1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use hashbrown::HashMap;
8use object_pool::{Pool, ReusableOwned};
9use smallvec::SmallVec;
10use std::alloc::{Layout, alloc, dealloc};
11use std::fmt::Debug;
12use std::mem::{align_of, size_of};
13use std::ops::{Deref, DerefMut};
14use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
15
16type PoolID = ArrayString<64>;
17
18pub trait PoolMonitor: Send + Sync {
20 fn id(&self) -> PoolID;
22
23 fn space_left(&self) -> usize;
25
26 fn total_size(&self) -> usize;
28
29 fn buffer_size(&self) -> usize;
31}
32
33static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
34const MAX_POOLS: usize = 16;
35
36fn lock_unpoison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
37 match mutex.lock() {
38 Ok(guard) => guard,
39 Err(poison) => poison.into_inner(),
40 }
41}
42
43fn register_pool(pool: Arc<dyn PoolMonitor>) {
45 POOL_REGISTRY
46 .get_or_init(|| Mutex::new(HashMap::new()))
47 .lock()
48 .unwrap_or_else(|poison| poison.into_inner())
49 .insert(pool.id().to_string(), pool);
50}
51
52type PoolStats = (PoolID, usize, usize, usize);
53
54pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
57 let registry_lock = match POOL_REGISTRY.get() {
59 Some(lock) => lock_unpoison(lock),
60 None => return SmallVec::new(), };
62 let mut result = SmallVec::with_capacity(MAX_POOLS);
63 for pool in registry_lock.values() {
64 result.push((
65 pool.id(),
66 pool.space_left(),
67 pool.total_size(),
68 pool.buffer_size(),
69 ));
70 }
71 result
72}
73
74pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
76 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
77 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
78}
79
80impl<T> ElementType for T
82where
83 T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
84 T: Encode,
85 T: Decode<()>,
86{
87 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
88 self.encode(encoder)
89 }
90
91 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
92 Self::decode(decoder)
93 }
94}
95
96pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
97 type Element: ElementType;
98}
99
100use crate::monitoring::CuPayloadSize;
101
102pub enum CuHandleInner<T: Debug> {
106 Pooled(ReusableOwned<T>),
107 Detached(T), }
109
110impl<T> Debug for CuHandleInner<T>
111where
112 T: Debug,
113{
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 CuHandleInner::Pooled(r) => {
117 write!(f, "Pooled: {:?}", r.deref())
118 }
119 CuHandleInner::Detached(r) => write!(f, "Detached: {r:?}"),
120 }
121 }
122}
123
124impl<T: ArrayLike> Deref for CuHandleInner<T> {
125 type Target = [T::Element];
126
127 fn deref(&self) -> &Self::Target {
128 match self {
129 CuHandleInner::Pooled(pooled) => pooled,
130 CuHandleInner::Detached(detached) => detached,
131 }
132 }
133}
134
135impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
136 fn deref_mut(&mut self) -> &mut Self::Target {
137 match self {
138 CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
139 CuHandleInner::Detached(detached) => detached,
140 }
141 }
142}
143
144#[derive(Clone, Debug)]
146pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
147
148impl<T: ArrayLike> Deref for CuHandle<T> {
149 type Target = Arc<Mutex<CuHandleInner<T>>>;
150
151 fn deref(&self) -> &Self::Target {
152 &self.0
153 }
154}
155
156impl<T: ArrayLike> CuHandle<T> {
157 pub fn new_detached(inner: T) -> Self {
159 CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
160 }
161
162 pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
164 let lock = lock_unpoison(&self.0);
165 f(&*lock)
166 }
167
168 pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
170 let mut lock = lock_unpoison(&self.0);
171 f(&mut *lock)
172 }
173}
174
175impl<T> CuPayloadSize for CuHandle<T>
176where
177 T: ArrayLike,
178{
179 fn raw_bytes(&self) -> usize {
180 lock_unpoison(&self.0).deref().len() * size_of::<T::Element>()
181 }
182
183 fn handle_bytes(&self) -> usize {
184 self.raw_bytes()
185 }
186}
187
188impl<T: ArrayLike + Encode> Encode for CuHandle<T>
189where
190 <T as ArrayLike>::Element: 'static,
191{
192 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
193 let inner = lock_unpoison(&self.0);
194 match inner.deref() {
195 CuHandleInner::Pooled(pooled) => pooled.deref().encode(encoder),
196 CuHandleInner::Detached(detached) => detached.encode(encoder),
197 }
198 }
199}
200
201impl<T: ArrayLike> Default for CuHandle<T> {
202 fn default() -> Self {
203 panic!("Cannot create a default CuHandle")
204 }
205}
206
207impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
208 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
209 let vec: Vec<U> = Vec::decode(decoder)?;
210 Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
211 }
212}
213
214pub trait CuPool<T: ArrayLike>: PoolMonitor {
217 fn acquire(&self) -> Option<CuHandle<T>>;
219
220 fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
222 where
223 O: ArrayLike<Element = T::Element>;
224}
225
226pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
228 fn copy_to_host_pool<O>(
231 &self,
232 from_device_handle: &CuHandle<T>,
233 to_host_handle: &mut CuHandle<O>,
234 ) -> CuResult<()>
235 where
236 O: ArrayLike<Element = T::Element>;
237}
238
239pub struct CuHostMemoryPool<T> {
241 id: PoolID,
244 pool: Arc<Pool<T>>,
245 size: usize,
246 buffer_size: usize,
247}
248
249impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
250 pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
251 where
252 F: Fn() -> T,
253 {
254 let pool = Arc::new(Pool::new(size, buffer_initializer));
255 let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
256
257 let og = Self {
258 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
259 pool,
260 size,
261 buffer_size,
262 };
263 let og = Arc::new(og);
264 register_pool(og.clone());
265 Ok(og)
266 }
267}
268
269impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
270 fn id(&self) -> PoolID {
271 self.id
272 }
273
274 fn space_left(&self) -> usize {
275 self.pool.len()
276 }
277
278 fn total_size(&self) -> usize {
279 self.size
280 }
281
282 fn buffer_size(&self) -> usize {
283 self.buffer_size
284 }
285}
286
287impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
288 fn acquire(&self) -> Option<CuHandle<T>> {
289 let owned_object = self.pool.try_pull_owned(); owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
292 }
293
294 fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
295 let to_handle = self.acquire().expect("No available buffers in the pool");
296
297 match lock_unpoison(&from.0).deref() {
298 CuHandleInner::Detached(source) => match lock_unpoison(&to_handle.0).deref_mut() {
299 CuHandleInner::Detached(destination) => {
300 destination.copy_from_slice(source);
301 }
302 CuHandleInner::Pooled(destination) => {
303 destination.copy_from_slice(source);
304 }
305 },
306 CuHandleInner::Pooled(source) => match lock_unpoison(&to_handle.0).deref_mut() {
307 CuHandleInner::Detached(destination) => {
308 destination.copy_from_slice(source);
309 }
310 CuHandleInner::Pooled(destination) => {
311 destination.copy_from_slice(source);
312 }
313 },
314 }
315 to_handle
316 }
317}
318
319impl<E: ElementType + 'static> ArrayLike for Vec<E> {
320 type Element = E;
321}
322
323#[cfg(all(feature = "cuda", not(target_os = "macos")))]
324mod cuda {
325 use super::*;
326 use cu29_traits::CuError;
327 use cudarc::driver::{
328 CudaContext, CudaSlice, CudaStream, DeviceRepr, HostSlice, SyncOnDrop, ValidAsZeroBits,
329 };
330 use std::sync::Arc;
331
332 #[derive(Debug)]
333 pub struct CudaSliceWrapper<E>(CudaSlice<E>);
334
335 impl<E> Deref for CudaSliceWrapper<E>
336 where
337 E: ElementType,
338 {
339 type Target = [E];
340
341 fn deref(&self) -> &Self::Target {
342 panic!("You need to copy data to host memory pool before accessing it.");
344 }
345 }
346
347 impl<E> DerefMut for CudaSliceWrapper<E>
348 where
349 E: ElementType,
350 {
351 fn deref_mut(&mut self) -> &mut Self::Target {
352 panic!("You need to copy data to host memory pool before accessing it.");
353 }
354 }
355
356 impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
357 type Element = E;
358 }
359
360 impl<E> CudaSliceWrapper<E> {
361 pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
362 &self.0
363 }
364
365 pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
366 &mut self.0
367 }
368 }
369
370 pub struct HostSliceWrapper<'a, T: ArrayLike> {
372 inner: &'a T,
373 }
374
375 impl<T: ArrayLike> HostSlice<T::Element> for HostSliceWrapper<'_, T> {
376 fn len(&self) -> usize {
377 self.inner.len()
378 }
379
380 unsafe fn stream_synced_slice<'b>(
382 &'b self,
383 stream: &'b CudaStream,
384 ) -> (&'b [T::Element], SyncOnDrop<'b>) {
385 (self.inner.deref(), SyncOnDrop::sync_stream(stream))
386 }
387
388 unsafe fn stream_synced_mut_slice<'b>(
390 &'b mut self,
391 _stream: &'b CudaStream,
392 ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
393 panic!("Cannot get mutable reference from immutable wrapper")
394 }
395 }
396
397 pub struct HostSliceMutWrapper<'a, T: ArrayLike> {
399 inner: &'a mut T,
400 }
401
402 impl<T: ArrayLike> HostSlice<T::Element> for HostSliceMutWrapper<'_, T> {
403 fn len(&self) -> usize {
404 self.inner.len()
405 }
406
407 unsafe fn stream_synced_slice<'b>(
409 &'b self,
410 stream: &'b CudaStream,
411 ) -> (&'b [T::Element], SyncOnDrop<'b>) {
412 (self.inner.deref(), SyncOnDrop::sync_stream(stream))
413 }
414
415 unsafe fn stream_synced_mut_slice<'b>(
417 &'b mut self,
418 stream: &'b CudaStream,
419 ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
420 (self.inner.deref_mut(), SyncOnDrop::sync_stream(stream))
421 }
422 }
423
424 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
426 fn get_host_slice_wrapper<O: ArrayLike<Element = E>>(
428 handle_inner: &CuHandleInner<O>,
429 ) -> HostSliceWrapper<'_, O> {
430 match handle_inner {
431 CuHandleInner::Pooled(pooled) => HostSliceWrapper { inner: pooled },
432 CuHandleInner::Detached(detached) => HostSliceWrapper { inner: detached },
433 }
434 }
435
436 fn get_host_slice_mut_wrapper<O: ArrayLike<Element = E>>(
438 handle_inner: &mut CuHandleInner<O>,
439 ) -> HostSliceMutWrapper<'_, O> {
440 match handle_inner {
441 CuHandleInner::Pooled(pooled) => HostSliceMutWrapper { inner: pooled },
442 CuHandleInner::Detached(detached) => HostSliceMutWrapper { inner: detached },
443 }
444 }
445 }
446 pub struct CuCudaPool<E>
448 where
449 E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
450 {
451 id: PoolID,
452 stream: Arc<CudaStream>,
453 pool: Arc<Pool<CudaSliceWrapper<E>>>,
454 nb_buffers: usize,
455 nb_element_per_buffer: usize,
456 }
457
458 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
459 #[allow(dead_code)]
460 pub fn new(
461 id: &'static str,
462 ctx: Arc<CudaContext>,
463 nb_buffers: usize,
464 nb_element_per_buffer: usize,
465 ) -> CuResult<Self> {
466 let stream = ctx.default_stream();
467 let pool = (0..nb_buffers)
468 .map(|_| {
469 stream
470 .alloc_zeros(nb_element_per_buffer)
471 .map(CudaSliceWrapper)
472 .map_err(|_| "Failed to allocate device memory")
473 })
474 .collect::<Result<Vec<_>, _>>()?;
475
476 Ok(Self {
477 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
478 stream,
479 pool: Arc::new(Pool::from_vec(pool)),
480 nb_buffers,
481 nb_element_per_buffer,
482 })
483 }
484 }
485
486 impl<E> PoolMonitor for CuCudaPool<E>
487 where
488 E: DeviceRepr + ElementType + ValidAsZeroBits,
489 {
490 fn id(&self) -> PoolID {
491 self.id
492 }
493
494 fn space_left(&self) -> usize {
495 self.pool.len()
496 }
497
498 fn total_size(&self) -> usize {
499 self.nb_buffers
500 }
501
502 fn buffer_size(&self) -> usize {
503 self.nb_element_per_buffer * size_of::<E>()
504 }
505 }
506
507 impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
508 where
509 E: DeviceRepr + ElementType + ValidAsZeroBits,
510 {
511 fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
512 self.pool
513 .try_pull_owned()
514 .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
515 }
516
517 fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
518 where
519 O: ArrayLike<Element = E>,
520 {
521 let to_handle = self.acquire().expect("No available buffers in the pool");
522
523 {
524 let from_lock = lock_unpoison(&from_handle.0);
525 let mut to_lock = lock_unpoison(&to_handle.0);
526
527 match &mut *to_lock {
528 CuHandleInner::Detached(CudaSliceWrapper(to)) => {
529 let wrapper = Self::get_host_slice_wrapper(&*from_lock);
530 self.stream
531 .memcpy_htod(&wrapper, to)
532 .expect("Failed to copy data to device");
533 }
534 CuHandleInner::Pooled(to) => {
535 let wrapper = Self::get_host_slice_wrapper(&*from_lock);
536 self.stream
537 .memcpy_htod(&wrapper, to.as_cuda_slice_mut())
538 .expect("Failed to copy data to device");
539 }
540 }
541 } to_handle }
544 }
545
546 impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
547 where
548 E: ElementType + ValidAsZeroBits + DeviceRepr,
549 {
550 fn copy_to_host_pool<O>(
552 &self,
553 device_handle: &CuHandle<CudaSliceWrapper<E>>,
554 host_handle: &mut CuHandle<O>,
555 ) -> Result<(), CuError>
556 where
557 O: ArrayLike<Element = E>,
558 {
559 let device_lock = device_handle.lock().map_err(|e| {
560 CuError::from("Device handle mutex poisoned").add_cause(&e.to_string())
561 })?;
562 let mut host_lock = host_handle.lock().map_err(|e| {
563 CuError::from("Host handle mutex poisoned").add_cause(&e.to_string())
564 })?;
565 let src = match &*device_lock {
566 CuHandleInner::Pooled(source) => source.as_cuda_slice(),
567 CuHandleInner::Detached(source) => source.as_cuda_slice(),
568 };
569 let mut wrapper = Self::get_host_slice_mut_wrapper(&mut *host_lock);
570 self.stream.memcpy_dtoh(src, &mut wrapper).map_err(|e| {
571 CuError::from("Failed to copy data from device to host").add_cause(&e.to_string())
572 })?;
573 Ok(())
574 }
575 }
576}
577
578#[derive(Debug)]
579pub struct AlignedBuffer<E: ElementType> {
581 ptr: *mut E,
582 size: usize,
583 layout: Layout,
584}
585
586impl<E: ElementType> AlignedBuffer<E> {
587 pub fn new(num_elements: usize, alignment: usize) -> Self {
588 assert!(
589 num_elements > 0 && size_of::<E>() > 0,
590 "AlignedBuffer requires a non-zero element count and non-zero-sized element type"
591 );
592 let alignment = alignment.max(align_of::<E>());
593 let alloc_size = num_elements
594 .checked_mul(size_of::<E>())
595 .expect("AlignedBuffer allocation size overflow");
596 let layout = Layout::from_size_align(alloc_size, alignment).unwrap();
597 let ptr = unsafe { alloc(layout) as *mut E };
599 if ptr.is_null() {
600 panic!("Failed to allocate memory");
601 }
602 unsafe {
604 for i in 0..num_elements {
605 std::ptr::write(ptr.add(i), E::default());
606 }
607 }
608 Self {
609 ptr,
610 size: num_elements,
611 layout,
612 }
613 }
614}
615
616impl<E: ElementType> Deref for AlignedBuffer<E> {
617 type Target = [E];
618
619 fn deref(&self) -> &Self::Target {
620 unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
622 }
623}
624
625impl<E: ElementType> DerefMut for AlignedBuffer<E> {
626 fn deref_mut(&mut self) -> &mut Self::Target {
627 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
629 }
630}
631
632impl<E: ElementType> Drop for AlignedBuffer<E> {
633 fn drop(&mut self) {
634 unsafe { dealloc(self.ptr as *mut u8, self.layout) }
636 }
637}
638
639#[cfg(test)]
640mod tests {
641 use super::*;
642
643 #[test]
644 fn test_pool() {
645 use std::cell::RefCell;
646 let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
647 let holding = objs.borrow().clone();
648 let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
649 let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
650 .unwrap();
651
652 let obj1 = pool.acquire().unwrap();
653 {
654 let obj2 = pool.acquire().unwrap();
655 assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
656 assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
657 assert_eq!(pool.space_left(), 1);
658 }
659 assert_eq!(pool.space_left(), 2);
660
661 let obj3 = pool.acquire().unwrap();
662 assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
663
664 assert_eq!(pool.space_left(), 1);
665
666 let _obj4 = pool.acquire().unwrap();
667 assert_eq!(pool.space_left(), 0);
668
669 let obj5 = pool.acquire();
670 assert!(obj5.is_none());
671 }
672
673 #[cfg(all(feature = "cuda", has_nvidia_gpu))]
674 #[test]
675 fn test_cuda_pool() {
676 use crate::pool::cuda::CuCudaPool;
677 use cudarc::driver::CudaContext;
678 let ctx = CudaContext::new(0).unwrap();
679 let pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
680
681 let _obj1 = pool.acquire().unwrap();
682
683 {
684 let _obj2 = pool.acquire().unwrap();
685 assert_eq!(pool.space_left(), 1);
686 }
687 assert_eq!(pool.space_left(), 2);
688
689 let _obj3 = pool.acquire().unwrap();
690
691 assert_eq!(pool.space_left(), 1);
692
693 let _obj4 = pool.acquire().unwrap();
694 assert_eq!(pool.space_left(), 0);
695
696 let obj5 = pool.acquire();
697 assert!(obj5.is_none());
698 }
699
700 #[cfg(all(feature = "cuda", has_nvidia_gpu))]
701 #[test]
702 fn test_copy_roundtrip() {
703 use crate::pool::cuda::CuCudaPool;
704 use cudarc::driver::CudaContext;
705 let ctx = CudaContext::new(0).unwrap();
706 let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
707 let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
708
709 let cuda_handle = {
710 let mut initial_handle = host_pool.acquire().unwrap();
711 {
712 let mut inner_initial_handle = initial_handle.lock().unwrap();
713 if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
714 pooled[0] = 42.0;
715 } else {
716 panic!();
717 }
718 }
719
720 cuda_pool.copy_from(&mut initial_handle)
722 };
723
724 let mut final_handle = host_pool.acquire().unwrap();
726 cuda_pool
727 .copy_to_host_pool(&cuda_handle, &mut final_handle)
728 .unwrap();
729
730 let value = final_handle.lock().unwrap().deref().deref()[0];
731 assert_eq!(value, 42.0);
732 }
733}