Skip to main content

cu29_runtime/
pool.rs

1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use hashbrown::HashMap;
8use object_pool::{Pool, ReusableOwned};
9use serde::de::{self, MapAccess, SeqAccess, Visitor};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use smallvec::SmallVec;
12use std::alloc::{Layout, alloc, dealloc};
13use std::cell::Cell;
14use std::cell::UnsafeCell;
15use std::fmt::Debug;
16use std::fs::OpenOptions;
17use std::marker::PhantomData;
18use std::mem::{align_of, size_of};
19use std::ops::{Deref, DerefMut};
20use std::path::{Path, PathBuf};
21use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
22
23use memmap2::{MmapMut, MmapOptions};
24use tempfile::NamedTempFile;
25
26type PoolID = ArrayString<64>;
27
28/// Trait for a Pool to exposed to be monitored by the monitoring API.
29pub trait PoolMonitor: Send + Sync {
30    /// A unique and descriptive identifier for the pool.
31    fn id(&self) -> PoolID;
32
33    /// Number of buffer slots left in the pool.
34    fn space_left(&self) -> usize;
35
36    /// Total size of the pool in number of buffers.
37    fn total_size(&self) -> usize;
38
39    /// Size of one buffer
40    fn buffer_size(&self) -> usize;
41}
42
43static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
44const MAX_POOLS: usize = 16;
45
46fn lock_unpoison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
47    match mutex.lock() {
48        Ok(guard) => guard,
49        Err(poison) => poison.into_inner(),
50    }
51}
52
53// Register a pool to the global registry.
54fn register_pool(pool: Arc<dyn PoolMonitor>) {
55    POOL_REGISTRY
56        .get_or_init(|| Mutex::new(HashMap::new()))
57        .lock()
58        .unwrap_or_else(|poison| poison.into_inner())
59        .insert(pool.id().to_string(), pool);
60}
61
62type PoolStats = (PoolID, usize, usize, usize);
63
64/// Get the list of pools and their statistics.
65/// We use SmallVec here to avoid heap allocations while the stack is running.
66pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
67    // Safely get the registry, returning empty stats if not initialized.
68    let registry_lock = match POOL_REGISTRY.get() {
69        Some(lock) => lock_unpoison(lock),
70        None => return SmallVec::new(), // Return empty if registry is not initialized
71    };
72    let mut result = SmallVec::with_capacity(MAX_POOLS);
73    for pool in registry_lock.values() {
74        result.push((
75            pool.id(),
76            pool.space_left(),
77            pool.total_size(),
78            pool.buffer_size(),
79        ));
80    }
81    result
82}
83
84/// Basic Type that can be used in a buffer in a CuPool.
85pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
86    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
87    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
88}
89
90/// Blanket implementation for all types that are Sized, Copy, Encode, Decode and Debug.
91impl<T> ElementType for T
92where
93    T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
94    T: Encode,
95    T: Decode<()>,
96{
97    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
98        self.encode(encoder)
99    }
100
101    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
102        Self::decode(decoder)
103    }
104}
105
106pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
107    type Element: ElementType;
108}
109
110thread_local! {
111    static SHARED_HANDLE_SERIALIZATION_ENABLED: Cell<bool> = const { Cell::new(false) };
112}
113
114pub struct SharedHandleSerializationGuard {
115    previous: bool,
116}
117
118impl Drop for SharedHandleSerializationGuard {
119    fn drop(&mut self) {
120        SHARED_HANDLE_SERIALIZATION_ENABLED.with(|enabled| enabled.set(self.previous));
121    }
122}
123
124pub fn enable_shared_handle_serialization() -> SharedHandleSerializationGuard {
125    let previous = SHARED_HANDLE_SERIALIZATION_ENABLED.with(|enabled| {
126        let previous = enabled.get();
127        enabled.set(true);
128        previous
129    });
130    SharedHandleSerializationGuard { previous }
131}
132
133fn shared_handle_serialization_enabled() -> bool {
134    SHARED_HANDLE_SERIALIZATION_ENABLED.with(Cell::get)
135}
136
137#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum CuSharedMemoryElementType {
140    U8,
141    U16,
142    U32,
143    U64,
144    I8,
145    I16,
146    I32,
147    I64,
148    F32,
149    F64,
150}
151
152impl CuSharedMemoryElementType {
153    pub fn of<E: ElementType + 'static>() -> Option<Self> {
154        let type_id = core::any::TypeId::of::<E>();
155        if type_id == core::any::TypeId::of::<u8>() {
156            Some(Self::U8)
157        } else if type_id == core::any::TypeId::of::<u16>() {
158            Some(Self::U16)
159        } else if type_id == core::any::TypeId::of::<u32>() {
160            Some(Self::U32)
161        } else if type_id == core::any::TypeId::of::<u64>() {
162            Some(Self::U64)
163        } else if type_id == core::any::TypeId::of::<i8>() {
164            Some(Self::I8)
165        } else if type_id == core::any::TypeId::of::<i16>() {
166            Some(Self::I16)
167        } else if type_id == core::any::TypeId::of::<i32>() {
168            Some(Self::I32)
169        } else if type_id == core::any::TypeId::of::<i64>() {
170            Some(Self::I64)
171        } else if type_id == core::any::TypeId::of::<f32>() {
172            Some(Self::F32)
173        } else if type_id == core::any::TypeId::of::<f64>() {
174            Some(Self::F64)
175        } else {
176            None
177        }
178    }
179}
180
181#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
182pub struct CuSharedMemoryHandleDescriptor {
183    #[serde(rename = "__cu_shm_handle__")]
184    pub marker: bool,
185    pub path: String,
186    pub offset_bytes: usize,
187    pub len_elements: usize,
188    pub element_type: CuSharedMemoryElementType,
189}
190
191impl CuSharedMemoryHandleDescriptor {
192    fn new(
193        path: String,
194        offset_bytes: usize,
195        len_elements: usize,
196        element_type: CuSharedMemoryElementType,
197    ) -> Self {
198        Self {
199            marker: true,
200            path,
201            offset_bytes,
202            len_elements,
203            element_type,
204        }
205    }
206}
207
208struct CuSharedMemoryRegion {
209    path: PathBuf,
210    mmap: UnsafeCell<MmapMut>,
211    _backing_file: Option<NamedTempFile>,
212}
213
214impl CuSharedMemoryRegion {
215    fn create(byte_len: usize) -> CuResult<Arc<Self>> {
216        let file = NamedTempFile::new()
217            .map_err(|e| cu29_traits::CuError::new_with_cause("create shared memory file", e))?;
218        file.as_file()
219            .set_len(byte_len as u64)
220            .map_err(|e| cu29_traits::CuError::new_with_cause("size shared memory file", e))?;
221        let mmap = unsafe {
222            MmapOptions::new()
223                .len(byte_len)
224                .map_mut(file.as_file())
225                .map_err(|e| cu29_traits::CuError::new_with_cause("map shared memory file", e))?
226        };
227        let region = Arc::new(Self {
228            path: file.path().to_path_buf(),
229            mmap: UnsafeCell::new(mmap),
230            _backing_file: Some(file),
231        });
232        cache_shared_region(region.clone());
233        Ok(region)
234    }
235
236    fn open(path: &Path) -> CuResult<Arc<Self>> {
237        if let Some(region) = cached_shared_region(path) {
238            return Ok(region);
239        }
240
241        let file = OpenOptions::new()
242            .read(true)
243            .write(true)
244            .open(path)
245            .map_err(|e| cu29_traits::CuError::new_with_cause("open shared memory file", e))?;
246        let len = file
247            .metadata()
248            .map_err(|e| cu29_traits::CuError::new_with_cause("stat shared memory file", e))?
249            .len() as usize;
250        let mmap = unsafe {
251            MmapOptions::new()
252                .len(len)
253                .map_mut(&file)
254                .map_err(|e| cu29_traits::CuError::new_with_cause("map shared memory file", e))?
255        };
256        let region = Arc::new(Self {
257            path: path.to_path_buf(),
258            mmap: UnsafeCell::new(mmap),
259            _backing_file: None,
260        });
261        cache_shared_region(region.clone());
262        Ok(region)
263    }
264}
265
266impl Debug for CuSharedMemoryRegion {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        f.debug_struct("CuSharedMemoryRegion")
269            .field("path", &self.path)
270            .finish_non_exhaustive()
271    }
272}
273
274// SAFETY:
275// Access to the mapped bytes is mediated through Copper handles and pool slot
276// leasing, so cross-thread aliasing follows the same external synchronization as
277// other mutable payload buffers.
278unsafe impl Send for CuSharedMemoryRegion {}
279// SAFETY:
280// See `Send` rationale above.
281unsafe impl Sync for CuSharedMemoryRegion {}
282
283fn shared_region_cache() -> &'static Mutex<HashMap<PathBuf, std::sync::Weak<CuSharedMemoryRegion>>>
284{
285    static CACHE: OnceLock<Mutex<HashMap<PathBuf, std::sync::Weak<CuSharedMemoryRegion>>>> =
286        OnceLock::new();
287    CACHE.get_or_init(|| Mutex::new(HashMap::new()))
288}
289
290fn cache_shared_region(region: Arc<CuSharedMemoryRegion>) {
291    lock_unpoison(shared_region_cache()).insert(region.path.clone(), Arc::downgrade(&region));
292}
293
294fn cached_shared_region(path: &Path) -> Option<Arc<CuSharedMemoryRegion>> {
295    lock_unpoison(shared_region_cache())
296        .get(path)
297        .and_then(std::sync::Weak::upgrade)
298}
299
300fn shared_slot_stride<E: ElementType>(len_elements: usize) -> usize {
301    let raw_bytes = len_elements
302        .checked_mul(size_of::<E>())
303        .expect("shared memory slot size overflow");
304    let alignment = align_of::<E>().max(1);
305    raw_bytes.div_ceil(alignment) * alignment
306}
307
308#[derive(Debug)]
309pub struct CuSharedMemoryBuffer<E: ElementType> {
310    region: Arc<CuSharedMemoryRegion>,
311    offset_bytes: usize,
312    len_elements: usize,
313    _marker: PhantomData<E>,
314}
315
316impl<E: ElementType + 'static> CuSharedMemoryBuffer<E> {
317    fn from_region(
318        region: Arc<CuSharedMemoryRegion>,
319        offset_bytes: usize,
320        len_elements: usize,
321    ) -> Self {
322        Self {
323            region,
324            offset_bytes,
325            len_elements,
326            _marker: PhantomData,
327        }
328    }
329
330    pub fn from_vec_detached(data: Vec<E>) -> CuResult<Self> {
331        let len_elements = data.len();
332        let slot_stride = shared_slot_stride::<E>(len_elements.max(1));
333        let region = CuSharedMemoryRegion::create(slot_stride)?;
334        let mut buffer = Self::from_region(region, 0, len_elements);
335        if !data.is_empty() {
336            buffer.copy_from_slice(&data);
337        }
338        Ok(buffer)
339    }
340
341    pub fn from_descriptor(descriptor: &CuSharedMemoryHandleDescriptor) -> CuResult<Self> {
342        let expected = CuSharedMemoryElementType::of::<E>()
343            .ok_or_else(|| cu29_traits::CuError::from("unsupported shared memory element type"))?;
344        if descriptor.element_type != expected {
345            return Err(cu29_traits::CuError::from(
346                "shared memory descriptor element type mismatch",
347            ));
348        }
349        let region = CuSharedMemoryRegion::open(Path::new(&descriptor.path))?;
350        Ok(Self::from_region(
351            region,
352            descriptor.offset_bytes,
353            descriptor.len_elements,
354        ))
355    }
356
357    pub fn descriptor(&self) -> Option<CuSharedMemoryHandleDescriptor>
358    where
359        E: 'static,
360    {
361        CuSharedMemoryElementType::of::<E>().map(|element_type| {
362            CuSharedMemoryHandleDescriptor::new(
363                self.region.path.display().to_string(),
364                self.offset_bytes,
365                self.len_elements,
366                element_type,
367            )
368        })
369    }
370}
371
372impl<E: ElementType> Deref for CuSharedMemoryBuffer<E> {
373    type Target = [E];
374
375    fn deref(&self) -> &Self::Target {
376        let ptr = unsafe { (*self.region.mmap.get()).as_ptr().add(self.offset_bytes) as *const E };
377        unsafe { std::slice::from_raw_parts(ptr, self.len_elements) }
378    }
379}
380
381impl<E: ElementType> DerefMut for CuSharedMemoryBuffer<E> {
382    fn deref_mut(&mut self) -> &mut Self::Target {
383        let ptr = unsafe {
384            (*self.region.mmap.get())
385                .as_mut_ptr()
386                .add(self.offset_bytes) as *mut E
387        };
388        unsafe { std::slice::from_raw_parts_mut(ptr, self.len_elements) }
389    }
390}
391
392impl<E: ElementType> ArrayLike for CuSharedMemoryBuffer<E> {
393    type Element = E;
394}
395
396impl<E: ElementType> Encode for CuSharedMemoryBuffer<E> {
397    fn encode<Enc: Encoder>(&self, encoder: &mut Enc) -> Result<(), EncodeError> {
398        let len = self.len_elements as u64;
399        Encode::encode(&len, encoder)?;
400        for value in self.deref() {
401            value.encode(encoder)?;
402        }
403        Ok(())
404    }
405}
406
407impl<E: ElementType + 'static> Decode<()> for CuSharedMemoryBuffer<E> {
408    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
409        let len = <u64 as Decode<()>>::decode(decoder)? as usize;
410        let mut vec = Vec::with_capacity(len);
411        for _ in 0..len {
412            vec.push(E::decode(decoder)?);
413        }
414        Self::from_vec_detached(vec).map_err(|e| DecodeError::OtherString(e.to_string()))
415    }
416}
417
418/// A handle to a pooled or detached object.
419///
420/// For onboard usages, large payloads should typically be pooled. The detached form exists for
421/// offline/deserialization flows and for payloads that are intentionally heap-backed instead of
422/// pool-backed.
423pub enum CuHandleInner<T: Debug + Send + Sync> {
424    Pooled(ReusableOwned<Box<T>>),
425    Detached(Box<T>), // Should only be used in offline cases (e.g. deserialization)
426}
427
428impl<T> Debug for CuHandleInner<T>
429where
430    T: Debug + Send + Sync,
431{
432    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433        match self {
434            CuHandleInner::Pooled(r) => {
435                write!(f, "Pooled: {:?}", r.deref().deref())
436            }
437            CuHandleInner::Detached(r) => write!(f, "Detached: {r:?}"),
438        }
439    }
440}
441
442impl<T> CuHandleInner<T>
443where
444    T: Debug + Send + Sync,
445{
446    fn inner_ref(&self) -> &T {
447        match self {
448            CuHandleInner::Pooled(pooled) => pooled.deref().as_ref(),
449            CuHandleInner::Detached(detached) => detached.deref(),
450        }
451    }
452
453    fn inner_mut(&mut self) -> &mut T {
454        match self {
455            CuHandleInner::Pooled(pooled) => pooled.deref_mut().as_mut(),
456            CuHandleInner::Detached(detached) => detached.deref_mut(),
457        }
458    }
459}
460
461impl<T> AsRef<T> for CuHandleInner<T>
462where
463    T: Debug + Send + Sync,
464{
465    fn as_ref(&self) -> &T {
466        self.inner_ref()
467    }
468}
469
470impl<T> AsMut<T> for CuHandleInner<T>
471where
472    T: Debug + Send + Sync,
473{
474    fn as_mut(&mut self) -> &mut T {
475        self.inner_mut()
476    }
477}
478
479impl<T: ArrayLike> Deref for CuHandleInner<T> {
480    type Target = [T::Element];
481
482    fn deref(&self) -> &Self::Target {
483        self.inner_ref().deref()
484    }
485}
486
487impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
488    fn deref_mut(&mut self) -> &mut Self::Target {
489        self.inner_mut().deref_mut()
490    }
491}
492
493/// A shareable handle to a pooled or detached object.
494///
495/// When `T: ArrayLike`, the handle also participates in Copper's buffer pool APIs.
496#[derive(Debug)]
497pub struct CuHandle<T: Debug + Send + Sync>(Arc<Mutex<CuHandleInner<T>>>);
498
499impl<T: Debug + Send + Sync> Clone for CuHandle<T> {
500    fn clone(&self) -> Self {
501        Self(self.0.clone())
502    }
503}
504
505impl<T: Debug + Send + Sync> Deref for CuHandle<T> {
506    type Target = Arc<Mutex<CuHandleInner<T>>>;
507
508    fn deref(&self) -> &Self::Target {
509        &self.0
510    }
511}
512
513impl<T: Debug + Send + Sync> CuHandle<T> {
514    /// Create a new CuHandle not part of a Pool (not for onboard usages, use pools instead)
515    pub fn new_detached(inner: T) -> Self {
516        Self::new_detached_box(Box::new(inner))
517    }
518
519    /// Create a detached handle from an already heap-allocated object.
520    pub fn new_detached_box(inner: Box<T>) -> Self {
521        CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
522    }
523
524    /// Safely access the inner value, applying a closure to it.
525    pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
526        let lock = lock_unpoison(&self.0);
527        f(&*lock)
528    }
529
530    /// Mutably access the inner value, applying a closure to it.
531    pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
532        let mut lock = lock_unpoison(&self.0);
533        f(&mut *lock)
534    }
535}
536
537impl<U> Serialize for CuHandle<Vec<U>>
538where
539    U: ElementType + Serialize + 'static,
540{
541    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
542        let inner = lock_unpoison(&self.0);
543        inner.inner_ref().serialize(serializer)
544    }
545}
546
547impl<'de, U> Deserialize<'de> for CuHandle<Vec<U>>
548where
549    U: ElementType + Deserialize<'de> + 'static,
550{
551    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
552        Vec::<U>::deserialize(deserializer).map(CuHandle::new_detached)
553    }
554}
555
556impl<U> Serialize for CuHandle<CuSharedMemoryBuffer<U>>
557where
558    U: ElementType + Serialize + 'static,
559{
560    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
561        let inner = lock_unpoison(&self.0);
562        let buffer = inner.inner_ref();
563
564        if shared_handle_serialization_enabled()
565            && let Some(descriptor) = buffer.descriptor()
566        {
567            return descriptor.serialize(serializer);
568        }
569
570        buffer.deref().serialize(serializer)
571    }
572}
573
574impl<'de, U> Deserialize<'de> for CuHandle<CuSharedMemoryBuffer<U>>
575where
576    U: ElementType + Deserialize<'de> + 'static,
577{
578    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
579        enum Repr<U> {
580            Descriptor(CuSharedMemoryHandleDescriptor),
581            Data(Vec<U>),
582        }
583
584        impl<'de, U> Deserialize<'de> for Repr<U>
585        where
586            U: ElementType + Deserialize<'de>,
587        {
588            fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
589                struct ReprVisitor<U>(PhantomData<U>);
590
591                impl<'de, U> Visitor<'de> for ReprVisitor<U>
592                where
593                    U: ElementType + Deserialize<'de>,
594                {
595                    type Value = Repr<U>;
596
597                    fn expecting(
598                        &self,
599                        formatter: &mut std::fmt::Formatter<'_>,
600                    ) -> std::fmt::Result {
601                        formatter
602                            .write_str("a shared-memory handle descriptor or an element sequence")
603                    }
604
605                    fn visit_seq<A: SeqAccess<'de>>(self, seq: A) -> Result<Self::Value, A::Error> {
606                        let data =
607                            Vec::<U>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
608                        Ok(Repr::Data(data))
609                    }
610
611                    fn visit_map<A: MapAccess<'de>>(self, map: A) -> Result<Self::Value, A::Error> {
612                        let descriptor = CuSharedMemoryHandleDescriptor::deserialize(
613                            de::value::MapAccessDeserializer::new(map),
614                        )?;
615                        Ok(Repr::Descriptor(descriptor))
616                    }
617                }
618
619                deserializer.deserialize_any(ReprVisitor(PhantomData))
620            }
621        }
622
623        match Repr::<U>::deserialize(deserializer)? {
624            Repr::Descriptor(descriptor) => CuSharedMemoryBuffer::from_descriptor(&descriptor)
625                .map(CuHandle::new_detached)
626                .map_err(de::Error::custom),
627            Repr::Data(data) => CuSharedMemoryBuffer::from_vec_detached(data)
628                .map(CuHandle::new_detached)
629                .map_err(de::Error::custom),
630        }
631    }
632}
633
634impl<T: ArrayLike + Encode> Encode for CuHandle<T>
635where
636    <T as ArrayLike>::Element: 'static,
637{
638    fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
639        let inner = lock_unpoison(&self.0);
640        crate::monitoring::record_payload_handle_bytes(
641            inner.inner_ref().len() * size_of::<T::Element>(),
642        );
643        inner.inner_ref().encode(encoder)
644    }
645}
646
647impl<T: Debug + Send + Sync> Default for CuHandle<T> {
648    fn default() -> Self {
649        panic!("Cannot create a default CuHandle")
650    }
651}
652
653impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
654    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
655        let vec: Vec<U> = Vec::decode(decoder)?;
656        Ok(CuHandle::new_detached(vec))
657    }
658}
659
660impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<CuSharedMemoryBuffer<U>> {
661    fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
662        let buffer = CuSharedMemoryBuffer::<U>::decode(decoder)?;
663        Ok(CuHandle::new_detached(buffer))
664    }
665}
666
667/// A CuPool is a pool of buffers that can be shared between different parts of the code.
668/// Handles can be stored locally in the tasks and shared between them.
669pub trait CuPool<T: ArrayLike>: PoolMonitor {
670    /// Acquire a buffer from the pool.
671    fn acquire(&self) -> Option<CuHandle<T>>;
672
673    /// Copy data from a handle to a new handle from the pool.
674    fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
675    where
676        O: ArrayLike<Element = T::Element>;
677}
678
679/// A device memory pool can copy data from a device to a host memory pool on top.
680pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
681    /// Takes a handle to a device buffer and copies it into a host buffer pool.
682    /// It returns a new handle from the host pool with the data from the device handle given.
683    fn copy_to_host_pool<O>(
684        &self,
685        from_device_handle: &CuHandle<T>,
686        to_host_handle: &mut CuHandle<O>,
687    ) -> CuResult<()>
688    where
689        O: ArrayLike<Element = T::Element>;
690}
691
692/// A pool of host memory buffers.
693pub struct CuHostMemoryPool<T> {
694    /// Underlying pool of host buffers.
695    // Being an Arc is a requirement of try_pull_owned() so buffers can refer back to the pool.
696    id: PoolID,
697    pool: Arc<Pool<Box<T>>>,
698    size: usize,
699    buffer_size: usize,
700}
701
702impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
703    pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
704    where
705        F: Fn() -> T,
706    {
707        let pool = Arc::new(Pool::new(size, move || Box::new(buffer_initializer())));
708        let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
709
710        let og = Self {
711            id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
712            pool,
713            size,
714            buffer_size,
715        };
716        let og = Arc::new(og);
717        register_pool(og.clone());
718        Ok(og)
719    }
720}
721
722impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
723    fn id(&self) -> PoolID {
724        self.id
725    }
726
727    fn space_left(&self) -> usize {
728        self.pool.len()
729    }
730
731    fn total_size(&self) -> usize {
732        self.size
733    }
734
735    fn buffer_size(&self) -> usize {
736        self.buffer_size
737    }
738}
739
740impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
741    fn acquire(&self) -> Option<CuHandle<T>> {
742        let owned_object = self.pool.try_pull_owned(); // Use the owned version
743
744        owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
745    }
746
747    fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
748        let to_handle = self.acquire().expect("No available buffers in the pool");
749        {
750            let from_lock = lock_unpoison(&from.0);
751            let mut to_lock = lock_unpoison(&to_handle.0);
752            to_lock.inner_mut().copy_from_slice(from_lock.inner_ref());
753        }
754        to_handle
755    }
756}
757
758/// A pool of fixed-size shared-memory buffers that can be leased to a child
759/// process without copying the underlying bytes.
760pub struct CuSharedMemoryPool<E: ElementType> {
761    id: PoolID,
762    pool: Arc<Pool<Box<CuSharedMemoryBuffer<E>>>>,
763    size: usize,
764    buffer_size: usize,
765}
766
767impl<E: ElementType + 'static> CuSharedMemoryPool<E> {
768    pub fn new(id: &str, size: usize, elements_per_buffer: usize) -> CuResult<Arc<Self>> {
769        let slot_stride = shared_slot_stride::<E>(elements_per_buffer.max(1));
770        let region = CuSharedMemoryRegion::create(
771            slot_stride
772                .checked_mul(size)
773                .ok_or_else(|| cu29_traits::CuError::from("shared memory pool size overflow"))?,
774        )?;
775        let next_slot = Arc::new(std::sync::atomic::AtomicUsize::new(0));
776        let initializer_region = region.clone();
777        let initializer_next_slot = next_slot.clone();
778        let pool = Arc::new(Pool::new(size, move || {
779            let slot = initializer_next_slot.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
780            assert!(slot < size, "shared memory pool slot index overflow");
781            Box::new(CuSharedMemoryBuffer::from_region(
782                initializer_region.clone(),
783                slot * slot_stride,
784                elements_per_buffer,
785            ))
786        }));
787
788        let pool = Arc::new(Self {
789            id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
790            pool,
791            size,
792            buffer_size: elements_per_buffer * size_of::<E>(),
793        });
794        register_pool(pool.clone());
795        Ok(pool)
796    }
797}
798
799impl<E: ElementType> PoolMonitor for CuSharedMemoryPool<E> {
800    fn id(&self) -> PoolID {
801        self.id
802    }
803
804    fn space_left(&self) -> usize {
805        self.pool.len()
806    }
807
808    fn total_size(&self) -> usize {
809        self.size
810    }
811
812    fn buffer_size(&self) -> usize {
813        self.buffer_size
814    }
815}
816
817impl<E: ElementType> CuPool<CuSharedMemoryBuffer<E>> for CuSharedMemoryPool<E> {
818    fn acquire(&self) -> Option<CuHandle<CuSharedMemoryBuffer<E>>> {
819        self.pool
820            .try_pull_owned()
821            .map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
822    }
823
824    fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<CuSharedMemoryBuffer<E>>
825    where
826        O: ArrayLike<Element = E>,
827    {
828        let to_handle = self.acquire().expect("No available buffers in the pool");
829        {
830            let from_lock = lock_unpoison(&from.0);
831            let mut to_lock = lock_unpoison(&to_handle.0);
832            to_lock.inner_mut().copy_from_slice(from_lock.inner_ref());
833        }
834        to_handle
835    }
836}
837
838impl<E: ElementType + 'static> ArrayLike for Vec<E> {
839    type Element = E;
840}
841
842#[cfg(all(feature = "cuda", not(target_os = "macos")))]
843mod cuda {
844    use super::*;
845    use cu29_traits::CuError;
846    use cudarc::driver::{
847        CudaContext, CudaSlice, CudaStream, DeviceRepr, HostSlice, SyncOnDrop, ValidAsZeroBits,
848    };
849    use std::sync::Arc;
850
851    #[derive(Debug)]
852    pub struct CudaSliceWrapper<E>(CudaSlice<E>);
853
854    impl<E> Deref for CudaSliceWrapper<E>
855    where
856        E: ElementType,
857    {
858        type Target = [E];
859
860        fn deref(&self) -> &Self::Target {
861            // Implement logic to return a slice
862            panic!("You need to copy data to host memory pool before accessing it.");
863        }
864    }
865
866    impl<E> DerefMut for CudaSliceWrapper<E>
867    where
868        E: ElementType,
869    {
870        fn deref_mut(&mut self) -> &mut Self::Target {
871            panic!("You need to copy data to host memory pool before accessing it.");
872        }
873    }
874
875    impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
876        type Element = E;
877    }
878
879    impl<E> CudaSliceWrapper<E> {
880        pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
881            &self.0
882        }
883
884        pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
885            &mut self.0
886        }
887    }
888
889    // Create a wrapper type to bridge between ArrayLike and HostSlice
890    pub struct HostSliceWrapper<'a, T: ArrayLike> {
891        inner: &'a T,
892    }
893
894    impl<T: ArrayLike> HostSlice<T::Element> for HostSliceWrapper<'_, T> {
895        fn len(&self) -> usize {
896            self.inner.len()
897        }
898
899        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
900        unsafe fn stream_synced_slice<'b>(
901            &'b self,
902            stream: &'b CudaStream,
903        ) -> (&'b [T::Element], SyncOnDrop<'b>) {
904            (self.inner.deref(), SyncOnDrop::sync_stream(stream))
905        }
906
907        // SAFETY: This wrapper cannot provide mutable access; callers must not rely on this.
908        unsafe fn stream_synced_mut_slice<'b>(
909            &'b mut self,
910            _stream: &'b CudaStream,
911        ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
912            panic!("Cannot get mutable reference from immutable wrapper")
913        }
914    }
915
916    // Mutable wrapper
917    pub struct HostSliceMutWrapper<'a, T: ArrayLike> {
918        inner: &'a mut T,
919    }
920
921    impl<T: ArrayLike> HostSlice<T::Element> for HostSliceMutWrapper<'_, T> {
922        fn len(&self) -> usize {
923            self.inner.len()
924        }
925
926        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
927        unsafe fn stream_synced_slice<'b>(
928            &'b self,
929            stream: &'b CudaStream,
930        ) -> (&'b [T::Element], SyncOnDrop<'b>) {
931            (self.inner.deref(), SyncOnDrop::sync_stream(stream))
932        }
933
934        // SAFETY: HostSlice requires the returned slice to remain valid for 'b.
935        unsafe fn stream_synced_mut_slice<'b>(
936            &'b mut self,
937            stream: &'b CudaStream,
938        ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
939            (self.inner.deref_mut(), SyncOnDrop::sync_stream(stream))
940        }
941    }
942
943    // Add helper methods to the CuCudaPool implementation
944    impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
945        // Helper method to get a HostSliceWrapper from a CuHandleInner
946        fn get_host_slice_wrapper<O: ArrayLike<Element = E>>(
947            handle_inner: &CuHandleInner<O>,
948        ) -> HostSliceWrapper<'_, O> {
949            HostSliceWrapper {
950                inner: handle_inner.inner_ref(),
951            }
952        }
953
954        // Helper method to get a HostSliceMutWrapper from a CuHandleInner
955        fn get_host_slice_mut_wrapper<O: ArrayLike<Element = E>>(
956            handle_inner: &mut CuHandleInner<O>,
957        ) -> HostSliceMutWrapper<'_, O> {
958            HostSliceMutWrapper {
959                inner: handle_inner.inner_mut(),
960            }
961        }
962    }
963    /// A pool of CUDA memory buffers.
964    pub struct CuCudaPool<E>
965    where
966        E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
967    {
968        id: PoolID,
969        stream: Arc<CudaStream>,
970        pool: Arc<Pool<Box<CudaSliceWrapper<E>>>>,
971        nb_buffers: usize,
972        nb_element_per_buffer: usize,
973    }
974
975    impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
976        #[allow(dead_code)]
977        pub fn new(
978            id: &'static str,
979            ctx: Arc<CudaContext>,
980            nb_buffers: usize,
981            nb_element_per_buffer: usize,
982        ) -> CuResult<Self> {
983            let stream = ctx.default_stream();
984            let pool = (0..nb_buffers)
985                .map(|_| {
986                    stream
987                        .alloc_zeros(nb_element_per_buffer)
988                        .map(CudaSliceWrapper)
989                        .map(Box::new)
990                        .map_err(|_| "Failed to allocate device memory")
991                })
992                .collect::<Result<Vec<_>, _>>()?;
993
994            Ok(Self {
995                id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
996                stream,
997                pool: Arc::new(Pool::from_vec(pool)),
998                nb_buffers,
999                nb_element_per_buffer,
1000            })
1001        }
1002    }
1003
1004    impl<E> PoolMonitor for CuCudaPool<E>
1005    where
1006        E: DeviceRepr + ElementType + ValidAsZeroBits,
1007    {
1008        fn id(&self) -> PoolID {
1009            self.id
1010        }
1011
1012        fn space_left(&self) -> usize {
1013            self.pool.len()
1014        }
1015
1016        fn total_size(&self) -> usize {
1017            self.nb_buffers
1018        }
1019
1020        fn buffer_size(&self) -> usize {
1021            self.nb_element_per_buffer * size_of::<E>()
1022        }
1023    }
1024
1025    impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
1026    where
1027        E: DeviceRepr + ElementType + ValidAsZeroBits,
1028    {
1029        fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
1030            self.pool
1031                .try_pull_owned()
1032                .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
1033        }
1034
1035        fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
1036        where
1037            O: ArrayLike<Element = E>,
1038        {
1039            let to_handle = self.acquire().expect("No available buffers in the pool");
1040
1041            {
1042                let from_lock = lock_unpoison(&from_handle.0);
1043                let mut to_lock = lock_unpoison(&to_handle.0);
1044
1045                match &mut *to_lock {
1046                    CuHandleInner::Detached(to) => {
1047                        let wrapper = Self::get_host_slice_wrapper(&*from_lock);
1048                        self.stream
1049                            .memcpy_htod(&wrapper, to.deref_mut().as_cuda_slice_mut())
1050                            .expect("Failed to copy data to device");
1051                    }
1052                    CuHandleInner::Pooled(to) => {
1053                        let wrapper = Self::get_host_slice_wrapper(&*from_lock);
1054                        self.stream
1055                            .memcpy_htod(&wrapper, to.deref_mut().as_mut().as_cuda_slice_mut())
1056                            .expect("Failed to copy data to device");
1057                    }
1058                }
1059            } // locks are dropped here
1060            to_handle // now we can safely return to_handle
1061        }
1062    }
1063
1064    impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
1065    where
1066        E: ElementType + ValidAsZeroBits + DeviceRepr,
1067    {
1068        /// Copy from device to host
1069        fn copy_to_host_pool<O>(
1070            &self,
1071            device_handle: &CuHandle<CudaSliceWrapper<E>>,
1072            host_handle: &mut CuHandle<O>,
1073        ) -> Result<(), CuError>
1074        where
1075            O: ArrayLike<Element = E>,
1076        {
1077            let device_lock = device_handle.lock().map_err(|e| {
1078                CuError::from("Device handle mutex poisoned").add_cause(&e.to_string())
1079            })?;
1080            let mut host_lock = host_handle.lock().map_err(|e| {
1081                CuError::from("Host handle mutex poisoned").add_cause(&e.to_string())
1082            })?;
1083            let src = match &*device_lock {
1084                CuHandleInner::Pooled(source) => source.deref().as_ref().as_cuda_slice(),
1085                CuHandleInner::Detached(source) => source.deref().as_cuda_slice(),
1086            };
1087            let mut wrapper = Self::get_host_slice_mut_wrapper(&mut *host_lock);
1088            self.stream.memcpy_dtoh(src, &mut wrapper).map_err(|e| {
1089                CuError::from("Failed to copy data from device to host").add_cause(&e.to_string())
1090            })?;
1091            Ok(())
1092        }
1093    }
1094}
1095
1096#[derive(Debug)]
1097/// A buffer that is aligned to a specific size with the Element of type E.
1098pub struct AlignedBuffer<E: ElementType> {
1099    ptr: *mut E,
1100    size: usize,
1101    layout: Layout,
1102}
1103
1104impl<E: ElementType> AlignedBuffer<E> {
1105    pub fn new(num_elements: usize, alignment: usize) -> Self {
1106        assert!(
1107            num_elements > 0 && size_of::<E>() > 0,
1108            "AlignedBuffer requires a non-zero element count and non-zero-sized element type"
1109        );
1110        let alignment = alignment.max(align_of::<E>());
1111        let alloc_size = num_elements
1112            .checked_mul(size_of::<E>())
1113            .expect("AlignedBuffer allocation size overflow");
1114        let layout = Layout::from_size_align(alloc_size, alignment).unwrap();
1115        // SAFETY: layout describes a valid, non-zero allocation request.
1116        let ptr = unsafe { alloc(layout) as *mut E };
1117        if ptr.is_null() {
1118            panic!("Failed to allocate memory");
1119        }
1120        // SAFETY: ptr is valid for writes of `num_elements` elements.
1121        unsafe {
1122            for i in 0..num_elements {
1123                std::ptr::write(ptr.add(i), E::default());
1124            }
1125        }
1126        Self {
1127            ptr,
1128            size: num_elements,
1129            layout,
1130        }
1131    }
1132}
1133
1134impl<E: ElementType> Deref for AlignedBuffer<E> {
1135    type Target = [E];
1136
1137    fn deref(&self) -> &Self::Target {
1138        // SAFETY: `new` initializes all elements and keeps the pointer aligned.
1139        unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
1140    }
1141}
1142
1143impl<E: ElementType> DerefMut for AlignedBuffer<E> {
1144    fn deref_mut(&mut self) -> &mut Self::Target {
1145        // SAFETY: `new` initializes all elements and keeps the pointer aligned.
1146        unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
1147    }
1148}
1149
1150impl<E: ElementType> Drop for AlignedBuffer<E> {
1151    fn drop(&mut self) {
1152        // SAFETY: `ptr` was allocated with `layout` in `new`.
1153        unsafe { dealloc(self.ptr as *mut u8, self.layout) }
1154    }
1155}
1156
1157#[cfg(test)]
1158mod tests {
1159    use super::*;
1160
1161    #[test]
1162    fn test_pool() {
1163        use std::cell::RefCell;
1164        let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
1165        let holding = objs.borrow().clone();
1166        let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
1167        let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
1168            .unwrap();
1169
1170        let obj1 = pool.acquire().unwrap();
1171        {
1172            let obj2 = pool.acquire().unwrap();
1173            assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
1174            assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
1175            assert_eq!(pool.space_left(), 1);
1176        }
1177        assert_eq!(pool.space_left(), 2);
1178
1179        let obj3 = pool.acquire().unwrap();
1180        assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
1181
1182        assert_eq!(pool.space_left(), 1);
1183
1184        let _obj4 = pool.acquire().unwrap();
1185        assert_eq!(pool.space_left(), 0);
1186
1187        let obj5 = pool.acquire();
1188        assert!(obj5.is_none());
1189    }
1190
1191    #[cfg(all(feature = "cuda", has_nvidia_gpu))]
1192    #[test]
1193    fn test_cuda_pool() {
1194        use crate::pool::cuda::CuCudaPool;
1195        use cudarc::driver::CudaContext;
1196        let ctx = CudaContext::new(0).unwrap();
1197        let pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
1198
1199        let _obj1 = pool.acquire().unwrap();
1200
1201        {
1202            let _obj2 = pool.acquire().unwrap();
1203            assert_eq!(pool.space_left(), 1);
1204        }
1205        assert_eq!(pool.space_left(), 2);
1206
1207        let _obj3 = pool.acquire().unwrap();
1208
1209        assert_eq!(pool.space_left(), 1);
1210
1211        let _obj4 = pool.acquire().unwrap();
1212        assert_eq!(pool.space_left(), 0);
1213
1214        let obj5 = pool.acquire();
1215        assert!(obj5.is_none());
1216    }
1217
1218    #[cfg(all(feature = "cuda", has_nvidia_gpu))]
1219    #[test]
1220    fn test_copy_roundtrip() {
1221        use crate::pool::cuda::CuCudaPool;
1222        use cudarc::driver::CudaContext;
1223        let ctx = CudaContext::new(0).unwrap();
1224        let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
1225        let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
1226
1227        let cuda_handle = {
1228            let mut initial_handle = host_pool.acquire().unwrap();
1229            {
1230                let mut inner_initial_handle = initial_handle.lock().unwrap();
1231                if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
1232                    pooled[0] = 42.0;
1233                } else {
1234                    panic!();
1235                }
1236            }
1237
1238            // send that to the GPU
1239            cuda_pool.copy_from(&mut initial_handle)
1240        };
1241
1242        // get it back to the host
1243        let mut final_handle = host_pool.acquire().unwrap();
1244        cuda_pool
1245            .copy_to_host_pool(&cuda_handle, &mut final_handle)
1246            .unwrap();
1247
1248        let value = final_handle.lock().unwrap().deref().deref()[0];
1249        assert_eq!(value, 42.0);
1250    }
1251}