1mod format;
2#[cfg(feature = "macro_debug")]
3use format::{highlight_rust_code, rustfmt_generated_code};
4use proc_macro::TokenStream;
5use quote::{format_ident, quote, ToTokens};
6use syn::{parse_macro_input, Data, DeriveInput, Fields, Type};
7
8#[proc_macro_derive(Soa)]
69pub fn derive_soa(input: TokenStream) -> TokenStream {
70 use syn::TypePath;
71
72 let input = parse_macro_input!(input as DeriveInput);
73 let visibility = &input.vis;
74
75 let name = &input.ident;
76 let module_name = format_ident!("{}_soa", name.to_string().to_lowercase());
77 let soa_struct_name = format_ident!("{}Soa", name);
78
79 let data = match &input.data {
80 Data::Struct(data) => data,
81 _ => panic!("Only structs are supported"),
82 };
83 let fields = match &data.fields {
84 Fields::Named(fields) => &fields.named,
85 _ => panic!("Only named fields are supported"),
86 };
87
88 let mut field_names = vec![];
89 let mut field_names_mut = vec![];
90 let mut field_names_range = vec![];
91 let mut field_names_range_mut = vec![];
92 let mut field_types = vec![];
93 let mut unique_imports = vec![];
94 let mut unique_import_names = vec![];
95
96 fn is_primitive(type_name: &str) -> bool {
97 matches!(
98 type_name,
99 "i8" | "i16"
100 | "i32"
101 | "i64"
102 | "i128"
103 | "u8"
104 | "u16"
105 | "u32"
106 | "u64"
107 | "u128"
108 | "f32"
109 | "f64"
110 | "bool"
111 | "char"
112 | "str"
113 | "usize"
114 | "isize"
115 )
116 }
117
118 for field in fields {
119 let field_name = field.ident.as_ref().unwrap();
120 let field_type = &field.ty;
121 field_names.push(field_name);
122 field_names_mut.push(format_ident!("{}_mut", field_name));
123 field_names_range.push(format_ident!("{}_range", field_name));
124 field_names_range_mut.push(format_ident!("{}_range_mut", field_name));
125 field_types.push(field_type);
126
127 if let Type::Path(TypePath { path, .. }) = field_type {
128 let type_name = path.segments.last().unwrap().ident.to_string();
129 let path_str = path.to_token_stream().to_string();
130
131 if !is_primitive(&type_name) && !unique_import_names.contains(&path_str) {
132 unique_imports.push(path.clone());
133 unique_import_names.push(path_str);
134 }
135 }
136 }
137
138 let soa_struct_name_iterator = format_ident!("{}Iterator", name);
139 let field_count = field_names.len() + 1; let iterator = quote! {
142 pub struct #soa_struct_name_iterator<'a, const N: usize> {
143 soa_struct: &'a #soa_struct_name<N>,
144 current: usize,
145 }
146
147 impl<'a, const N: usize> #soa_struct_name_iterator<'a, N> {
148 pub fn new(soa_struct: &'a #soa_struct_name<N>) -> Self {
149 Self {
150 soa_struct,
151 current: 0,
152 }
153 }
154 }
155
156 impl<'a, const N: usize> Iterator for #soa_struct_name_iterator<'a, N> {
157 type Item = super::#name;
158
159 fn next(&mut self) -> Option<Self::Item> {
160 if self.current < self.soa_struct.len {
161 let item = self.soa_struct.get(self.current); self.current += 1;
163 Some(item)
164 } else {
165 None
166 }
167 }
168 }
169 };
170
171 let expanded = quote! {
172 #visibility mod #module_name {
173 use bincode::{Decode, Encode};
174 use bincode::enc::Encoder;
175 use bincode::de::Decoder;
176 use bincode::error::{DecodeError, EncodeError};
177 use serde::Serialize;
178 use serde::Serializer;
179 use serde::ser::SerializeStruct;
180 use std::ops::{Index, IndexMut};
181 #( use super::#unique_imports; )*
182 use core::array::from_fn;
183
184 #[derive(Debug)]
185 #visibility struct #soa_struct_name<const N: usize> {
186 pub len: usize,
187 #(pub #field_names: [#field_types; N], )*
188 }
189
190 impl<const N: usize> #soa_struct_name<N> {
191 pub fn new(default: super::#name) -> Self {
192 Self {
193 #( #field_names: from_fn(|_| default.#field_names.clone()), )*
194 len: 0,
195 }
196 }
197
198 pub fn len(&self) -> usize {
199 self.len
200 }
201
202 pub fn is_empty(&self) -> bool {
203 self.len == 0
204 }
205
206 pub fn push(&mut self, value: super::#name) {
207 if self.len < N {
208 #( self.#field_names[self.len] = value.#field_names.clone(); )*
209 self.len += 1;
210 } else {
211 panic!("Capacity exceeded")
212 }
213 }
214
215 pub fn pop(&mut self) -> Option<super::#name> {
216 if self.len == 0 {
217 None
218 } else {
219 self.len -= 1;
220 Some(super::#name {
221 #( #field_names: self.#field_names[self.len].clone(), )*
222 })
223 }
224 }
225
226 pub fn set(&mut self, index: usize, value: super::#name) {
227 assert!(index < self.len, "Index out of bounds");
228 #( self.#field_names[index] = value.#field_names.clone(); )*
229 }
230
231 pub fn get(&self, index: usize) -> super::#name {
232 assert!(index < self.len, "Index out of bounds");
233 super::#name {
234 #( #field_names: self.#field_names[index].clone(), )*
235 }
236 }
237
238 pub fn apply<F>(&mut self, mut f: F)
239 where
240 F: FnMut(#(#field_types),*) -> (#(#field_types),*)
241 {
242 for _idx in 0..self.len {
244 let result = f(#(self.#field_names[_idx].clone()),*);
245 let (#(#field_names),*) = result;
246 #(
247 self.#field_names[_idx] = #field_names;
248 )*
249 }
250 }
251
252 pub fn iter(&self) -> #soa_struct_name_iterator<N> {
253 #soa_struct_name_iterator::new(self)
254 }
255
256 #(
257 pub fn #field_names(&self) -> &[#field_types] {
258 &self.#field_names
259 }
260
261 pub fn #field_names_mut(&mut self) -> &mut [#field_types] {
262 &mut self.#field_names
263 }
264
265 pub fn #field_names_range(&self, range: std::ops::Range<usize>) -> &[#field_types] {
266 &self.#field_names[range]
267 }
268
269 pub fn #field_names_range_mut(&mut self, range: std::ops::Range<usize>) -> &mut [#field_types] {
270 &mut self.#field_names[range]
271 }
272 )*
273 }
274
275 impl<const N: usize> Encode for #soa_struct_name<N> {
276 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
277 self.len.encode(encoder)?;
278 #( self.#field_names[..self.len].encode(encoder)?; )*
279 Ok(())
280 }
281 }
282
283 impl<const N: usize> Decode<()> for #soa_struct_name<N> {
284 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
285 let mut result = Self::default();
286 result.len = Decode::decode(decoder)?;
287 #(
288 for _idx in 0..result.len {
289 result.#field_names[_idx] = Decode::decode(decoder)?;
290 }
291 )*
292 Ok(result)
293 }
294 }
295
296 impl<const N: usize> Default for #soa_struct_name<N> {
297 fn default() -> Self {
298 Self {
299 #( #field_names: from_fn(|_| #field_types::default()), )*
300 len: 0,
301 }
302 }
303 }
304
305 impl<const N: usize> Clone for #soa_struct_name<N>
306 where
307 #(
308 #field_types: Clone,
309 )*
310 {
311 fn clone(&self) -> Self {
312 Self {
313 #( #field_names: self.#field_names.clone(), )*
314 len: self.len,
315 }
316 }
317 }
318
319 impl<const N: usize> Serialize for #soa_struct_name<N>
320 where
321 #(
322 #field_types: Serialize,
323 )*
324 {
325 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
326 where
327 S: Serializer,
328 {
329 let mut state = serializer.serialize_struct(stringify!(#soa_struct_name), #field_count)?;
330 state.serialize_field("len", &self.len)?;
331 #(
332 state.serialize_field(stringify!(#field_names), &self.#field_names[..self.len])?;
333 )*
334 state.end()
335 }
336 }
337
338 #iterator
339
340 }
341 #visibility use #module_name::#soa_struct_name;
342 };
343
344 let tokens: TokenStream = expanded.into();
345
346 #[cfg(feature = "macro_debug")]
347 {
348 let formatted_code = rustfmt_generated_code(tokens.to_string());
349 eprintln!("\n === Gen. SOA ===\n");
350 eprintln!("{}", highlight_rust_code(formatted_code));
351 eprintln!("\n === === === === === ===\n");
352 }
353 tokens
354}