1use proc_macro::TokenStream;
2use quote::{ToTokens, format_ident, quote};
3use syn::{Attribute, Data, DeriveInput, Fields, Path, PathArguments, Type, parse_macro_input};
4
5#[proc_macro_derive(Soa, attributes(soa))]
128pub fn derive_soa(input: TokenStream) -> TokenStream {
129 use syn::TypePath;
130
131 let input = parse_macro_input!(input as DeriveInput);
132 let visibility = &input.vis;
133 let derive_reflect = input
134 .attrs
135 .iter()
136 .any(|attr| attr.path().is_ident("reflect"));
137 let reflect_import = if derive_reflect {
138 quote!(
139 use super::{bevy_reflect, Reflect};
140 )
141 } else {
142 quote!()
143 };
144 let soa_reflect_attrs = if derive_reflect {
145 quote! {
146 #[derive(Reflect)]
147 #[reflect(from_reflect = false)]
148 }
149 } else {
150 quote!()
151 };
152
153 let name = &input.ident;
154 let module_name = format_ident!("{}_soa", name.to_string().to_lowercase());
155 let soa_struct_name = format_ident!("{}Soa", name);
156 let soa_storage_name = format_ident!("{}SoaStorage", name);
157 let soa_storage_wire_name = format_ident!("{}SoaStorageWire", name);
158 let soa_storage_serde_name = format_ident!("{}SoaStorageSerde", name);
159 let soa_struct_wire_name = format_ident!("{}SoaWire", name);
160
161 let data = match &input.data {
162 Data::Struct(data) => data,
163 _ => {
164 return syn::Error::new_spanned(&input, "Only structs are supported")
165 .to_compile_error()
166 .into();
167 }
168 };
169 let fields = match &data.fields {
170 Fields::Named(fields) => &fields.named,
171 _ => {
172 return syn::Error::new_spanned(&data.fields, "Only named fields are supported")
173 .to_compile_error()
174 .into();
175 }
176 };
177
178 struct FieldInfo {
179 name: syn::Ident,
180 ty: syn::Type,
181 nested: bool,
182 storage_path: Option<syn::Path>,
183 storage_wire_path: Option<syn::Path>,
184 }
185
186 let mut field_infos = Vec::new();
187 let mut unique_imports = vec![];
188 let mut unique_import_names = vec![];
189
190 fn is_primitive(type_name: &str) -> bool {
191 matches!(
192 type_name,
193 "i8" | "i16"
194 | "i32"
195 | "i64"
196 | "i128"
197 | "u8"
198 | "u16"
199 | "u32"
200 | "u64"
201 | "u128"
202 | "f32"
203 | "f64"
204 | "bool"
205 | "char"
206 | "str"
207 | "usize"
208 | "isize"
209 )
210 }
211
212 fn parse_soa_nested(attrs: &[Attribute]) -> Result<bool, syn::Error> {
213 let mut nested = false;
214 for attr in attrs {
215 if attr.path().is_ident("soa") {
216 attr.parse_nested_meta(|meta| {
217 if meta.path.is_ident("nested") {
218 nested = true;
219 Ok(())
220 } else {
221 Err(meta.error("unsupported #[soa] option, expected `nested`"))
222 }
223 })?;
224 }
225 }
226 Ok(nested)
227 }
228
229 fn qualify_path(path: &Path) -> Path {
230 if path.leading_colon.is_some() {
231 return path.clone();
232 }
233 if let Some(first) = path.segments.first() {
234 if first.ident == "crate" {
235 return path.clone();
236 }
237 if first.ident == "self" {
238 let mut path = path.clone();
239 if let Some(segment) = path.segments.first_mut() {
240 segment.ident = format_ident!("super");
241 }
242 return path;
243 }
244 if first.ident == "std" || first.ident == "core" || first.ident == "alloc" {
245 return path.clone();
246 }
247 }
248 syn::parse_quote!(super::#path)
249 }
250
251 fn storage_paths(field_type: &Type) -> Result<(Path, Path), syn::Error> {
252 let Type::Path(type_path) = field_type else {
253 return Err(syn::Error::new_spanned(
254 field_type,
255 "expected a path type for #[soa(nested)]",
256 ));
257 };
258
259 let mut storage_path = type_path.path.clone();
260 let last_segment = storage_path
261 .segments
262 .last_mut()
263 .ok_or_else(|| syn::Error::new_spanned(field_type, "expected a non-empty type path"))?;
264 if !matches!(last_segment.arguments, PathArguments::None) {
265 return Err(syn::Error::new_spanned(
266 field_type,
267 "generic types are not supported with #[soa(nested)]",
268 ));
269 }
270 let base_ident = last_segment.ident.clone();
271 last_segment.ident = format_ident!("{}SoaStorage", base_ident);
272
273 let mut storage_wire_path = type_path.path.clone();
274 let last_wire_segment = storage_wire_path
275 .segments
276 .last_mut()
277 .ok_or_else(|| syn::Error::new_spanned(field_type, "expected a non-empty type path"))?;
278 last_wire_segment.ident = format_ident!("{}SoaStorageWire", base_ident);
279
280 Ok((
281 qualify_path(&storage_path),
282 qualify_path(&storage_wire_path),
283 ))
284 }
285
286 for field in fields {
287 let field_name = match &field.ident {
288 Some(ident) => ident.clone(),
289 None => {
290 return syn::Error::new_spanned(field, "Only named fields are supported")
291 .to_compile_error()
292 .into();
293 }
294 };
295 let field_type = field.ty.clone();
296 let nested = match parse_soa_nested(&field.attrs) {
297 Ok(value) => value,
298 Err(err) => return err.to_compile_error().into(),
299 };
300 let (storage_path, storage_wire_path) = if nested {
301 match storage_paths(&field_type) {
302 Ok((storage_path, storage_wire_path)) => {
303 (Some(storage_path), Some(storage_wire_path))
304 }
305 Err(err) => return err.to_compile_error().into(),
306 }
307 } else {
308 (None, None)
309 };
310
311 if let Type::Path(TypePath { path, .. }) = &field_type {
312 let Some(last_segment) = path.segments.last() else {
313 return syn::Error::new_spanned(path, "expected a non-empty type path")
314 .to_compile_error()
315 .into();
316 };
317 let type_name = last_segment.ident.to_string();
318 let path_str = path.to_token_stream().to_string();
319
320 if !is_primitive(&type_name) && !unique_import_names.contains(&path_str) {
321 unique_imports.push(path.clone());
322 unique_import_names.push(path_str);
323 }
324 }
325
326 field_infos.push(FieldInfo {
327 name: field_name,
328 ty: field_type,
329 nested,
330 storage_path,
331 storage_wire_path,
332 });
333 }
334
335 let field_names: Vec<_> = field_infos.iter().map(|info| &info.name).collect();
336 let field_types: Vec<_> = field_infos.iter().map(|info| &info.ty).collect();
337
338 let soa_struct_name_iterator = format_ident!("{}Iterator", name);
339 let storage_field_count = field_names.len();
340 let field_count = storage_field_count + 1; let iterator = quote! {
343 pub struct #soa_struct_name_iterator<'a, const N: usize> {
344 soa_struct: &'a #soa_struct_name<N>,
345 current: usize,
346 }
347
348 impl<'a, const N: usize> #soa_struct_name_iterator<'a, N> {
349 pub fn new(soa_struct: &'a #soa_struct_name<N>) -> Self {
350 Self {
351 soa_struct,
352 current: 0,
353 }
354 }
355 }
356
357 impl<'a, const N: usize> Iterator for #soa_struct_name_iterator<'a, N> {
358 type Item = super::#name;
359
360 fn next(&mut self) -> Option<Self::Item> {
361 if self.current < self.soa_struct.len {
362 let item = self.soa_struct.get(self.current); self.current += 1;
364 Some(item)
365 } else {
366 None
367 }
368 }
369 }
370 };
371
372 let mut field_decls = Vec::new();
374 let mut new_inits = Vec::new();
375 let mut default_inits = Vec::new();
376 let mut get_fields = Vec::new();
377 let mut set_fields = Vec::new();
378 let mut accessors = Vec::new();
379
380 let mut soa_push_fields = Vec::new();
382 let mut soa_pop_fields = Vec::new();
383 let mut soa_apply_args = Vec::new();
384 let mut soa_apply_sets = Vec::new();
385 let mut soa_encode_fields = Vec::new();
386 let mut soa_decode_fields = Vec::new();
387 let mut soa_serialize_fields = Vec::new();
388 let mut soa_serialize_bounds = Vec::new();
389 let mut soa_deserialize_bounds = Vec::new();
390 let mut soa_wire_fields = Vec::new();
391 let mut soa_wire_checks = Vec::new();
392 let mut soa_wire_assignments = Vec::new();
393
394 let mut storage_encode_fields = Vec::new();
396 let mut storage_decode_fields = Vec::new();
397 let mut storage_serialize_fields = Vec::new();
398 let mut storage_serialize_bounds = Vec::new();
399 let mut storage_clone_bounds = Vec::new();
400 let mut storage_wire_fields = Vec::new();
401 let mut storage_wire_checks = Vec::new();
402 let mut storage_wire_assignments = Vec::new();
403
404 for info in &field_infos {
405 let name = &info.name;
406 let ty = &info.ty;
407 let name_mut = format_ident!("{}_mut", name);
408
409 if info.nested {
410 let storage_path = info
411 .storage_path
412 .as_ref()
413 .expect("nested field missing storage path");
414 let storage_wire_path = info
415 .storage_wire_path
416 .as_ref()
417 .expect("nested field missing storage wire path");
418 let serde_name = format_ident!("{}_serde", name);
419
420 field_decls.push(quote!(pub #name: #storage_path<N>));
421 new_inits.push(quote!(#name: #storage_path::<N>::new(default.#name.clone())));
422 default_inits.push(quote!(#name: #storage_path::<N>::default()));
423 storage_clone_bounds.push(quote!(#storage_path<N>: Clone,));
424
425 accessors.push(quote! {
426 pub fn #name(&self) -> &#storage_path<N> {
427 &self.#name
428 }
429
430 pub fn #name_mut(&mut self) -> &mut #storage_path<N> {
431 &mut self.#name
432 }
433 });
434
435 get_fields.push(quote!(#name: self.#name.get(index),));
436 set_fields.push(quote!(self.#name.set(index, value.#name.clone());));
437
438 soa_push_fields.push(quote!(self.#name.set(self.len, value.#name.clone());));
439 soa_pop_fields.push(quote!(#name: self.#name.get(self.len),));
440
441 soa_apply_args.push(quote!(self.#name.get(_idx)));
442 soa_apply_sets.push(quote!(self.#name.set(_idx, #name);));
443
444 storage_encode_fields.push(quote!(self.#name.encode_len(encoder, len)?;));
445 storage_decode_fields
446 .push(quote!(result.#name = #storage_path::<N>::decode_len(decoder, len)?;));
447
448 soa_encode_fields.push(quote!(self.#name.encode_len(encoder, self.len)?;));
449 soa_decode_fields
450 .push(quote!(result.#name = #storage_path::<N>::decode_len(decoder, result.len)?;));
451
452 storage_serialize_fields.push(quote! {
453 {
454 let #serde_name = self.storage.#name.serialize_len(self.len);
455 state.serialize_field(stringify!(#name), &#serde_name)?;
456 }
457 });
458 soa_serialize_fields.push(quote! {
459 {
460 let #serde_name = self.#name.serialize_len(self.len);
461 state.serialize_field(stringify!(#name), &#serde_name)?;
462 }
463 });
464
465 storage_wire_fields.push(quote!(#name: #storage_wire_path,));
466 storage_wire_assignments.push(quote!(
467 result.#name = #storage_path::<N>::from_wire(#name, len)
468 .map_err(|err| format!("field {}: {}", stringify!(#name), err))?;
469 ));
470
471 soa_wire_fields.push(quote!(#name: #storage_wire_path,));
472 soa_wire_assignments.push(quote!(
473 result.#name = #storage_path::<N>::from_wire(#name, len)
474 .map_err(|err| serde::de::Error::custom(format!(
475 "field {}: {}",
476 stringify!(#name),
477 err
478 )))?;
479 ));
480 } else {
481 let name_range = format_ident!("{}_range", name);
482 let name_range_mut = format_ident!("{}_range_mut", name);
483
484 field_decls.push(quote!(pub #name: [#ty; N]));
485 new_inits.push(quote!(#name: from_fn(|_| default.#name.clone())));
486 default_inits.push(quote!(#name: from_fn(|_| #ty::default())));
487 storage_clone_bounds.push(quote!(#ty: Clone,));
488
489 accessors.push(quote! {
490 pub fn #name(&self) -> &[#ty] {
491 &self.#name
492 }
493
494 pub fn #name_mut(&mut self) -> &mut [#ty] {
495 &mut self.#name
496 }
497
498 pub fn #name_range(&self, range: std::ops::Range<usize>) -> &[#ty] {
499 &self.#name[range]
500 }
501
502 pub fn #name_range_mut(&mut self, range: std::ops::Range<usize>) -> &mut [#ty] {
503 &mut self.#name[range]
504 }
505 });
506
507 get_fields.push(quote!(#name: self.#name[index].clone(),));
508 set_fields.push(quote!(self.#name[index] = value.#name.clone();));
509
510 soa_push_fields.push(quote!(self.#name[self.len] = value.#name.clone();));
511 soa_pop_fields.push(quote!(#name: self.#name[self.len].clone(),));
512
513 soa_apply_args.push(quote!(self.#name[_idx].clone()));
514 soa_apply_sets.push(quote!(self.#name[_idx] = #name;));
515
516 storage_encode_fields.push(quote! {
517 for _idx in 0..len {
518 self.#name[_idx].encode(encoder)?;
519 }
520 });
521 storage_decode_fields.push(quote! {
522 for _idx in 0..len {
523 result.#name[_idx] = Decode::decode(decoder)?;
524 }
525 });
526
527 soa_encode_fields.push(quote! {
528 for _idx in 0..self.len {
529 self.#name[_idx].encode(encoder)?;
530 }
531 });
532 soa_decode_fields.push(quote! {
533 for _idx in 0..result.len {
534 result.#name[_idx] = Decode::decode(decoder)?;
535 }
536 });
537
538 storage_serialize_fields.push(quote! {
539 state.serialize_field(stringify!(#name), &self.storage.#name[..self.len])?;
540 });
541 soa_serialize_fields.push(quote! {
542 state.serialize_field(stringify!(#name), &self.#name[..self.len])?;
543 });
544
545 storage_serialize_bounds.push(quote!(#ty: Serialize,));
546 soa_serialize_bounds.push(quote!(#ty: Serialize,));
547 soa_deserialize_bounds.push(quote!(#ty: Deserialize<'de> + Default,));
548
549 storage_wire_fields.push(quote!(#name: Vec<#ty>,));
550 storage_wire_checks.push(quote! {
551 if #name.len() != len {
552 return Err(format!(
553 "field {} has length {} but len is {}",
554 stringify!(#name),
555 #name.len(),
556 len
557 ));
558 }
559 });
560 storage_wire_assignments.push(quote! {
561 for (idx, value) in #name.into_iter().enumerate() {
562 result.#name[idx] = value;
563 }
564 });
565
566 soa_wire_fields.push(quote!(#name: Vec<#ty>,));
567 soa_wire_checks.push(quote! {
568 if #name.len() != len {
569 return Err(serde::de::Error::custom(format!(
570 "field {} has length {} but len is {}",
571 stringify!(#name),
572 #name.len(),
573 len
574 )));
575 }
576 });
577 soa_wire_assignments.push(quote! {
578 for (idx, value) in #name.into_iter().enumerate() {
579 result.#name[idx] = value;
580 }
581 });
582 }
583 }
584
585 let storage_clone_where = if storage_clone_bounds.is_empty() {
586 quote!()
587 } else {
588 quote!(where #(#storage_clone_bounds)*)
589 };
590 let storage_serialize_where = if storage_serialize_bounds.is_empty() {
591 quote!()
592 } else {
593 quote!(where #(#storage_serialize_bounds)*)
594 };
595 let soa_serialize_where = if soa_serialize_bounds.is_empty() {
596 quote!()
597 } else {
598 quote!(where #(#soa_serialize_bounds)*)
599 };
600 let soa_deserialize_where = if soa_deserialize_bounds.is_empty() {
601 quote!()
602 } else {
603 quote!(where #(#soa_deserialize_bounds)*)
604 };
605
606 let expanded = quote! {
607 #visibility mod #module_name {
608 use bincode::{Decode, Encode};
609 use bincode::enc::Encoder;
610 use bincode::de::Decoder;
611 use bincode::error::{DecodeError, EncodeError};
612 use serde::Deserialize;
613 use serde::Serialize;
614 use serde::Serializer;
615 use serde::ser::SerializeStruct;
616 use std::ops::{Index, IndexMut};
617 #( use super::#unique_imports; )*
618 #reflect_import
619 use core::array::from_fn;
620
621 #[derive(Debug)]
622 #visibility struct #soa_storage_name<const N: usize> {
623 #(#field_decls,)*
624 }
625
626 #[doc(hidden)]
627 #[derive(Deserialize)]
628 #visibility struct #soa_storage_wire_name {
629 #(#storage_wire_fields)*
630 }
631
632 #[doc(hidden)]
633 #visibility struct #soa_storage_serde_name<'a, const N: usize> {
634 storage: &'a #soa_storage_name<N>,
635 len: usize,
636 }
637
638 impl<const N: usize> #soa_storage_name<N> {
639 pub fn new(default: super::#name) -> Self {
640 Self {
641 #(#new_inits,)*
642 }
643 }
644
645 pub fn set(&mut self, index: usize, value: super::#name) {
646 assert!(index < N, "Index out of bounds");
647 #(#set_fields)*
648 }
649
650 pub fn get(&self, index: usize) -> super::#name {
651 assert!(index < N, "Index out of bounds");
652 super::#name {
653 #(#get_fields)*
654 }
655 }
656
657 pub fn encode_len<E: Encoder>(
658 &self,
659 encoder: &mut E,
660 len: usize,
661 ) -> Result<(), EncodeError> {
662 #(#storage_encode_fields)*
663 Ok(())
664 }
665
666 pub fn decode_len<D: Decoder<Context = ()>>(
667 decoder: &mut D,
668 len: usize,
669 ) -> Result<Self, DecodeError> {
670 let mut result = Self::default();
671 #(#storage_decode_fields)*
672 Ok(result)
673 }
674
675 pub fn serialize_len(&self, len: usize) -> #soa_storage_serde_name<'_, N> {
676 #soa_storage_serde_name {
677 storage: self,
678 len,
679 }
680 }
681
682 pub fn from_wire(wire: #soa_storage_wire_name, len: usize) -> Result<Self, String> {
683 let #soa_storage_wire_name { #( #field_names ),* } = wire;
684
685 if len > N {
686 return Err(format!(
687 "len {} exceeds capacity {}",
688 len,
689 N
690 ));
691 }
692
693 #(#storage_wire_checks)*
694
695 let mut result = Self::default();
696 #(#storage_wire_assignments)*
697 Ok(result)
698 }
699
700 #(#accessors)*
701 }
702
703 impl<'a, const N: usize> Serialize for #soa_storage_serde_name<'a, N>
704 #storage_serialize_where
705 {
706 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
707 where
708 S: Serializer,
709 {
710 let mut state = serializer.serialize_struct(
711 stringify!(#soa_storage_name),
712 #storage_field_count,
713 )?;
714 #(#storage_serialize_fields)*
715 state.end()
716 }
717 }
718
719 impl<const N: usize> Default for #soa_storage_name<N> {
720 fn default() -> Self {
721 Self {
722 #(#default_inits,)*
723 }
724 }
725 }
726
727 impl<const N: usize> Clone for #soa_storage_name<N>
728 #storage_clone_where
729 {
730 fn clone(&self) -> Self {
731 Self {
732 #( #field_names: self.#field_names.clone(), )*
733 }
734 }
735 }
736
737 #[derive(Debug)]
738 #soa_reflect_attrs
739 #visibility struct #soa_struct_name<const N: usize> {
740 pub len: usize,
741 #(#field_decls,)*
742 }
743
744 impl<const N: usize> #soa_struct_name<N> {
745 pub fn new(default: super::#name) -> Self {
746 Self {
747 #(#new_inits,)*
748 len: 0,
749 }
750 }
751
752 pub fn len(&self) -> usize {
753 self.len
754 }
755
756 pub fn is_empty(&self) -> bool {
757 self.len == 0
758 }
759
760 pub fn push(&mut self, value: super::#name) {
761 if self.len < N {
762 #(#soa_push_fields)*
763 self.len += 1;
764 } else {
765 panic!("Capacity exceeded")
766 }
767 }
768
769 pub fn pop(&mut self) -> Option<super::#name> {
770 if self.len == 0 {
771 None
772 } else {
773 self.len -= 1;
774 Some(super::#name {
775 #(#soa_pop_fields)*
776 })
777 }
778 }
779
780 pub fn set(&mut self, index: usize, value: super::#name) {
781 assert!(index < self.len, "Index out of bounds");
782 #(#set_fields)*
783 }
784
785 pub fn get(&self, index: usize) -> super::#name {
786 assert!(index < self.len, "Index out of bounds");
787 super::#name {
788 #(#get_fields)*
789 }
790 }
791
792 pub fn apply<F>(&mut self, mut f: F)
793 where
794 F: FnMut(#(#field_types),*) -> (#(#field_types),*)
795 {
796 for _idx in 0..self.len {
798 let result = f(#(#soa_apply_args),*);
799 let (#(#field_names),*) = result;
800 #(#soa_apply_sets)*
801 }
802 }
803
804 pub fn iter(&self) -> #soa_struct_name_iterator<N> {
805 #soa_struct_name_iterator::new(self)
806 }
807
808 #(#accessors)*
809 }
810
811 impl<const N: usize> Encode for #soa_struct_name<N> {
812 fn encode<E: Encoder>(&self, encoder: &mut E) -> Result<(), EncodeError> {
813 self.len.encode(encoder)?;
814 #(#soa_encode_fields)*
815 Ok(())
816 }
817 }
818
819 impl<const N: usize> Decode<()> for #soa_struct_name<N> {
820 fn decode<D: Decoder<Context = ()>>(decoder: &mut D) -> Result<Self, DecodeError> {
821 let mut result = Self::default();
822 result.len = Decode::decode(decoder)?;
823 #(#soa_decode_fields)*
824 Ok(result)
825 }
826 }
827
828 impl<const N: usize> Default for #soa_struct_name<N> {
829 fn default() -> Self {
830 Self {
831 #(#default_inits,)*
832 len: 0,
833 }
834 }
835 }
836
837 impl<const N: usize> Clone for #soa_struct_name<N>
838 #storage_clone_where
839 {
840 fn clone(&self) -> Self {
841 Self {
842 #( #field_names: self.#field_names.clone(), )*
843 len: self.len,
844 }
845 }
846 }
847
848 impl<const N: usize> Serialize for #soa_struct_name<N>
849 #soa_serialize_where
850 {
851 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
852 where
853 S: Serializer,
854 {
855 let mut state =
856 serializer.serialize_struct(stringify!(#soa_struct_name), #field_count)?;
857 state.serialize_field("len", &self.len)?;
858 #(#soa_serialize_fields)*
859 state.end()
860 }
861 }
862
863 impl<'de, const N: usize> Deserialize<'de> for #soa_struct_name<N>
864 #soa_deserialize_where
865 {
866 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
867 where
868 D: serde::Deserializer<'de>,
869 {
870 #[derive(Deserialize)]
871 struct #soa_struct_wire_name {
872 len: usize,
873 #(#soa_wire_fields)*
874 }
875
876 let wire = #soa_struct_wire_name::deserialize(deserializer)?;
877 let #soa_struct_wire_name { len, #( #field_names ),* } = wire;
878
879 if len > N {
880 return Err(serde::de::Error::custom(format!(
881 "len {} exceeds capacity {}",
882 len,
883 N
884 )));
885 }
886
887 #(#soa_wire_checks)*
888
889 let mut result = Self::default();
890 result.len = len;
891 #(#soa_wire_assignments)*
892 Ok(result)
893 }
894 }
895
896 #iterator
897 }
898 #visibility use #module_name::#soa_struct_name;
899 #visibility use #module_name::#soa_storage_name;
900 #visibility use #module_name::#soa_storage_wire_name;
901 };
902
903 let tokens: TokenStream = expanded.into();
904
905 tokens
906}