1use arrayvec::ArrayString;
2use bincode::de::Decoder;
3use bincode::enc::Encoder;
4use bincode::error::{DecodeError, EncodeError};
5use bincode::{Decode, Encode};
6use cu29_traits::CuResult;
7use object_pool::{Pool, ReusableOwned};
8use smallvec::SmallVec;
9use std::alloc::{alloc, dealloc, Layout};
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::ops::{Deref, DerefMut};
13use std::sync::{Arc, Mutex, OnceLock};
14
15type PoolID = ArrayString<64>;
16
17pub trait PoolMonitor: Send + Sync {
19 fn id(&self) -> PoolID;
21
22 fn space_left(&self) -> usize;
24
25 fn total_size(&self) -> usize;
27
28 fn buffer_size(&self) -> usize;
30}
31
32static POOL_REGISTRY: OnceLock<Mutex<HashMap<String, Arc<dyn PoolMonitor>>>> = OnceLock::new();
33const MAX_POOLS: usize = 16;
34
35fn register_pool(pool: Arc<dyn PoolMonitor>) {
37 POOL_REGISTRY
38 .get_or_init(|| Mutex::new(HashMap::new()))
39 .lock()
40 .unwrap()
41 .insert(pool.id().to_string(), pool);
42}
43
44type PoolStats = (PoolID, usize, usize, usize);
45
46pub fn pools_statistics() -> SmallVec<[PoolStats; MAX_POOLS]> {
49 let registry_lock = match POOL_REGISTRY.get() {
51 Some(lock) => lock.lock().unwrap(),
52 None => return SmallVec::new(), };
54 let mut result = SmallVec::with_capacity(MAX_POOLS);
55 for pool in registry_lock.values() {
56 result.push((
57 pool.id(),
58 pool.space_left(),
59 pool.total_size(),
60 pool.buffer_size(),
61 ));
62 }
63 result
64}
65
66pub trait ElementType: Default + Sized + Copy + Debug + Unpin + Send + Sync {
68 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError>;
69 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError>;
70}
71
72impl<T> ElementType for T
74where
75 T: Default + Sized + Copy + Debug + Unpin + Send + Sync,
76 T: Encode,
77 T: Decode<()>,
78{
79 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
80 self.encode(encoder)
81 }
82
83 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
84 Self::decode(decoder)
85 }
86}
87
88pub trait ArrayLike: Deref<Target = [Self::Element]> + DerefMut + Debug + Sync + Send {
89 type Element: ElementType;
90}
91
92pub enum CuHandleInner<T: Debug> {
96 Pooled(ReusableOwned<T>),
97 Detached(T), }
99
100impl<T> Debug for CuHandleInner<T>
101where
102 T: Debug,
103{
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match self {
106 CuHandleInner::Pooled(r) => {
107 write!(f, "Pooled: {:?}", r.deref())
108 }
109 CuHandleInner::Detached(r) => write!(f, "Detached: {r:?}"),
110 }
111 }
112}
113
114impl<T: ArrayLike> Deref for CuHandleInner<T> {
115 type Target = [T::Element];
116
117 fn deref(&self) -> &Self::Target {
118 match self {
119 CuHandleInner::Pooled(pooled) => pooled,
120 CuHandleInner::Detached(detached) => detached,
121 }
122 }
123}
124
125impl<T: ArrayLike> DerefMut for CuHandleInner<T> {
126 fn deref_mut(&mut self) -> &mut Self::Target {
127 match self {
128 CuHandleInner::Pooled(pooled) => pooled.deref_mut(),
129 CuHandleInner::Detached(detached) => detached,
130 }
131 }
132}
133
134#[derive(Clone, Debug)]
136pub struct CuHandle<T: ArrayLike>(Arc<Mutex<CuHandleInner<T>>>);
137
138impl<T: ArrayLike> Deref for CuHandle<T> {
139 type Target = Arc<Mutex<CuHandleInner<T>>>;
140
141 fn deref(&self) -> &Self::Target {
142 &self.0
143 }
144}
145
146impl<T: ArrayLike> CuHandle<T> {
147 pub fn new_detached(inner: T) -> Self {
149 CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(inner))))
150 }
151
152 pub fn with_inner<R>(&self, f: impl FnOnce(&CuHandleInner<T>) -> R) -> R {
154 let lock = self.lock().unwrap();
155 f(&*lock)
156 }
157
158 pub fn with_inner_mut<R>(&self, f: impl FnOnce(&mut CuHandleInner<T>) -> R) -> R {
160 let mut lock = self.lock().unwrap();
161 f(&mut *lock)
162 }
163}
164
165impl<T: ArrayLike + Encode> Encode for CuHandle<T>
166where
167 <T as ArrayLike>::Element: 'static,
168{
169 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
170 let inner = self.lock().unwrap();
171 match inner.deref() {
172 CuHandleInner::Pooled(pooled) => pooled.deref().encode(encoder),
173 CuHandleInner::Detached(detached) => detached.encode(encoder),
174 }
175 }
176}
177
178impl<T: ArrayLike> Default for CuHandle<T> {
179 fn default() -> Self {
180 panic!("Cannot create a default CuHandle")
181 }
182}
183
184impl<U: ElementType + Decode<()> + 'static> Decode<()> for CuHandle<Vec<U>> {
185 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
186 let vec: Vec<U> = Vec::decode(decoder)?;
187 Ok(CuHandle(Arc::new(Mutex::new(CuHandleInner::Detached(vec)))))
188 }
189}
190
191pub trait CuPool<T: ArrayLike>: PoolMonitor {
194 fn acquire(&self) -> Option<CuHandle<T>>;
196
197 fn copy_from<O>(&self, from: &mut CuHandle<O>) -> CuHandle<T>
199 where
200 O: ArrayLike<Element = T::Element>;
201}
202
203pub trait DeviceCuPool<T: ArrayLike>: CuPool<T> {
205 fn copy_to_host_pool<O>(
208 &self,
209 from_device_handle: &CuHandle<T>,
210 to_host_handle: &mut CuHandle<O>,
211 ) -> CuResult<()>
212 where
213 O: ArrayLike<Element = T::Element>;
214}
215
216pub struct CuHostMemoryPool<T> {
218 id: PoolID,
221 pool: Arc<Pool<T>>,
222 size: usize,
223 buffer_size: usize,
224}
225
226impl<T: ArrayLike + 'static> CuHostMemoryPool<T> {
227 pub fn new<F>(id: &str, size: usize, buffer_initializer: F) -> CuResult<Arc<Self>>
228 where
229 F: Fn() -> T,
230 {
231 let pool = Arc::new(Pool::new(size, buffer_initializer));
232 let buffer_size = pool.try_pull().unwrap().len() * size_of::<T::Element>();
233
234 let og = Self {
235 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
236 pool,
237 size,
238 buffer_size,
239 };
240 let og = Arc::new(og);
241 register_pool(og.clone());
242 Ok(og)
243 }
244}
245
246impl<T: ArrayLike> PoolMonitor for CuHostMemoryPool<T> {
247 fn id(&self) -> PoolID {
248 self.id
249 }
250
251 fn space_left(&self) -> usize {
252 self.pool.len()
253 }
254
255 fn total_size(&self) -> usize {
256 self.size
257 }
258
259 fn buffer_size(&self) -> usize {
260 self.buffer_size
261 }
262}
263
264impl<T: ArrayLike> CuPool<T> for CuHostMemoryPool<T> {
265 fn acquire(&self) -> Option<CuHandle<T>> {
266 let owned_object = self.pool.try_pull_owned(); owned_object.map(|reusable| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(reusable)))))
269 }
270
271 fn copy_from<O: ArrayLike<Element = T::Element>>(&self, from: &mut CuHandle<O>) -> CuHandle<T> {
272 let to_handle = self.acquire().expect("No available buffers in the pool");
273
274 match from.lock().unwrap().deref() {
275 CuHandleInner::Detached(source) => match to_handle.lock().unwrap().deref_mut() {
276 CuHandleInner::Detached(destination) => {
277 destination.copy_from_slice(source);
278 }
279 CuHandleInner::Pooled(destination) => {
280 destination.copy_from_slice(source);
281 }
282 },
283 CuHandleInner::Pooled(source) => match to_handle.lock().unwrap().deref_mut() {
284 CuHandleInner::Detached(destination) => {
285 destination.copy_from_slice(source);
286 }
287 CuHandleInner::Pooled(destination) => {
288 destination.copy_from_slice(source);
289 }
290 },
291 }
292 to_handle
293 }
294}
295
296impl<E: ElementType + 'static> ArrayLike for Vec<E> {
297 type Element = E;
298}
299
300#[cfg(all(feature = "cuda", not(target_os = "macos")))]
301mod cuda {
302 use super::*;
303 use cu29_traits::CuError;
304 use cudarc::driver::{
305 CudaContext, CudaSlice, CudaStream, DeviceRepr, HostSlice, SyncOnDrop, ValidAsZeroBits,
306 };
307 use std::sync::Arc;
308
309 #[derive(Debug)]
310 pub struct CudaSliceWrapper<E>(CudaSlice<E>);
311
312 impl<E> Deref for CudaSliceWrapper<E>
313 where
314 E: ElementType,
315 {
316 type Target = [E];
317
318 fn deref(&self) -> &Self::Target {
319 panic!("You need to copy data to host memory pool before accessing it.");
321 }
322 }
323
324 impl<E> DerefMut for CudaSliceWrapper<E>
325 where
326 E: ElementType,
327 {
328 fn deref_mut(&mut self) -> &mut Self::Target {
329 panic!("You need to copy data to host memory pool before accessing it.");
330 }
331 }
332
333 impl<E: ElementType> ArrayLike for CudaSliceWrapper<E> {
334 type Element = E;
335 }
336
337 impl<E> CudaSliceWrapper<E> {
338 pub fn as_cuda_slice(&self) -> &CudaSlice<E> {
339 &self.0
340 }
341
342 pub fn as_cuda_slice_mut(&mut self) -> &mut CudaSlice<E> {
343 &mut self.0
344 }
345 }
346
347 pub struct HostSliceWrapper<'a, T: ArrayLike> {
349 inner: &'a T,
350 }
351
352 impl<T: ArrayLike> HostSlice<T::Element> for HostSliceWrapper<'_, T> {
353 fn len(&self) -> usize {
354 self.inner.len()
355 }
356
357 unsafe fn stream_synced_slice<'b>(
358 &'b self,
359 stream: &'b CudaStream,
360 ) -> (&'b [T::Element], SyncOnDrop<'b>) {
361 (self.inner.deref(), SyncOnDrop::sync_stream(stream))
362 }
363
364 unsafe fn stream_synced_mut_slice<'b>(
365 &'b mut self,
366 _stream: &'b CudaStream,
367 ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
368 panic!("Cannot get mutable reference from immutable wrapper")
369 }
370 }
371
372 pub struct HostSliceMutWrapper<'a, T: ArrayLike> {
374 inner: &'a mut T,
375 }
376
377 impl<T: ArrayLike> HostSlice<T::Element> for HostSliceMutWrapper<'_, T> {
378 fn len(&self) -> usize {
379 self.inner.len()
380 }
381
382 unsafe fn stream_synced_slice<'b>(
383 &'b self,
384 stream: &'b CudaStream,
385 ) -> (&'b [T::Element], SyncOnDrop<'b>) {
386 (self.inner.deref(), SyncOnDrop::sync_stream(stream))
387 }
388
389 unsafe fn stream_synced_mut_slice<'b>(
390 &'b mut self,
391 stream: &'b CudaStream,
392 ) -> (&'b mut [T::Element], SyncOnDrop<'b>) {
393 (self.inner.deref_mut(), SyncOnDrop::sync_stream(stream))
394 }
395 }
396
397 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
399 fn get_host_slice_wrapper<O: ArrayLike<Element = E>>(
401 handle_inner: &CuHandleInner<O>,
402 ) -> HostSliceWrapper<'_, O> {
403 match handle_inner {
404 CuHandleInner::Pooled(pooled) => HostSliceWrapper { inner: pooled },
405 CuHandleInner::Detached(detached) => HostSliceWrapper { inner: detached },
406 }
407 }
408
409 fn get_host_slice_mut_wrapper<O: ArrayLike<Element = E>>(
411 handle_inner: &mut CuHandleInner<O>,
412 ) -> HostSliceMutWrapper<'_, O> {
413 match handle_inner {
414 CuHandleInner::Pooled(pooled) => HostSliceMutWrapper { inner: pooled },
415 CuHandleInner::Detached(detached) => HostSliceMutWrapper { inner: detached },
416 }
417 }
418 }
419 pub struct CuCudaPool<E>
421 where
422 E: ElementType + ValidAsZeroBits + DeviceRepr + Unpin,
423 {
424 id: PoolID,
425 stream: Arc<CudaStream>,
426 pool: Arc<Pool<CudaSliceWrapper<E>>>,
427 nb_buffers: usize,
428 nb_element_per_buffer: usize,
429 }
430
431 impl<E: ElementType + ValidAsZeroBits + DeviceRepr> CuCudaPool<E> {
432 #[allow(dead_code)]
433 pub fn new(
434 id: &'static str,
435 ctx: Arc<CudaContext>,
436 nb_buffers: usize,
437 nb_element_per_buffer: usize,
438 ) -> CuResult<Self> {
439 let stream = ctx.default_stream();
440 let pool = (0..nb_buffers)
441 .map(|_| {
442 stream
443 .alloc_zeros(nb_element_per_buffer)
444 .map(CudaSliceWrapper)
445 .map_err(|_| "Failed to allocate device memory")
446 })
447 .collect::<Result<Vec<_>, _>>()?;
448
449 Ok(Self {
450 id: PoolID::from(id).map_err(|_| "Failed to create PoolID")?,
451 stream,
452 pool: Arc::new(Pool::from_vec(pool)),
453 nb_buffers,
454 nb_element_per_buffer,
455 })
456 }
457 }
458
459 impl<E> PoolMonitor for CuCudaPool<E>
460 where
461 E: DeviceRepr + ElementType + ValidAsZeroBits,
462 {
463 fn id(&self) -> PoolID {
464 self.id
465 }
466
467 fn space_left(&self) -> usize {
468 self.pool.len()
469 }
470
471 fn total_size(&self) -> usize {
472 self.nb_buffers
473 }
474
475 fn buffer_size(&self) -> usize {
476 self.nb_element_per_buffer * size_of::<E>()
477 }
478 }
479
480 impl<E> CuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
481 where
482 E: DeviceRepr + ElementType + ValidAsZeroBits,
483 {
484 fn acquire(&self) -> Option<CuHandle<CudaSliceWrapper<E>>> {
485 self.pool
486 .try_pull_owned()
487 .map(|x| CuHandle(Arc::new(Mutex::new(CuHandleInner::Pooled(x)))))
488 }
489
490 fn copy_from<O>(&self, from_handle: &mut CuHandle<O>) -> CuHandle<CudaSliceWrapper<E>>
491 where
492 O: ArrayLike<Element = E>,
493 {
494 let to_handle = self.acquire().expect("No available buffers in the pool");
495
496 {
497 let from_lock = from_handle.lock().unwrap();
498 let mut to_lock = to_handle.lock().unwrap();
499
500 match &mut *to_lock {
501 CuHandleInner::Detached(CudaSliceWrapper(to)) => {
502 let wrapper = Self::get_host_slice_wrapper(&*from_lock);
503 self.stream
504 .memcpy_htod(&wrapper, to)
505 .expect("Failed to copy data to device");
506 }
507 CuHandleInner::Pooled(to) => {
508 let wrapper = Self::get_host_slice_wrapper(&*from_lock);
509 self.stream
510 .memcpy_htod(&wrapper, to.as_cuda_slice_mut())
511 .expect("Failed to copy data to device");
512 }
513 }
514 } to_handle }
517 }
518
519 impl<E> DeviceCuPool<CudaSliceWrapper<E>> for CuCudaPool<E>
520 where
521 E: ElementType + ValidAsZeroBits + DeviceRepr,
522 {
523 fn copy_to_host_pool<O>(
525 &self,
526 device_handle: &CuHandle<CudaSliceWrapper<E>>,
527 host_handle: &mut CuHandle<O>,
528 ) -> Result<(), CuError>
529 where
530 O: ArrayLike<Element = E>,
531 {
532 let device_lock = device_handle.lock().unwrap();
533 let mut host_lock = host_handle.lock().unwrap();
534 let src = match &*device_lock {
535 CuHandleInner::Pooled(source) => source.as_cuda_slice(),
536 CuHandleInner::Detached(source) => source.as_cuda_slice(),
537 };
538 let mut wrapper = Self::get_host_slice_mut_wrapper(&mut *host_lock);
539 self.stream
540 .memcpy_dtoh(src, &mut wrapper)
541 .expect("Failed to copy data from device to host");
542 Ok(())
543 }
544 }
545}
546
547#[derive(Debug)]
548pub struct AlignedBuffer<E: ElementType> {
550 ptr: *mut E,
551 size: usize,
552 layout: Layout,
553}
554
555impl<E: ElementType> AlignedBuffer<E> {
556 pub fn new(num_elements: usize, alignment: usize) -> Self {
557 let layout = Layout::from_size_align(num_elements * size_of::<E>(), alignment).unwrap();
558 let ptr = unsafe { alloc(layout) as *mut E };
559 if ptr.is_null() {
560 panic!("Failed to allocate memory");
561 }
562 Self {
563 ptr,
564 size: num_elements,
565 layout,
566 }
567 }
568}
569
570impl<E: ElementType> Deref for AlignedBuffer<E> {
571 type Target = [E];
572
573 fn deref(&self) -> &Self::Target {
574 unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
575 }
576}
577
578impl<E: ElementType> DerefMut for AlignedBuffer<E> {
579 fn deref_mut(&mut self) -> &mut Self::Target {
580 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
581 }
582}
583
584impl<E: ElementType> Drop for AlignedBuffer<E> {
585 fn drop(&mut self) {
586 if !self.ptr.is_null() {
587 unsafe {
588 dealloc(self.ptr as *mut u8, self.layout);
589 }
590 }
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597
598 #[test]
599 fn test_pool() {
600 use std::cell::RefCell;
601 let objs = RefCell::new(vec![vec![1], vec![2], vec![3]]);
602 let holding = objs.borrow().clone();
603 let objs_as_slices = holding.iter().map(|x| x.as_slice()).collect::<Vec<_>>();
604 let pool = CuHostMemoryPool::new("mytestcudapool", 3, || objs.borrow_mut().pop().unwrap())
605 .unwrap();
606
607 let obj1 = pool.acquire().unwrap();
608 {
609 let obj2 = pool.acquire().unwrap();
610 assert!(objs_as_slices.contains(&obj1.lock().unwrap().deref().deref()));
611 assert!(objs_as_slices.contains(&obj2.lock().unwrap().deref().deref()));
612 assert_eq!(pool.space_left(), 1);
613 }
614 assert_eq!(pool.space_left(), 2);
615
616 let obj3 = pool.acquire().unwrap();
617 assert!(objs_as_slices.contains(&obj3.lock().unwrap().deref().deref()));
618
619 assert_eq!(pool.space_left(), 1);
620
621 let _obj4 = pool.acquire().unwrap();
622 assert_eq!(pool.space_left(), 0);
623
624 let obj5 = pool.acquire();
625 assert!(obj5.is_none());
626 }
627
628 #[cfg(all(feature = "cuda", has_nvidia_gpu))]
629 #[test]
630 fn test_cuda_pool() {
631 use crate::pool::cuda::CuCudaPool;
632 use cudarc::driver::CudaContext;
633 let ctx = CudaContext::new(0).unwrap();
634 let pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
635
636 let _obj1 = pool.acquire().unwrap();
637
638 {
639 let _obj2 = pool.acquire().unwrap();
640 assert_eq!(pool.space_left(), 1);
641 }
642 assert_eq!(pool.space_left(), 2);
643
644 let _obj3 = pool.acquire().unwrap();
645
646 assert_eq!(pool.space_left(), 1);
647
648 let _obj4 = pool.acquire().unwrap();
649 assert_eq!(pool.space_left(), 0);
650
651 let obj5 = pool.acquire();
652 assert!(obj5.is_none());
653 }
654
655 #[cfg(all(feature = "cuda", has_nvidia_gpu))]
656 #[test]
657 fn test_copy_roundtrip() {
658 use crate::pool::cuda::CuCudaPool;
659 use cudarc::driver::CudaContext;
660 let ctx = CudaContext::new(0).unwrap();
661 let host_pool = CuHostMemoryPool::new("mytesthostpool", 3, || vec![0.0; 1]).unwrap();
662 let cuda_pool = CuCudaPool::<f32>::new("mytestcudapool", ctx, 3, 1).unwrap();
663
664 let cuda_handle = {
665 let mut initial_handle = host_pool.acquire().unwrap();
666 {
667 let mut inner_initial_handle = initial_handle.lock().unwrap();
668 if let CuHandleInner::Pooled(ref mut pooled) = *inner_initial_handle {
669 pooled[0] = 42.0;
670 } else {
671 panic!();
672 }
673 }
674
675 cuda_pool.copy_from(&mut initial_handle)
677 };
678
679 let mut final_handle = host_pool.acquire().unwrap();
681 cuda_pool
682 .copy_to_host_pool(&cuda_handle, &mut final_handle)
683 .unwrap();
684
685 let value = final_handle.lock().unwrap().deref().deref()[0];
686 assert_eq!(value, 42.0);
687 }
688}