Skip to main content

cu29_runtime/
pool.rs

1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use hashbrown::HashMap;
8use object_pool::{Pool, ReusableOwned};
9use smallvec::SmallVec;
10use std::alloc::{Layout, alloc, dealloc};
11use std::fmt::Debug;
12use std::mem::{align_of, size_of};
13use std::ops::{Deref, DerefMut};
14use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
15
16type PoolID = ArrayString<64>;
17
18/// Trait for a Pool to exposed to be monitored by the monitoring API.
19pub trait PoolMonitor: Send + Sync {
20    /// A unique and descriptive identifier for the pool.
21    fn id(&self) -> PoolID;
22
23    /// Number of buffer slots left in the pool.
24    fn space_left(&self) -> usize;
25
26    /// Total size of the pool in number of buffers.
27    fn total_size(&self) -> usize;
28
29    /// Size of one buffer
30    fn buffer_size(&self) -> usize;
31}
32
33static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
34const MAX_POOLS: usize = 16;
35
36fn lock_unpoison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
37    match mutex.lock() {
38        Ok(guard) => guard,
39        Err(poison) => poison.into_inner(),
40    }
41}
42
43// Register a pool to the global registry.
44fn register_pool(pool: Arc<dyn PoolMonitor>) {
45    POOL_REGISTRY
46        .get_or_init(|| Mutex::new(HashMap::new()))
47        .lock()
48        .unwrap_or_else(|poison| poison.into_inner())
49        .insert(pool.id().to_string(), pool);
50}
51
52type PoolStats = (PoolID, usize, usize, usize);
53
54/// Get the list of pools and their statistics.
55/// We use SmallVec here to avoid heap allocations while the stack is running.
56pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
57    // Safely get the registry, returning empty stats if not initialized.
58    let registry_lock = match POOL_REGISTRY.get() {
59        Some(lock) => lock_unpoison(lock),
60        None => return SmallVec::new(), // Return empty if registry is not initialized
61    };
62    let mut result = SmallVec::with_capacity(MAX_POOLS);
63    for pool in registry_lock.values() {
64        result.push((
65            pool.id(),
66            pool.space_left(),
67            pool.total_size(),
68            pool.buffer_size(),
69        ));
70    }
71    result
72}
73
74/// Basic Type that can be used in a buffer in a CuPool.
75pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
76    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
77    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
78}
79
80/// Blanket implementation for all types that are Sized, Copy, Encode, Decode and Debug.
81impl<T> ElementType for T
82where
83    T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
84    T: Encode,
85    T: Decode<()>,
86{
87    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
88        self.encode(encoder)
89    }
90
91    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
92        Self::decode(decoder)
93    }
94}
95
96pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
97    type Element: ElementType;
98}
99
100use crate::monitoring::CuPayloadSize;
101
102/// A Handle to a Buffer.
103/// For onboard usages, the buffer should be Pooled (ie, coming from a preallocated pool).
104/// The Detached version is for offline usages where we don't really need a pool to deserialize them.
105pub enum CuHandleInner<T: Debug> {
106    Pooled(ReusableOwned<T>),
107    Detached(T), // Should only be used in offline cases (e.g. deserialization)
108}
109
110impl<T> Debug for CuHandleInner<T>
111where
112    T: Debug,
113{
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self {
116            CuHandleInner::Pooled(r) => {
117                write!(f, "Pooled: {:?}", r.deref())
118            }
119            CuHandleInner::Detached(r) => write!(f, "Detached: {r:?}"),
120        }
121    }
122}
123
124impl<T: ArrayLike> Deref for CuHandleInner<T> {
125    type Target = [T::Element];
126
127    fn deref(&self) -> &Self::Target {
128        match self {
129            CuHandleInner::Pooled(pooled) => pooled,
130            CuHandleInner::Detached(detached) => detached,
131        }
132    }
133}
134
135impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
136    fn deref_mut(&mut self) -> &mut Self::Target {
137        match self {
138            CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
139            CuHandleInner::Detached(detached) => detached,
140        }
141    }
142}
143
144/// A shareable handle to an Array coming from a pool (either host or device).
145#[derive(Clone, Debug)]
146pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
147
148impl<T: ArrayLike> Deref for CuHandle<T> {
149    type Target = Arc<Mutex<CuHandleInner<T>>>;
150
151    fn deref(&self) -> &Self::Target {
152        &self.0
153    }
154}
155
156impl<T: ArrayLike> CuHandle<T> {
157    /// Create a new CuHandle not part of a Pool (not for onboard usages, use pools instead)
158    pub fn new_detached(inner: T) -> Self {
159        CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
160    }
161
162    /// Safely access the inner value, applying a closure to it.
163    pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
164        let lock = lock_unpoison(&self.0);
165        f(&*lock)
166    }
167
168    /// Mutably access the inner value, applying a closure to it.
169    pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
170        let mut lock = lock_unpoison(&self.0);
171        f(&mut *lock)
172    }
173}
174
175impl<T> CuPayloadSize for CuHandle<T>
176where
177    T: ArrayLike,
178{
179    fn raw_bytes(&self) -> usize {
180        lock_unpoison(&self.0).deref().len() * size_of::<T::Element>()
181    }
182
183    fn handle_bytes(&self) -> usize {
184        self.raw_bytes()
185    }
186}
187
188impl<T: ArrayLike + Encode> Encode for CuHandle<T>
189where
190    <T as ArrayLike>::Element: 'static,
191{
192    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
193        let inner = lock_unpoison(&self.0);
194        match inner.deref() {
195            CuHandleInner::Pooled(pooled) => pooled.deref().encode(encoder),
196            CuHandleInner::Detached(detached) => detached.encode(encoder),
197        }
198    }
199}
200
201impl<T: ArrayLike> Default for CuHandle<T> {
202    fn default() -> Self {
203        panic!("Cannot create a default CuHandle")
204    }
205}
206
207impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
208    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
209        let vec: Vec<U> = Vec::decode(decoder)?;
210        Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
211    }
212}
213
214/// A CuPool is a pool of buffers that can be shared between different parts of the code.
215/// Handles can be stored locally in the tasks and shared between them.
216pub trait CuPool<T: ArrayLike>: PoolMonitor {
217    /// Acquire a buffer from the pool.
218    fn acquire(&self) -> Option<CuHandle<T>>;
219
220    /// Copy data from a handle to a new handle from the pool.
221    fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
222    where
223        O: ArrayLike<Element = T::Element>;
224}
225
226/// A device memory pool can copy data from a device to a host memory pool on top.
227pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
228    /// Takes a handle to a device buffer and copies it into a host buffer pool.
229    /// It returns a new handle from the host pool with the data from the device handle given.
230    fn copy_to_host_pool<O>(
231        &self,
232        from_device_handle: &CuHandle<T>,
233        to_host_handle: &mut CuHandle<O>,
234    ) -> CuResult<()>
235    where
236        O: ArrayLike<Element = T::Element>;
237}
238
239/// A pool of host memory buffers.
240pub struct CuHostMemoryPool<T> {
241    /// Underlying pool of host buffers.
242    // Being an Arc is a requirement of try_pull_owned() so buffers can refer back to the pool.
243    id: PoolID,
244    pool: Arc<Pool<T>>,
245    size: usize,
246    buffer_size: usize,
247}
248
249impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
250    pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
251    where
252        F: Fn() -> T,
253    {
254        let pool = Arc::new(Pool::new(size, buffer_initializer));
255        let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
256
257        let og = Self {
258            id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
259            pool,
260            size,
261            buffer_size,
262        };
263        let og = Arc::new(og);
264        register_pool(og.clone());
265        Ok(og)
266    }
267}
268
269impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
270    fn id(&self) -> PoolID {
271        self.id
272    }
273
274    fn space_left(&self) -> usize {
275        self.pool.len()
276    }
277
278    fn total_size(&self) -> usize {
279        self.size
280    }
281
282    fn buffer_size(&self) -> usize {
283        self.buffer_size
284    }
285}
286
287impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
288    fn acquire(&self) -> Option<CuHandle<T>> {
289        let owned_object = self.pool.try_pull_owned(); // Use the owned version
290
291        owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
292    }
293
294    fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
295        let to_handle = self.acquire().expect("No available buffers in the pool");
296
297        match lock_unpoison(&from.0).deref() {
298            CuHandleInner::Detached(source) => match lock_unpoison(&to_handle.0).deref_mut() {
299                CuHandleInner::Detached(destination) => {
300                    destination.copy_from_slice(source);
301                }
302                CuHandleInner::Pooled(destination) => {
303                    destination.copy_from_slice(source);
304                }
305            },
306            CuHandleInner::Pooled(source) => match lock_unpoison(&to_handle.0).deref_mut() {
307                CuHandleInner::Detached(destination) => {
308                    destination.copy_from_slice(source);
309                }
310                CuHandleInner::Pooled(destination) => {
311                    destination.copy_from_slice(source);
312                }
313            },
314        }
315        to_handle
316    }
317}
318
319impl<E: ElementType + 'static> ArrayLike for Vec<E> {
320    type Element = E;
321}
322
323#[cfg(all(feature = "cuda", not(target_os = "macos")))]
324mod cuda {
325    use super::*;
326    use cu29_traits::CuError;
327    use cudarc::driver::{
328        CudaContext, CudaSlice, CudaStream, DeviceRepr, HostSlice, SyncOnDrop, ValidAsZeroBits,
329    };
330    use std::sync::Arc;
331
332    #[derive(Debug)]
333    pub struct CudaSliceWrapper<E>(CudaSlice<E>);
334
335    impl<E> Deref for CudaSliceWrapper<E>
336    where
337        E: ElementType,
338    {
339        type Target = [E];
340
341        fn deref(&self) -> &Self::Target {
342            // Implement logic to return a slice
343            panic!("You need to copy data to host memory pool before accessing it.");
344        }
345    }
346
347    impl<E> DerefMut for CudaSliceWrapper<E>
348    where
349        E: ElementType,
350    {
351        fn deref_mut(&mut self) -> &mut Self::Target {
352            panic!("You need to copy data to host memory pool before accessing it.");
353        }
354    }
355
356    impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
357        type Element = E;
358    }
359
360    impl<E> CudaSliceWrapper<E> {
361        pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
362            &self.0
363        }
364
365        pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
366            &mut self.0
367        }
368    }
369
370    // Create a wrapper type to bridge between ArrayLike and HostSlice
371    pub struct HostSliceWrapper<'a, T: ArrayLike> {
372        inner: &'a T,
373    }
374
375    impl<T: ArrayLike> HostSlice<T::Element> for HostSliceWrapper<'_, T> {
376        fn len(&self) -> usize {
377            self.inner.len()
378        }
379
380        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
381        unsafe fn stream_synced_slice<'b>(
382            &'b self,
383            stream: &'b CudaStream,
384        ) -> (&'b [T::Element], SyncOnDrop<'b>) {
385            (self.inner.deref(), SyncOnDrop::sync_stream(stream))
386        }
387
388        // SAFETY: This wrapper cannot provide mutable access; callers must not rely on this.
389        unsafe fn stream_synced_mut_slice<'b>(
390            &'b mut self,
391            _stream: &'b CudaStream,
392        ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
393            panic!("Cannot get mutable reference from immutable wrapper")
394        }
395    }
396
397    // Mutable wrapper
398    pub struct HostSliceMutWrapper<'a, T: ArrayLike> {
399        inner: &'a mut T,
400    }
401
402    impl<T: ArrayLike> HostSlice<T::Element> for HostSliceMutWrapper<'_, T> {
403        fn len(&self) -> usize {
404            self.inner.len()
405        }
406
407        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
408        unsafe fn stream_synced_slice<'b>(
409            &'b self,
410            stream: &'b CudaStream,
411        ) -> (&'b [T::Element], SyncOnDrop<'b>) {
412            (self.inner.deref(), SyncOnDrop::sync_stream(stream))
413        }
414
415        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
416        unsafe fn stream_synced_mut_slice<'b>(
417            &'b mut self,
418            stream: &'b CudaStream,
419        ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
420            (self.inner.deref_mut(), SyncOnDrop::sync_stream(stream))
421        }
422    }
423
424    // Add helper methods to the CuCudaPool implementation
425    impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
426        // Helper method to get a HostSliceWrapper from a CuHandleInner
427        fn get_host_slice_wrapper<O: ArrayLike<Element = E>>(
428            handle_inner: &CuHandleInner<O>,
429        ) -> HostSliceWrapper<'_, O> {
430            match handle_inner {
431                CuHandleInner::Pooled(pooled) => HostSliceWrapper { inner: pooled },
432                CuHandleInner::Detached(detached) => HostSliceWrapper { inner: detached },
433            }
434        }
435
436        // Helper method to get a HostSliceMutWrapper from a CuHandleInner
437        fn get_host_slice_mut_wrapper<O: ArrayLike<Element = E>>(
438            handle_inner: &mut CuHandleInner<O>,
439        ) -> HostSliceMutWrapper<'_, O> {
440            match handle_inner {
441                CuHandleInner::Pooled(pooled) => HostSliceMutWrapper { inner: pooled },
442                CuHandleInner::Detached(detached) => HostSliceMutWrapper { inner: detached },
443            }
444        }
445    }
446    /// A pool of CUDA memory buffers.
447    pub struct CuCudaPool<E>
448    where
449        E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
450    {
451        id: PoolID,
452        stream: Arc<CudaStream>,
453        pool: Arc<Pool<CudaSliceWrapper<E>>>,
454        nb_buffers: usize,
455        nb_element_per_buffer: usize,
456    }
457
458    impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
459        #[allow(dead_code)]
460        pub fn new(
461            id: &'static str,
462            ctx: Arc<CudaContext>,
463            nb_buffers: usize,
464            nb_element_per_buffer: usize,
465        ) -> CuResult<Self> {
466            let stream = ctx.default_stream();
467            let pool = (0..nb_buffers)
468                .map(|_| {
469                    stream
470                        .alloc_zeros(nb_element_per_buffer)
471                        .map(CudaSliceWrapper)
472                        .map_err(|_| "Failed to allocate device memory")
473                })
474                .collect::<Result<Vec<_>, _>>()?;
475
476            Ok(Self {
477                id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
478                stream,
479                pool: Arc::new(Pool::from_vec(pool)),
480                nb_buffers,
481                nb_element_per_buffer,
482            })
483        }
484    }
485
486    impl<E> PoolMonitor for CuCudaPool<E>
487    where
488        E: DeviceRepr + ElementType + ValidAsZeroBits,
489    {
490        fn id(&self) -> PoolID {
491            self.id
492        }
493
494        fn space_left(&self) -> usize {
495            self.pool.len()
496        }
497
498        fn total_size(&self) -> usize {
499            self.nb_buffers
500        }
501
502        fn buffer_size(&self) -> usize {
503            self.nb_element_per_buffer * size_of::<E>()
504        }
505    }
506
507    impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
508    where
509        E: DeviceRepr + ElementType + ValidAsZeroBits,
510    {
511        fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
512            self.pool
513                .try_pull_owned()
514                .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
515        }
516
517        fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
518        where
519            O: ArrayLike<Element = E>,
520        {
521            let to_handle = self.acquire().expect("No available buffers in the pool");
522
523            {
524                let from_lock = lock_unpoison(&from_handle.0);
525                let mut to_lock = lock_unpoison(&to_handle.0);
526
527                match &mut *to_lock {
528                    CuHandleInner::Detached(CudaSliceWrapper(to)) => {
529                        let wrapper = Self::get_host_slice_wrapper(&*from_lock);
530                        self.stream
531                            .memcpy_htod(&wrapper, to)
532                            .expect("Failed to copy data to device");
533                    }
534                    CuHandleInner::Pooled(to) => {
535                        let wrapper = Self::get_host_slice_wrapper(&*from_lock);
536                        self.stream
537                            .memcpy_htod(&wrapper, to.as_cuda_slice_mut())
538                            .expect("Failed to copy data to device");
539                    }
540                }
541            } // locks are dropped here
542            to_handle // now we can safely return to_handle
543        }
544    }
545
546    impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
547    where
548        E: ElementType + ValidAsZeroBits + DeviceRepr,
549    {
550        /// Copy from device to host
551        fn copy_to_host_pool<O>(
552            &self,
553            device_handle: &CuHandle<CudaSliceWrapper<E>>,
554            host_handle: &mut CuHandle<O>,
555        ) -> Result<(), CuError>
556        where
557            O: ArrayLike<Element = E>,
558        {
559            let device_lock = device_handle.lock().map_err(|e| {
560                CuError::from("Device handle mutex poisoned").add_cause(&e.to_string())
561            })?;
562            let mut host_lock = host_handle.lock().map_err(|e| {
563                CuError::from("Host handle mutex poisoned").add_cause(&e.to_string())
564            })?;
565            let src = match &*device_lock {
566                CuHandleInner::Pooled(source) => source.as_cuda_slice(),
567                CuHandleInner::Detached(source) => source.as_cuda_slice(),
568            };
569            let mut wrapper = Self::get_host_slice_mut_wrapper(&mut *host_lock);
570            self.stream.memcpy_dtoh(src, &mut wrapper).map_err(|e| {
571                CuError::from("Failed to copy data from device to host").add_cause(&e.to_string())
572            })?;
573            Ok(())
574        }
575    }
576}
577
578#[derive(Debug)]
579/// A buffer that is aligned to a specific size with the Element of type E.
580pub struct AlignedBuffer<E: ElementType> {
581    ptr: *mut E,
582    size: usize,
583    layout: Layout,
584}
585
586impl<E: ElementType> AlignedBuffer<E> {
587    pub fn new(num_elements: usize, alignment: usize) -> Self {
588        assert!(
589            num_elements > 0 && size_of::<E>() > 0,
590            "AlignedBuffer requires a non-zero element count and non-zero-sized element type"
591        );
592        let alignment = alignment.max(align_of::<E>());
593        let alloc_size = num_elements
594            .checked_mul(size_of::<E>())
595            .expect("AlignedBuffer allocation size overflow");
596        let layout = Layout::from_size_align(alloc_size, alignment).unwrap();
597        // SAFETY: layout describes a valid, non-zero allocation request.
598        let ptr = unsafe { alloc(layout) as *mut E };
599        if ptr.is_null() {
600            panic!("Failed to allocate memory");
601        }
602        // SAFETY: ptr is valid for writes of `num_elements` elements.
603        unsafe {
604            for i in 0..num_elements {
605                std::ptr::write(ptr.add(i), E::default());
606            }
607        }
608        Self {
609            ptr,
610            size: num_elements,
611            layout,
612        }
613    }
614}
615
616impl<E: ElementType> Deref for AlignedBuffer<E> {
617    type Target = [E];
618
619    fn deref(&self) -> &Self::Target {
620        // SAFETY: `new` initializes all elements and keeps the pointer aligned.
621        unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
622    }
623}
624
625impl<E: ElementType> DerefMut for AlignedBuffer<E> {
626    fn deref_mut(&mut self) -> &mut Self::Target {
627        // SAFETY: `new` initializes all elements and keeps the pointer aligned.
628        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
629    }
630}
631
632impl<E: ElementType> Drop for AlignedBuffer<E> {
633    fn drop(&mut self) {
634        // SAFETY: `ptr` was allocated with `layout` in `new`.
635        unsafe { dealloc(self.ptr as *mut u8, self.layout) }
636    }
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642
643    #[test]
644    fn test_pool() {
645        use std::cell::RefCell;
646        let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
647        let holding = objs.borrow().clone();
648        let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
649        let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
650            .unwrap();
651
652        let obj1 = pool.acquire().unwrap();
653        {
654            let obj2 = pool.acquire().unwrap();
655            assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
656            assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
657            assert_eq!(pool.space_left(), 1);
658        }
659        assert_eq!(pool.space_left(), 2);
660
661        let obj3 = pool.acquire().unwrap();
662        assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
663
664        assert_eq!(pool.space_left(), 1);
665
666        let _obj4 = pool.acquire().unwrap();
667        assert_eq!(pool.space_left(), 0);
668
669        let obj5 = pool.acquire();
670        assert!(obj5.is_none());
671    }
672
673    #[cfg(all(feature = "cuda", has_nvidia_gpu))]
674    #[test]
675    fn test_cuda_pool() {
676        use crate::pool::cuda::CuCudaPool;
677        use cudarc::driver::CudaContext;
678        let ctx = CudaContext::new(0).unwrap();
679        let pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
680
681        let _obj1 = pool.acquire().unwrap();
682
683        {
684            let _obj2 = pool.acquire().unwrap();
685            assert_eq!(pool.space_left(), 1);
686        }
687        assert_eq!(pool.space_left(), 2);
688
689        let _obj3 = pool.acquire().unwrap();
690
691        assert_eq!(pool.space_left(), 1);
692
693        let _obj4 = pool.acquire().unwrap();
694        assert_eq!(pool.space_left(), 0);
695
696        let obj5 = pool.acquire();
697        assert!(obj5.is_none());
698    }
699
700    #[cfg(all(feature = "cuda", has_nvidia_gpu))]
701    #[test]
702    fn test_copy_roundtrip() {
703        use crate::pool::cuda::CuCudaPool;
704        use cudarc::driver::CudaContext;
705        let ctx = CudaContext::new(0).unwrap();
706        let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
707        let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
708
709        let cuda_handle = {
710            let mut initial_handle = host_pool.acquire().unwrap();
711            {
712                let mut inner_initial_handle = initial_handle.lock().unwrap();
713                if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
714                    pooled[0] = 42.0;
715                } else {
716                    panic!();
717                }
718            }
719
720            // send that to the GPU
721            cuda_pool.copy_from(&mut initial_handle)
722        };
723
724        // get it back to the host
725        let mut final_handle = host_pool.acquire().unwrap();
726        cuda_pool
727            .copy_to_host_pool(&cuda_handle, &mut final_handle)
728            .unwrap();
729
730        let value = final_handle.lock().unwrap().deref().deref()[0];
731        assert_eq!(value, 42.0);
732    }
733}