use arrayvec::ArrayVec;
use bincode::{Decode, Encode};
#[derive(Clone, Debug, Default)]
pub struct CuArray<T, const N: usize> {
inner: ArrayVec<T, N>,
}
impl<T, const N: usize> CuArray<T, N> {
pub fn new() -> Self {
Self {
inner: ArrayVec::new(),
}
}
pub fn fill_from_iter<I>(&mut self, iter: I)
where
I: IntoIterator<Item = T>,
{
self.inner.clear(); for value in iter.into_iter().take(N) {
self.inner.push(value);
}
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.len() == 0
}
pub fn as_slice(&self) -> &[T] {
&self.inner
}
pub fn capacity(&self) -> usize {
N
}
}
impl<T, const N: usize> Encode for CuArray<T, N>
where
T: Encode,
{
fn encode<E: bincode::enc::Encoder>(
&self,
encoder: &mut E,
) -> Result<(), bincode::error::EncodeError> {
(self.inner.len() as u32).encode(encoder)?;
for elem in &self.inner {
elem.encode(encoder)?;
}
Ok(())
}
}
impl<T, const N: usize> Decode for CuArray<T, N>
where
T: Decode,
{
fn decode<D: bincode::de::Decoder>(
decoder: &mut D,
) -> Result<Self, bincode::error::DecodeError> {
let len = u32::decode(decoder)? as usize;
if len > N {
return Err(bincode::error::DecodeError::OtherString(format!(
"Decoded length {len} exceeds maximum capacity {N}"
)));
}
let mut inner = ArrayVec::new();
for _ in 0..len {
inner.push(T::decode(decoder)?);
}
Ok(Self { inner })
}
}