cu29_soa_derive/
lib.rs

1mod format;
2#[cfg(feature = "macro_debug")]
3use format::{highlight_rust_code, rustfmt_generated_code};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote, ToTokens};
6use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
7
8/// Build a fixed sized SoA (Structure of Arrays) from a struct.
9/// The outputted SoA will be suitable for in place storage in messages and should be
10/// easier for the compiler to vectorize.
11///
12/// for example:
13///
14/// ```ignore
15/// #[derive(Soa)]
16/// struct MyStruct {
17///    a: i32,
18///    b: f32,
19/// }
20/// ```
21///
22/// will generate:
23/// ```ignore
24/// pub struct MyStructSoa<const N: usize> {
25///     pub a: [i32; N],
26///     pub b: [f32; N],
27/// }
28/// ```
29///
30/// You can then use the generated struct to store multiple
31/// instances of the original struct in an SoA format.
32///
33/// ```ignore
34/// // makes an SOA with a default value
35/// let soa1: MyStructSoa<8> = XyzSoa::new(MyStruct{ a: 1, b: 2.3 });
36/// ```
37///
38/// Then you can access the fields of the SoA as slices:
39/// ```ignore
40/// let a = soa1.a();
41/// let b = soa1.b();
42/// ```
43///
44/// You can also access a range of the fields:
45/// ```ignore
46/// let a = soa1.a_range(0..4);
47/// let b = soa1.b_range(0..4);
48/// ```
49///
50/// You can also modify the fields of the SoA:
51/// ```ignore
52/// soa1.a_mut()[0] = 42;
53/// soa1.b_mut()[0] = 42.0;
54/// ```
55///
56/// You can also modify a range of the fields:
57/// ```ignore
58/// soa1.a_range_mut(0..4)[0] = 42;
59/// soa1.b_range_mut(0..4)[0] = 42.0;
60/// ```
61///
62/// You can also apply a function to all the fields of the SoA:
63/// ```ignore
64/// soa1.apply(|a, b| {
65///    (a + 1, b + 1.0)
66/// });
67/// ```
68#[proc_macro_derive(Soa)]
69pub fn derive_soa(input: TokenStream) -> TokenStream {
70    use syn::TypePath;
71
72    let input = parse_macro_input!(input as DeriveInput);
73    let visibility = &input.vis;
74
75    let name = &input.ident;
76    let module_name = format_ident!("{}_soa", name.to_string().to_lowercase());
77    let soa_struct_name = format_ident!("{}Soa", name);
78
79    let data = match &input.data {
80        Data::Struct(data) => data,
81        _ => panic!("Only structs are supported"),
82    };
83    let fields = match &data.fields {
84        Fields::Named(fields) => &fields.named,
85        _ => panic!("Only named fields are supported"),
86    };
87
88    let mut field_names = vec![];
89    let mut field_names_mut = vec![];
90    let mut field_names_range = vec![];
91    let mut field_names_range_mut = vec![];
92    let mut field_types = vec![];
93    let mut unique_imports = vec![];
94    let mut unique_import_names = vec![];
95
96    fn is_primitive(type_name: &str) -> bool {
97        matches!(
98            type_name,
99            "i8" | "i16"
100                | "i32"
101                | "i64"
102                | "i128"
103                | "u8"
104                | "u16"
105                | "u32"
106                | "u64"
107                | "u128"
108                | "f32"
109                | "f64"
110                | "bool"
111                | "char"
112                | "str"
113                | "usize"
114                | "isize"
115        )
116    }
117
118    for field in fields {
119        let field_name = field.ident.as_ref().unwrap();
120        let field_type = &field.ty;
121        field_names.push(field_name);
122        field_names_mut.push(format_ident!("{}_mut", field_name));
123        field_names_range.push(format_ident!("{}_range", field_name));
124        field_names_range_mut.push(format_ident!("{}_range_mut", field_name));
125        field_types.push(field_type);
126
127        if let Type::Path(TypePath { path, .. }) = field_type {
128            let type_name = path.segments.last().unwrap().ident.to_string();
129            let path_str = path.to_token_stream().to_string();
130
131            if !is_primitive(&type_name) && !unique_import_names.contains(&path_str) {
132                unique_imports.push(path.clone());
133                unique_import_names.push(path_str);
134            }
135        }
136    }
137
138    let soa_struct_name_iterator = format_ident!("{}Iterator", name);
139    let field_count = field_names.len() + 1; // +1 for the len field
140
141    let iterator = quote! {
142        pub struct #soa_struct_name_iterator<'a, const N: usize> {
143            soa_struct: &'a #soa_struct_name<N>,
144            current: usize,
145        }
146
147        impl<'a, const N: usize> #soa_struct_name_iterator<'a, N> {
148            pub fn new(soa_struct: &'a #soa_struct_name<N>) -> Self {
149                Self {
150                    soa_struct,
151                    current: 0,
152                }
153            }
154        }
155
156        impl<'a, const N: usize> Iterator for #soa_struct_name_iterator<'a, N> {
157            type Item = super::#name;
158
159            fn next(&mut self) -> Option<Self::Item> {
160                if self.current < self.soa_struct.len {
161                    let item = self.soa_struct.get(self.current); // Reuse `get` method
162                    self.current += 1;
163                    Some(item)
164                } else {
165                    None
166                }
167            }
168        }
169    };
170
171    let expanded = quote! {
172        #visibility mod #module_name {
173            use bincode::{Decode, Encode};
174            use bincode::enc::Encoder;
175            use bincode::de::Decoder;
176            use bincode::error::{DecodeError, EncodeError};
177            use serde::Serialize;
178            use serde::Serializer;
179            use serde::ser::SerializeStruct;
180            use std::ops::{Index, IndexMut};
181            #( use super::#unique_imports; )*
182            use core::array::from_fn;
183
184            #[derive(Debug)]
185            #visibility struct #soa_struct_name<const N: usize> {
186                pub len: usize,
187                #(pub #field_names: [#field_types; N], )*
188            }
189
190            impl<const N: usize> #soa_struct_name<N> {
191                pub fn new(default: super::#name) -> Self {
192                    Self {
193                        #( #field_names: from_fn(|_| default.#field_names.clone()), )*
194                        len: 0,
195                    }
196                }
197
198                pub fn len(&self) -> usize {
199                    self.len
200                }
201
202                pub fn is_empty(&self) -> bool {
203                    self.len == 0
204                }
205
206                pub fn push(&mut self, value: super::#name) {
207                    if self.len < N {
208                        #( self.#field_names[self.len] = value.#field_names.clone(); )*
209                        self.len += 1;
210                    } else {
211                        panic!("Capacity exceeded")
212                    }
213                }
214
215                pub fn pop(&mut self) -> Option<super::#name> {
216                    if self.len == 0 {
217                        None
218                    } else {
219                        self.len -= 1;
220                        Some(super::#name {
221                            #( #field_names: self.#field_names[self.len].clone(), )*
222                        })
223                    }
224                }
225
226                pub fn set(&mut self, index: usize, value: super::#name) {
227                    assert!(index < self.len, "Index out of bounds");
228                    #( self.#field_names[index] = value.#field_names.clone(); )*
229                }
230
231                pub fn get(&self, index: usize) -> super::#name {
232                    assert!(index < self.len, "Index out of bounds");
233                    super::#name {
234                        #( #field_names: self.#field_names[index].clone(), )*
235                    }
236                }
237
238                pub fn apply<F>(&mut self, mut f: F)
239                where
240                    F: FnMut(#(#field_types),*) -> (#(#field_types),*)
241                {
242                    // don't use something common like i here.
243                    for _idx in 0..self.len {
244                        let result = f(#(self.#field_names[_idx].clone()),*);
245                        let (#(#field_names),*) = result;
246                        #(
247                            self.#field_names[_idx] = #field_names;
248                        )*
249                    }
250                }
251
252                pub fn iter(&self) -> #soa_struct_name_iterator<N> {
253                    #soa_struct_name_iterator::new(self)
254                }
255
256                #(
257                    pub fn #field_names(&self) -> &[#field_types] {
258                        &self.#field_names
259                    }
260
261                    pub fn #field_names_mut(&mut self) -> &mut [#field_types] {
262                        &mut self.#field_names
263                    }
264
265                    pub fn #field_names_range(&self, range: std::ops::Range<usize>) -> &[#field_types] {
266                        &self.#field_names[range]
267                    }
268
269                    pub fn #field_names_range_mut(&mut self, range: std::ops::Range<usize>) -> &mut [#field_types] {
270                        &mut self.#field_names[range]
271                    }
272                )*
273            }
274
275            impl<const N: usize> Encode for #soa_struct_name<N> {
276                fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
277                    self.len.encode(encoder)?;
278                    #( self.#field_names[..self.len].encode(encoder)?; )*
279                    Ok(())
280                }
281            }
282
283            impl<const N: usize> Decode<()> for #soa_struct_name<N> {
284                fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
285                    let mut result = Self::default();
286                    result.len = Decode::decode(decoder)?;
287                    #(
288                        for _idx in 0..result.len {
289                            result.#field_names[_idx] = Decode::decode(decoder)?;
290                        }
291                    )*
292                    Ok(result)
293                }
294            }
295
296            impl<const N: usize> Default for #soa_struct_name<N> {
297                fn default() -> Self {
298                    Self {
299                        #( #field_names: from_fn(|_| #field_types::default()), )*
300                        len: 0,
301                    }
302                }
303            }
304
305            impl<const N: usize> Clone for #soa_struct_name<N>
306            where
307                #(
308                    #field_types: Clone,
309                )*
310            {
311                fn clone(&self) -> Self {
312                    Self {
313                        #( #field_names: self.#field_names.clone(), )*
314                        len: self.len,
315                    }
316                }
317            }
318
319            impl<const N: usize> Serialize for #soa_struct_name<N>
320            where
321                #(
322                    #field_types: Serialize,
323                )*
324            {
325                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
326                where
327                    S: Serializer,
328                {
329                    let mut state = serializer.serialize_struct(stringify!(#soa_struct_name), #field_count)?;
330                    state.serialize_field("len", &self.len)?;
331                    #(
332                        state.serialize_field(stringify!(#field_names), &self.#field_names[..self.len])?;
333                    )*
334                    state.end()
335                }
336            }
337
338            #iterator
339
340        }
341        #visibility use #module_name::#soa_struct_name;
342    };
343
344    let tokens: TokenStream = expanded.into();
345
346    #[cfg(feature = "macro_debug")]
347    {
348        let formatted_code = rustfmt_generated_code(tokens.to_string());
349        eprintln!("\n     ===    Gen. SOA     ===\n");
350        eprintln!("{}", highlight_rust_code(formatted_code));
351        eprintln!("\n     === === === === === ===\n");
352    }
353    tokens
354}