1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use hashbrown::HashMap;
8use object_pool::{Pool, ReusableOwned};
9use serde::de::{self, MapAccess, SeqAccess, Visitor};
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11use smallvec::SmallVec;
12use std::alloc::{Layout, alloc, dealloc};
13use std::cell::Cell;
14use std::cell::UnsafeCell;
15use std::fmt::Debug;
16use std::fs::OpenOptions;
17use std::marker::PhantomData;
18use std::mem::{align_of, size_of};
19use std::ops::{Deref, DerefMut};
20use std::path::{Path, PathBuf};
21use std::sync::{Arc, Mutex, MutexGuard, OnceLock};
22
23use memmap2::{MmapMut, MmapOptions};
24use tempfile::NamedTempFile;
25
26type PoolID = ArrayString<64>;
27
28pub trait PoolMonitor: Send + Sync {
30 fn id(&self) -> PoolID;
32
33 fn space_left(&self) -> usize;
35
36 fn total_size(&self) -> usize;
38
39 fn buffer_size(&self) -> usize;
41}
42
43static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
44const MAX_POOLS: usize = 16;
45
46fn lock_unpoison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
47 match mutex.lock() {
48 Ok(guard) => guard,
49 Err(poison) => poison.into_inner(),
50 }
51}
52
53fn register_pool(pool: Arc<dyn PoolMonitor>) {
55 POOL_REGISTRY
56 .get_or_init(|| Mutex::new(HashMap::new()))
57 .lock()
58 .unwrap_or_else(|poison| poison.into_inner())
59 .insert(pool.id().to_string(), pool);
60}
61
62type PoolStats = (PoolID, usize, usize, usize);
63
64pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
67 let registry_lock = match POOL_REGISTRY.get() {
69 Some(lock) => lock_unpoison(lock),
70 None => return SmallVec::new(), };
72 let mut result = SmallVec::with_capacity(MAX_POOLS);
73 for pool in registry_lock.values() {
74 result.push((
75 pool.id(),
76 pool.space_left(),
77 pool.total_size(),
78 pool.buffer_size(),
79 ));
80 }
81 result
82}
83
84pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
86 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
87 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
88}
89
90impl<T> ElementType for T
92where
93 T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
94 T: Encode,
95 T: Decode<()>,
96{
97 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
98 self.encode(encoder)
99 }
100
101 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
102 Self::decode(decoder)
103 }
104}
105
106pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
107 type Element: ElementType;
108}
109
110thread_local! {
111 static SHARED_HANDLE_SERIALIZATION_ENABLED: Cell<bool> = const { Cell::new(false) };
112}
113
114pub struct SharedHandleSerializationGuard {
115 previous: bool,
116}
117
118impl Drop for SharedHandleSerializationGuard {
119 fn drop(&mut self) {
120 SHARED_HANDLE_SERIALIZATION_ENABLED.with(|enabled| enabled.set(self.previous));
121 }
122}
123
124pub fn enable_shared_handle_serialization() -> SharedHandleSerializationGuard {
125 let previous = SHARED_HANDLE_SERIALIZATION_ENABLED.with(|enabled| {
126 let previous = enabled.get();
127 enabled.set(true);
128 previous
129 });
130 SharedHandleSerializationGuard { previous }
131}
132
133fn shared_handle_serialization_enabled() -> bool {
134 SHARED_HANDLE_SERIALIZATION_ENABLED.with(Cell::get)
135}
136
137#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum CuSharedMemoryElementType {
140 U8,
141 U16,
142 U32,
143 U64,
144 I8,
145 I16,
146 I32,
147 I64,
148 F32,
149 F64,
150}
151
152impl CuSharedMemoryElementType {
153 pub fn of<E: ElementType + 'static>() -> Option<Self> {
154 let type_id = core::any::TypeId::of::<E>();
155 if type_id == core::any::TypeId::of::<u8>() {
156 Some(Self::U8)
157 } else if type_id == core::any::TypeId::of::<u16>() {
158 Some(Self::U16)
159 } else if type_id == core::any::TypeId::of::<u32>() {
160 Some(Self::U32)
161 } else if type_id == core::any::TypeId::of::<u64>() {
162 Some(Self::U64)
163 } else if type_id == core::any::TypeId::of::<i8>() {
164 Some(Self::I8)
165 } else if type_id == core::any::TypeId::of::<i16>() {
166 Some(Self::I16)
167 } else if type_id == core::any::TypeId::of::<i32>() {
168 Some(Self::I32)
169 } else if type_id == core::any::TypeId::of::<i64>() {
170 Some(Self::I64)
171 } else if type_id == core::any::TypeId::of::<f32>() {
172 Some(Self::F32)
173 } else if type_id == core::any::TypeId::of::<f64>() {
174 Some(Self::F64)
175 } else {
176 None
177 }
178 }
179}
180
181#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
182pub struct CuSharedMemoryHandleDescriptor {
183 #[serde(rename = "__cu_shm_handle__")]
184 pub marker: bool,
185 pub path: String,
186 pub offset_bytes: usize,
187 pub len_elements: usize,
188 pub element_type: CuSharedMemoryElementType,
189}
190
191impl CuSharedMemoryHandleDescriptor {
192 fn new(
193 path: String,
194 offset_bytes: usize,
195 len_elements: usize,
196 element_type: CuSharedMemoryElementType,
197 ) -> Self {
198 Self {
199 marker: true,
200 path,
201 offset_bytes,
202 len_elements,
203 element_type,
204 }
205 }
206}
207
208struct CuSharedMemoryRegion {
209 path: PathBuf,
210 mmap: UnsafeCell<MmapMut>,
211 _backing_file: Option<NamedTempFile>,
212}
213
214impl CuSharedMemoryRegion {
215 fn create(byte_len: usize) -> CuResult<Arc<Self>> {
216 let file = NamedTempFile::new()
217 .map_err(|e| cu29_traits::CuError::new_with_cause("create shared memory file", e))?;
218 file.as_file()
219 .set_len(byte_len as u64)
220 .map_err(|e| cu29_traits::CuError::new_with_cause("size shared memory file", e))?;
221 let mmap = unsafe {
222 MmapOptions::new()
223 .len(byte_len)
224 .map_mut(file.as_file())
225 .map_err(|e| cu29_traits::CuError::new_with_cause("map shared memory file", e))?
226 };
227 let region = Arc::new(Self {
228 path: file.path().to_path_buf(),
229 mmap: UnsafeCell::new(mmap),
230 _backing_file: Some(file),
231 });
232 cache_shared_region(region.clone());
233 Ok(region)
234 }
235
236 fn open(path: &Path) -> CuResult<Arc<Self>> {
237 if let Some(region) = cached_shared_region(path) {
238 return Ok(region);
239 }
240
241 let file = OpenOptions::new()
242 .read(true)
243 .write(true)
244 .open(path)
245 .map_err(|e| cu29_traits::CuError::new_with_cause("open shared memory file", e))?;
246 let len = file
247 .metadata()
248 .map_err(|e| cu29_traits::CuError::new_with_cause("stat shared memory file", e))?
249 .len() as usize;
250 let mmap = unsafe {
251 MmapOptions::new()
252 .len(len)
253 .map_mut(&file)
254 .map_err(|e| cu29_traits::CuError::new_with_cause("map shared memory file", e))?
255 };
256 let region = Arc::new(Self {
257 path: path.to_path_buf(),
258 mmap: UnsafeCell::new(mmap),
259 _backing_file: None,
260 });
261 cache_shared_region(region.clone());
262 Ok(region)
263 }
264}
265
266impl Debug for CuSharedMemoryRegion {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 f.debug_struct("CuSharedMemoryRegion")
269 .field("path", &self.path)
270 .finish_non_exhaustive()
271 }
272}
273
274unsafe impl Send for CuSharedMemoryRegion {}
279unsafe impl Sync for CuSharedMemoryRegion {}
282
283fn shared_region_cache() -> &'static Mutex<HashMap<PathBuf, std::sync::Weak<CuSharedMemoryRegion>>>
284{
285 static CACHE: OnceLock<Mutex<HashMap<PathBuf, std::sync::Weak<CuSharedMemoryRegion>>>> =
286 OnceLock::new();
287 CACHE.get_or_init(|| Mutex::new(HashMap::new()))
288}
289
290fn cache_shared_region(region: Arc<CuSharedMemoryRegion>) {
291 lock_unpoison(shared_region_cache()).insert(region.path.clone(), Arc::downgrade(®ion));
292}
293
294fn cached_shared_region(path: &Path) -> Option<Arc<CuSharedMemoryRegion>> {
295 lock_unpoison(shared_region_cache())
296 .get(path)
297 .and_then(std::sync::Weak::upgrade)
298}
299
300fn shared_slot_stride<E: ElementType>(len_elements: usize) -> usize {
301 let raw_bytes = len_elements
302 .checked_mul(size_of::<E>())
303 .expect("shared memory slot size overflow");
304 let alignment = align_of::<E>().max(1);
305 raw_bytes.div_ceil(alignment) * alignment
306}
307
308#[derive(Debug)]
309pub struct CuSharedMemoryBuffer<E: ElementType> {
310 region: Arc<CuSharedMemoryRegion>,
311 offset_bytes: usize,
312 len_elements: usize,
313 _marker: PhantomData<E>,
314}
315
316impl<E: ElementType + 'static> CuSharedMemoryBuffer<E> {
317 fn from_region(
318 region: Arc<CuSharedMemoryRegion>,
319 offset_bytes: usize,
320 len_elements: usize,
321 ) -> Self {
322 Self {
323 region,
324 offset_bytes,
325 len_elements,
326 _marker: PhantomData,
327 }
328 }
329
330 pub fn from_vec_detached(data: Vec<E>) -> CuResult<Self> {
331 let len_elements = data.len();
332 let slot_stride = shared_slot_stride::<E>(len_elements.max(1));
333 let region = CuSharedMemoryRegion::create(slot_stride)?;
334 let mut buffer = Self::from_region(region, 0, len_elements);
335 if !data.is_empty() {
336 buffer.copy_from_slice(&data);
337 }
338 Ok(buffer)
339 }
340
341 pub fn from_descriptor(descriptor: &CuSharedMemoryHandleDescriptor) -> CuResult<Self> {
342 let expected = CuSharedMemoryElementType::of::<E>()
343 .ok_or_else(|| cu29_traits::CuError::from("unsupported shared memory element type"))?;
344 if descriptor.element_type != expected {
345 return Err(cu29_traits::CuError::from(
346 "shared memory descriptor element type mismatch",
347 ));
348 }
349 let region = CuSharedMemoryRegion::open(Path::new(&descriptor.path))?;
350 Ok(Self::from_region(
351 region,
352 descriptor.offset_bytes,
353 descriptor.len_elements,
354 ))
355 }
356
357 pub fn descriptor(&self) -> Option<CuSharedMemoryHandleDescriptor>
358 where
359 E: 'static,
360 {
361 CuSharedMemoryElementType::of::<E>().map(|element_type| {
362 CuSharedMemoryHandleDescriptor::new(
363 self.region.path.display().to_string(),
364 self.offset_bytes,
365 self.len_elements,
366 element_type,
367 )
368 })
369 }
370}
371
372impl<E: ElementType> Deref for CuSharedMemoryBuffer<E> {
373 type Target = [E];
374
375 fn deref(&self) -> &Self::Target {
376 let ptr = unsafe { (*self.region.mmap.get()).as_ptr().add(self.offset_bytes) as *const E };
377 unsafe { std::slice::from_raw_parts(ptr, self.len_elements) }
378 }
379}
380
381impl<E: ElementType> DerefMut for CuSharedMemoryBuffer<E> {
382 fn deref_mut(&mut self) -> &mut Self::Target {
383 let ptr = unsafe {
384 (*self.region.mmap.get())
385 .as_mut_ptr()
386 .add(self.offset_bytes) as *mut E
387 };
388 unsafe { std::slice::from_raw_parts_mut(ptr, self.len_elements) }
389 }
390}
391
392impl<E: ElementType> ArrayLike for CuSharedMemoryBuffer<E> {
393 type Element = E;
394}
395
396impl<E: ElementType> Encode for CuSharedMemoryBuffer<E> {
397 fn encode<Enc: Encoder>(&self, encoder: &mut Enc) -> Result<(), EncodeError> {
398 let len = self.len_elements as u64;
399 Encode::encode(&len, encoder)?;
400 for value in self.deref() {
401 value.encode(encoder)?;
402 }
403 Ok(())
404 }
405}
406
407impl<E: ElementType + 'static> Decode<()> for CuSharedMemoryBuffer<E> {
408 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
409 let len = <u64 as Decode<()>>::decode(decoder)? as usize;
410 let mut vec = Vec::with_capacity(len);
411 for _ in 0..len {
412 vec.push(E::decode(decoder)?);
413 }
414 Self::from_vec_detached(vec).map_err(|e| DecodeError::OtherString(e.to_string()))
415 }
416}
417
418pub enum CuHandleInner<T: Debug + Send + Sync> {
424 Pooled(ReusableOwned<Box<T>>),
425 Detached(Box<T>), }
427
428impl<T> Debug for CuHandleInner<T>
429where
430 T: Debug + Send + Sync,
431{
432 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433 match self {
434 CuHandleInner::Pooled(r) => {
435 write!(f, "Pooled: {:?}", r.deref().deref())
436 }
437 CuHandleInner::Detached(r) => write!(f, "Detached: {r:?}"),
438 }
439 }
440}
441
442impl<T> CuHandleInner<T>
443where
444 T: Debug + Send + Sync,
445{
446 fn inner_ref(&self) -> &T {
447 match self {
448 CuHandleInner::Pooled(pooled) => pooled.deref().as_ref(),
449 CuHandleInner::Detached(detached) => detached.deref(),
450 }
451 }
452
453 fn inner_mut(&mut self) -> &mut T {
454 match self {
455 CuHandleInner::Pooled(pooled) => pooled.deref_mut().as_mut(),
456 CuHandleInner::Detached(detached) => detached.deref_mut(),
457 }
458 }
459}
460
461impl<T> AsRef<T> for CuHandleInner<T>
462where
463 T: Debug + Send + Sync,
464{
465 fn as_ref(&self) -> &T {
466 self.inner_ref()
467 }
468}
469
470impl<T> AsMut<T> for CuHandleInner<T>
471where
472 T: Debug + Send + Sync,
473{
474 fn as_mut(&mut self) -> &mut T {
475 self.inner_mut()
476 }
477}
478
479impl<T: ArrayLike> Deref for CuHandleInner<T> {
480 type Target = [T::Element];
481
482 fn deref(&self) -> &Self::Target {
483 self.inner_ref().deref()
484 }
485}
486
487impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
488 fn deref_mut(&mut self) -> &mut Self::Target {
489 self.inner_mut().deref_mut()
490 }
491}
492
493#[derive(Debug)]
497pub struct CuHandle<T: Debug + Send + Sync>(Arc<Mutex<CuHandleInner<T>>>);
498
499impl<T: Debug + Send + Sync> Clone for CuHandle<T> {
500 fn clone(&self) -> Self {
501 Self(self.0.clone())
502 }
503}
504
505impl<T: Debug + Send + Sync> Deref for CuHandle<T> {
506 type Target = Arc<Mutex<CuHandleInner<T>>>;
507
508 fn deref(&self) -> &Self::Target {
509 &self.0
510 }
511}
512
513impl<T: Debug + Send + Sync> CuHandle<T> {
514 pub fn new_detached(inner: T) -> Self {
516 Self::new_detached_box(Box::new(inner))
517 }
518
519 pub fn new_detached_box(inner: Box<T>) -> Self {
521 CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
522 }
523
524 pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
526 let lock = lock_unpoison(&self.0);
527 f(&*lock)
528 }
529
530 pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
532 let mut lock = lock_unpoison(&self.0);
533 f(&mut *lock)
534 }
535}
536
537impl<U> Serialize for CuHandle<Vec<U>>
538where
539 U: ElementType + Serialize + 'static,
540{
541 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
542 let inner = lock_unpoison(&self.0);
543 inner.inner_ref().serialize(serializer)
544 }
545}
546
547impl<'de, U> Deserialize<'de> for CuHandle<Vec<U>>
548where
549 U: ElementType + Deserialize<'de> + 'static,
550{
551 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
552 Vec::<U>::deserialize(deserializer).map(CuHandle::new_detached)
553 }
554}
555
556impl<U> Serialize for CuHandle<CuSharedMemoryBuffer<U>>
557where
558 U: ElementType + Serialize + 'static,
559{
560 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
561 let inner = lock_unpoison(&self.0);
562 let buffer = inner.inner_ref();
563
564 if shared_handle_serialization_enabled()
565 && let Some(descriptor) = buffer.descriptor()
566 {
567 return descriptor.serialize(serializer);
568 }
569
570 buffer.deref().serialize(serializer)
571 }
572}
573
574impl<'de, U> Deserialize<'de> for CuHandle<CuSharedMemoryBuffer<U>>
575where
576 U: ElementType + Deserialize<'de> + 'static,
577{
578 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
579 enum Repr<U> {
580 Descriptor(CuSharedMemoryHandleDescriptor),
581 Data(Vec<U>),
582 }
583
584 impl<'de, U> Deserialize<'de> for Repr<U>
585 where
586 U: ElementType + Deserialize<'de>,
587 {
588 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
589 struct ReprVisitor<U>(PhantomData<U>);
590
591 impl<'de, U> Visitor<'de> for ReprVisitor<U>
592 where
593 U: ElementType + Deserialize<'de>,
594 {
595 type Value = Repr<U>;
596
597 fn expecting(
598 &self,
599 formatter: &mut std::fmt::Formatter<'_>,
600 ) -> std::fmt::Result {
601 formatter
602 .write_str("a shared-memory handle descriptor or an element sequence")
603 }
604
605 fn visit_seq<A: SeqAccess<'de>>(self, seq: A) -> Result<Self::Value, A::Error> {
606 let data =
607 Vec::<U>::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
608 Ok(Repr::Data(data))
609 }
610
611 fn visit_map<A: MapAccess<'de>>(self, map: A) -> Result<Self::Value, A::Error> {
612 let descriptor = CuSharedMemoryHandleDescriptor::deserialize(
613 de::value::MapAccessDeserializer::new(map),
614 )?;
615 Ok(Repr::Descriptor(descriptor))
616 }
617 }
618
619 deserializer.deserialize_any(ReprVisitor(PhantomData))
620 }
621 }
622
623 match Repr::<U>::deserialize(deserializer)? {
624 Repr::Descriptor(descriptor) => CuSharedMemoryBuffer::from_descriptor(&descriptor)
625 .map(CuHandle::new_detached)
626 .map_err(de::Error::custom),
627 Repr::Data(data) => CuSharedMemoryBuffer::from_vec_detached(data)
628 .map(CuHandle::new_detached)
629 .map_err(de::Error::custom),
630 }
631 }
632}
633
634impl<T: ArrayLike + Encode> Encode for CuHandle<T>
635where
636 <T as ArrayLike>::Element: 'static,
637{
638 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
639 let inner = lock_unpoison(&self.0);
640 crate::monitoring::record_payload_handle_bytes(
641 inner.inner_ref().len() * size_of::<T::Element>(),
642 );
643 inner.inner_ref().encode(encoder)
644 }
645}
646
647impl<T: Debug + Send + Sync> Default for CuHandle<T> {
648 fn default() -> Self {
649 panic!("Cannot create a default CuHandle")
650 }
651}
652
653impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
654 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
655 let vec: Vec<U> = Vec::decode(decoder)?;
656 Ok(CuHandle::new_detached(vec))
657 }
658}
659
660impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<CuSharedMemoryBuffer<U>> {
661 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
662 let buffer = CuSharedMemoryBuffer::<U>::decode(decoder)?;
663 Ok(CuHandle::new_detached(buffer))
664 }
665}
666
667pub trait CuPool<T: ArrayLike>: PoolMonitor {
670 fn acquire(&self) -> Option<CuHandle<T>>;
672
673 fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
675 where
676 O: ArrayLike<Element = T::Element>;
677}
678
679pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
681 fn copy_to_host_pool<O>(
684 &self,
685 from_device_handle: &CuHandle<T>,
686 to_host_handle: &mut CuHandle<O>,
687 ) -> CuResult<()>
688 where
689 O: ArrayLike<Element = T::Element>;
690}
691
692pub struct CuHostMemoryPool<T> {
694 id: PoolID,
697 pool: Arc<Pool<Box<T>>>,
698 size: usize,
699 buffer_size: usize,
700}
701
702impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
703 pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
704 where
705 F: Fn() -> T,
706 {
707 let pool = Arc::new(Pool::new(size, move || Box::new(buffer_initializer())));
708 let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
709
710 let og = Self {
711 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
712 pool,
713 size,
714 buffer_size,
715 };
716 let og = Arc::new(og);
717 register_pool(og.clone());
718 Ok(og)
719 }
720}
721
722impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
723 fn id(&self) -> PoolID {
724 self.id
725 }
726
727 fn space_left(&self) -> usize {
728 self.pool.len()
729 }
730
731 fn total_size(&self) -> usize {
732 self.size
733 }
734
735 fn buffer_size(&self) -> usize {
736 self.buffer_size
737 }
738}
739
740impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
741 fn acquire(&self) -> Option<CuHandle<T>> {
742 let owned_object = self.pool.try_pull_owned(); owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
745 }
746
747 fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
748 let to_handle = self.acquire().expect("No available buffers in the pool");
749 {
750 let from_lock = lock_unpoison(&from.0);
751 let mut to_lock = lock_unpoison(&to_handle.0);
752 to_lock.inner_mut().copy_from_slice(from_lock.inner_ref());
753 }
754 to_handle
755 }
756}
757
758pub struct CuSharedMemoryPool<E: ElementType> {
761 id: PoolID,
762 pool: Arc<Pool<Box<CuSharedMemoryBuffer<E>>>>,
763 size: usize,
764 buffer_size: usize,
765}
766
767impl<E: ElementType + 'static> CuSharedMemoryPool<E> {
768 pub fn new(id: &str, size: usize, elements_per_buffer: usize) -> CuResult<Arc<Self>> {
769 let slot_stride = shared_slot_stride::<E>(elements_per_buffer.max(1));
770 let region = CuSharedMemoryRegion::create(
771 slot_stride
772 .checked_mul(size)
773 .ok_or_else(|| cu29_traits::CuError::from("shared memory pool size overflow"))?,
774 )?;
775 let next_slot = Arc::new(std::sync::atomic::AtomicUsize::new(0));
776 let initializer_region = region.clone();
777 let initializer_next_slot = next_slot.clone();
778 let pool = Arc::new(Pool::new(size, move || {
779 let slot = initializer_next_slot.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
780 assert!(slot < size, "shared memory pool slot index overflow");
781 Box::new(CuSharedMemoryBuffer::from_region(
782 initializer_region.clone(),
783 slot * slot_stride,
784 elements_per_buffer,
785 ))
786 }));
787
788 let pool = Arc::new(Self {
789 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
790 pool,
791 size,
792 buffer_size: elements_per_buffer * size_of::<E>(),
793 });
794 register_pool(pool.clone());
795 Ok(pool)
796 }
797}
798
799impl<E: ElementType> PoolMonitor for CuSharedMemoryPool<E> {
800 fn id(&self) -> PoolID {
801 self.id
802 }
803
804 fn space_left(&self) -> usize {
805 self.pool.len()
806 }
807
808 fn total_size(&self) -> usize {
809 self.size
810 }
811
812 fn buffer_size(&self) -> usize {
813 self.buffer_size
814 }
815}
816
817impl<E: ElementType> CuPool<CuSharedMemoryBuffer<E>> for CuSharedMemoryPool<E> {
818 fn acquire(&self) -> Option<CuHandle<CuSharedMemoryBuffer<E>>> {
819 self.pool
820 .try_pull_owned()
821 .map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
822 }
823
824 fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<CuSharedMemoryBuffer<E>>
825 where
826 O: ArrayLike<Element = E>,
827 {
828 let to_handle = self.acquire().expect("No available buffers in the pool");
829 {
830 let from_lock = lock_unpoison(&from.0);
831 let mut to_lock = lock_unpoison(&to_handle.0);
832 to_lock.inner_mut().copy_from_slice(from_lock.inner_ref());
833 }
834 to_handle
835 }
836}
837
838impl<E: ElementType + 'static> ArrayLike for Vec<E> {
839 type Element = E;
840}
841
842#[cfg(all(feature = "cuda", not(target_os = "macos")))]
843mod cuda {
844 use super::*;
845 use cu29_traits::CuError;
846 use cudarc::driver::{
847 CudaContext, CudaSlice, CudaStream, DeviceRepr, HostSlice, SyncOnDrop, ValidAsZeroBits,
848 };
849 use std::sync::Arc;
850
851 #[derive(Debug)]
852 pub struct CudaSliceWrapper<E>(CudaSlice<E>);
853
854 impl<E> Deref for CudaSliceWrapper<E>
855 where
856 E: ElementType,
857 {
858 type Target = [E];
859
860 fn deref(&self) -> &Self::Target {
861 panic!("You need to copy data to host memory pool before accessing it.");
863 }
864 }
865
866 impl<E> DerefMut for CudaSliceWrapper<E>
867 where
868 E: ElementType,
869 {
870 fn deref_mut(&mut self) -> &mut Self::Target {
871 panic!("You need to copy data to host memory pool before accessing it.");
872 }
873 }
874
875 impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
876 type Element = E;
877 }
878
879 impl<E> CudaSliceWrapper<E> {
880 pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
881 &self.0
882 }
883
884 pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
885 &mut self.0
886 }
887 }
888
889 pub struct HostSliceWrapper<'a, T: ArrayLike> {
891 inner: &'a T,
892 }
893
894 impl<T: ArrayLike> HostSlice<T::Element> for HostSliceWrapper<'_, T> {
895 fn len(&self) -> usize {
896 self.inner.len()
897 }
898
899 unsafe fn stream_synced_slice<'b>(
901 &'b self,
902 stream: &'b CudaStream,
903 ) -> (&'b [T::Element], SyncOnDrop<'b>) {
904 (self.inner.deref(), SyncOnDrop::sync_stream(stream))
905 }
906
907 unsafe fn stream_synced_mut_slice<'b>(
909 &'b mut self,
910 _stream: &'b CudaStream,
911 ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
912 panic!("Cannot get mutable reference from immutable wrapper")
913 }
914 }
915
916 pub struct HostSliceMutWrapper<'a, T: ArrayLike> {
918 inner: &'a mut T,
919 }
920
921 impl<T: ArrayLike> HostSlice<T::Element> for HostSliceMutWrapper<'_, T> {
922 fn len(&self) -> usize {
923 self.inner.len()
924 }
925
926 unsafe fn stream_synced_slice<'b>(
928 &'b self,
929 stream: &'b CudaStream,
930 ) -> (&'b [T::Element], SyncOnDrop<'b>) {
931 (self.inner.deref(), SyncOnDrop::sync_stream(stream))
932 }
933
934 unsafe fn stream_synced_mut_slice<'b>(
936 &'b mut self,
937 stream: &'b CudaStream,
938 ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
939 (self.inner.deref_mut(), SyncOnDrop::sync_stream(stream))
940 }
941 }
942
943 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
945 fn get_host_slice_wrapper<O: ArrayLike<Element = E>>(
947 handle_inner: &CuHandleInner<O>,
948 ) -> HostSliceWrapper<'_, O> {
949 HostSliceWrapper {
950 inner: handle_inner.inner_ref(),
951 }
952 }
953
954 fn get_host_slice_mut_wrapper<O: ArrayLike<Element = E>>(
956 handle_inner: &mut CuHandleInner<O>,
957 ) -> HostSliceMutWrapper<'_, O> {
958 HostSliceMutWrapper {
959 inner: handle_inner.inner_mut(),
960 }
961 }
962 }
963 pub struct CuCudaPool<E>
965 where
966 E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
967 {
968 id: PoolID,
969 stream: Arc<CudaStream>,
970 pool: Arc<Pool<Box<CudaSliceWrapper<E>>>>,
971 nb_buffers: usize,
972 nb_element_per_buffer: usize,
973 }
974
975 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
976 #[allow(dead_code)]
977 pub fn new(
978 id: &'static str,
979 ctx: Arc<CudaContext>,
980 nb_buffers: usize,
981 nb_element_per_buffer: usize,
982 ) -> CuResult<Self> {
983 let stream = ctx.default_stream();
984 let pool = (0..nb_buffers)
985 .map(|_| {
986 stream
987 .alloc_zeros(nb_element_per_buffer)
988 .map(CudaSliceWrapper)
989 .map(Box::new)
990 .map_err(|_| "Failed to allocate device memory")
991 })
992 .collect::<Result<Vec<_>, _>>()?;
993
994 Ok(Self {
995 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
996 stream,
997 pool: Arc::new(Pool::from_vec(pool)),
998 nb_buffers,
999 nb_element_per_buffer,
1000 })
1001 }
1002 }
1003
1004 impl<E> PoolMonitor for CuCudaPool<E>
1005 where
1006 E: DeviceRepr + ElementType + ValidAsZeroBits,
1007 {
1008 fn id(&self) -> PoolID {
1009 self.id
1010 }
1011
1012 fn space_left(&self) -> usize {
1013 self.pool.len()
1014 }
1015
1016 fn total_size(&self) -> usize {
1017 self.nb_buffers
1018 }
1019
1020 fn buffer_size(&self) -> usize {
1021 self.nb_element_per_buffer * size_of::<E>()
1022 }
1023 }
1024
1025 impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
1026 where
1027 E: DeviceRepr + ElementType + ValidAsZeroBits,
1028 {
1029 fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
1030 self.pool
1031 .try_pull_owned()
1032 .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
1033 }
1034
1035 fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
1036 where
1037 O: ArrayLike<Element = E>,
1038 {
1039 let to_handle = self.acquire().expect("No available buffers in the pool");
1040
1041 {
1042 let from_lock = lock_unpoison(&from_handle.0);
1043 let mut to_lock = lock_unpoison(&to_handle.0);
1044
1045 match &mut *to_lock {
1046 CuHandleInner::Detached(to) => {
1047 let wrapper = Self::get_host_slice_wrapper(&*from_lock);
1048 self.stream
1049 .memcpy_htod(&wrapper, to.deref_mut().as_cuda_slice_mut())
1050 .expect("Failed to copy data to device");
1051 }
1052 CuHandleInner::Pooled(to) => {
1053 let wrapper = Self::get_host_slice_wrapper(&*from_lock);
1054 self.stream
1055 .memcpy_htod(&wrapper, to.deref_mut().as_mut().as_cuda_slice_mut())
1056 .expect("Failed to copy data to device");
1057 }
1058 }
1059 } to_handle }
1062 }
1063
1064 impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
1065 where
1066 E: ElementType + ValidAsZeroBits + DeviceRepr,
1067 {
1068 fn copy_to_host_pool<O>(
1070 &self,
1071 device_handle: &CuHandle<CudaSliceWrapper<E>>,
1072 host_handle: &mut CuHandle<O>,
1073 ) -> Result<(), CuError>
1074 where
1075 O: ArrayLike<Element = E>,
1076 {
1077 let device_lock = device_handle.lock().map_err(|e| {
1078 CuError::from("Device handle mutex poisoned").add_cause(&e.to_string())
1079 })?;
1080 let mut host_lock = host_handle.lock().map_err(|e| {
1081 CuError::from("Host handle mutex poisoned").add_cause(&e.to_string())
1082 })?;
1083 let src = match &*device_lock {
1084 CuHandleInner::Pooled(source) => source.deref().as_ref().as_cuda_slice(),
1085 CuHandleInner::Detached(source) => source.deref().as_cuda_slice(),
1086 };
1087 let mut wrapper = Self::get_host_slice_mut_wrapper(&mut *host_lock);
1088 self.stream.memcpy_dtoh(src, &mut wrapper).map_err(|e| {
1089 CuError::from("Failed to copy data from device to host").add_cause(&e.to_string())
1090 })?;
1091 Ok(())
1092 }
1093 }
1094}
1095
1096#[derive(Debug)]
1097pub struct AlignedBuffer<E: ElementType> {
1099 ptr: *mut E,
1100 size: usize,
1101 layout: Layout,
1102}
1103
1104impl<E: ElementType> AlignedBuffer<E> {
1105 pub fn new(num_elements: usize, alignment: usize) -> Self {
1106 assert!(
1107 num_elements > 0 && size_of::<E>() > 0,
1108 "AlignedBuffer requires a non-zero element count and non-zero-sized element type"
1109 );
1110 let alignment = alignment.max(align_of::<E>());
1111 let alloc_size = num_elements
1112 .checked_mul(size_of::<E>())
1113 .expect("AlignedBuffer allocation size overflow");
1114 let layout = Layout::from_size_align(alloc_size, alignment).unwrap();
1115 let ptr = unsafe { alloc(layout) as *mut E };
1117 if ptr.is_null() {
1118 panic!("Failed to allocate memory");
1119 }
1120 unsafe {
1122 for i in 0..num_elements {
1123 std::ptr::write(ptr.add(i), E::default());
1124 }
1125 }
1126 Self {
1127 ptr,
1128 size: num_elements,
1129 layout,
1130 }
1131 }
1132}
1133
1134impl<E: ElementType> Deref for AlignedBuffer<E> {
1135 type Target = [E];
1136
1137 fn deref(&self) -> &Self::Target {
1138 unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
1140 }
1141}
1142
1143impl<E: ElementType> DerefMut for AlignedBuffer<E> {
1144 fn deref_mut(&mut self) -> &mut Self::Target {
1145 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
1147 }
1148}
1149
1150impl<E: ElementType> Drop for AlignedBuffer<E> {
1151 fn drop(&mut self) {
1152 unsafe { dealloc(self.ptr as *mut u8, self.layout) }
1154 }
1155}
1156
1157#[cfg(test)]
1158mod tests {
1159 use super::*;
1160
1161 #[test]
1162 fn test_pool() {
1163 use std::cell::RefCell;
1164 let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
1165 let holding = objs.borrow().clone();
1166 let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
1167 let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
1168 .unwrap();
1169
1170 let obj1 = pool.acquire().unwrap();
1171 {
1172 let obj2 = pool.acquire().unwrap();
1173 assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
1174 assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
1175 assert_eq!(pool.space_left(), 1);
1176 }
1177 assert_eq!(pool.space_left(), 2);
1178
1179 let obj3 = pool.acquire().unwrap();
1180 assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
1181
1182 assert_eq!(pool.space_left(), 1);
1183
1184 let _obj4 = pool.acquire().unwrap();
1185 assert_eq!(pool.space_left(), 0);
1186
1187 let obj5 = pool.acquire();
1188 assert!(obj5.is_none());
1189 }
1190
1191 #[cfg(all(feature = "cuda", has_nvidia_gpu))]
1192 #[test]
1193 fn test_cuda_pool() {
1194 use crate::pool::cuda::CuCudaPool;
1195 use cudarc::driver::CudaContext;
1196 let ctx = CudaContext::new(0).unwrap();
1197 let pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
1198
1199 let _obj1 = pool.acquire().unwrap();
1200
1201 {
1202 let _obj2 = pool.acquire().unwrap();
1203 assert_eq!(pool.space_left(), 1);
1204 }
1205 assert_eq!(pool.space_left(), 2);
1206
1207 let _obj3 = pool.acquire().unwrap();
1208
1209 assert_eq!(pool.space_left(), 1);
1210
1211 let _obj4 = pool.acquire().unwrap();
1212 assert_eq!(pool.space_left(), 0);
1213
1214 let obj5 = pool.acquire();
1215 assert!(obj5.is_none());
1216 }
1217
1218 #[cfg(all(feature = "cuda", has_nvidia_gpu))]
1219 #[test]
1220 fn test_copy_roundtrip() {
1221 use crate::pool::cuda::CuCudaPool;
1222 use cudarc::driver::CudaContext;
1223 let ctx = CudaContext::new(0).unwrap();
1224 let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
1225 let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
1226
1227 let cuda_handle = {
1228 let mut initial_handle = host_pool.acquire().unwrap();
1229 {
1230 let mut inner_initial_handle = initial_handle.lock().unwrap();
1231 if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
1232 pooled[0] = 42.0;
1233 } else {
1234 panic!();
1235 }
1236 }
1237
1238 cuda_pool.copy_from(&mut initial_handle)
1240 };
1241
1242 let mut final_handle = host_pool.acquire().unwrap();
1244 cuda_pool
1245 .copy_to_host_pool(&cuda_handle, &mut final_handle)
1246 .unwrap();
1247
1248 let value = final_handle.lock().unwrap().deref().deref()[0];
1249 assert_eq!(value, 42.0);
1250 }
1251}