From d5ff54d4b5c1c6f4a9d856c3fc133ca8f0ad47fc Mon Sep 17 00:00:00 2001 From: Numbers Date: Mon, 2 Jun 2025 11:53:15 +0200 Subject: [PATCH] implement enums --- derive_macro/src/lib.rs | 257 ++++++++++++++++++++++++---------------- 1 file changed, 153 insertions(+), 104 deletions(-) diff --git a/derive_macro/src/lib.rs b/derive_macro/src/lib.rs index 963b9f8..681e43f 100644 --- a/derive_macro/src/lib.rs +++ b/derive_macro/src/lib.rs @@ -4,7 +4,7 @@ use proc_macro::TokenStream; use std::fmt::Write; use proc_macro2::Ident; use quote::ToTokens; -use syn::{Data, DataStruct, DeriveInput, Fields, parse_macro_input}; +use syn::{Data, DataStruct, DeriveInput, Fields, parse_macro_input, DataEnum}; #[proc_macro_derive(OverTheWire)] pub fn derive_bin_serializer(input: TokenStream) -> TokenStream { @@ -12,128 +12,177 @@ pub fn derive_bin_serializer(input: TokenStream) -> TokenStream { match input.data { Data::Struct(data) => { derive_struct(input.ident, &data).unwrap() } - Data::Enum(_data) => { panic!("Cannot derive OverTheWire for Enums (yet)"); } + Data::Enum(data) => { derive_enum(input.ident, &data).unwrap() } Data::Union(_data) => { panic!("Cannot derive OverTheWire for Unions"); } } } + #[allow(unused_variables)] fn derive_struct(name: Ident, data: &DataStruct) -> Result { let mut output = String::new(); - - // collect the fields... - let input_fields = match &data.fields { - - // field_name: - Fields::Named(fields) => - fields.named.iter().filter_map( - |n| { n.ident.as_ref().map(|i| i.to_string()) } - ).collect::>(), - - // field names are 0, 1, 2, etc. - Fields::Unnamed(fields) => - (0..fields.unnamed.len()) - .map(|i| i.to_string()) - .collect::>(), - - // no fields to worry about - Fields::Unit => Vec::new(), - }; - - - const PAD: &str = " "; - let deserialize = match &data.fields { - Fields::Named(_) => { - let mut output = String::new(); - writeln!(&mut output, "{PAD}Ok(Self{{")?; - for fname in &input_fields { - writeln!(&mut output, "{PAD} {fname}: otw::OverTheWire::deserialize(reader)?,")? - } - writeln!(&mut output, "{PAD}}})")?; - output - } - - Fields::Unnamed(_) => { - let mut output = String::new(); - writeln!(&mut output, "{PAD}Ok(Self(")?; - for fname in &input_fields { - writeln!(&mut output, "{PAD} otw::OverTheWire::deserialize(reader)?,")? - } - writeln!(&mut output, "{PAD}))")?; - output - } - - Fields::Unit => "Ok(Self)".to_string(), - }; - - let serialize = match data.fields { - Fields::Named(_) | Fields::Unnamed(_) => { - let mut output = String::new(); - for fname in &input_fields { - writeln!(&mut output, "{PAD}self.{fname}.serialize(writer)?;")? - } - writeln!(&mut output, "{PAD}Ok(())")?; - output - } - Fields::Unit => "Ok(())".to_string(), - }; - - - let input_types = match &data.fields { - - // field_name: - Fields::Named(fields) => - fields.named.iter() - .map(|n| { n.ty.to_token_stream().to_string() }) - .collect::>(), - - // field names are 0, 1, 2, etc. - Fields::Unnamed(fields) => fields.unnamed.iter() - .map(|f|f.ty.to_token_stream().to_string()) - .collect::>(), - - // no fields to worry about - Fields::Unit => Vec::new(), - }; - - - let size_hint = match data.fields { - Fields::Named(_) | Fields::Unnamed(_) => { - let mut output = String::new(); - writeln!(&mut output, "{PAD}0usize")?; - for ftype in &input_types { - writeln!(&mut output, "{PAD} .saturating_add(otw::min_wire_size::<{ftype}>())")? - } - output - } - Fields::Unit => "Ok(())".to_string(), - }; + let tname = name.to_string(); writeln!(&mut output, "#[automatically_derived]")?; writeln!(&mut output, "impl otw::OverTheWire for {name} {{")?; writeln!(&mut output, " fn serialize(&self, writer: &mut T) -> e::Result<()> {{")?; - writeln!(&mut output, "{serialize}")?; + wsf(&mut output, " ", Some("self"), &data.fields)?; writeln!(&mut output, " }}")?; writeln!(&mut output, " fn deserialize(reader: &mut T) -> e::Result {{")?; - writeln!(&mut output, "{deserialize}")?; + wdf(&mut output, " ", tname.as_str(), &data.fields)?; writeln!(&mut output, " }}")?; + writeln!(&mut output, " #[inline(always)]")?; writeln!(&mut output, " fn size_hint() -> usize {{")?; - writeln!(&mut output, "{size_hint}")?; + wshf(&mut output, " ", &data.fields)?; writeln!(&mut output, " }}")?; - - - /* - fn size_hint() -> usize { - size_of::() - } - */ - writeln!(&mut output, "}}")?; + Ok(output.parse().unwrap()) +} - if let Err(error) = output.parse::() { - panic!("{}", output); +#[allow(unused_variables)] +fn derive_enum(name: Ident, data: &DataEnum) -> Result { + let mut output = String::new(); + let tname = name.to_string(); + + writeln!(&mut output, "#[automatically_derived]")?; + writeln!(&mut output, "impl otw::OverTheWire for {name} {{")?; + writeln!(&mut output, " fn serialize(&self, writer: &mut T) -> e::Result<()> {{")?; + writeln!(&mut output, " match self {{")?; + + for (v_index, v) in data.variants.iter().enumerate() { + let vname = v.ident.to_string(); + const P: &str = " "; + + match &v.fields { + Fields::Named(f) => { + let fields = f.named.iter().enumerate() + .map(|(i, r)| r.ident.as_ref() + .expect("named enum has unnamed fields") + .to_string()) + .collect::>(); + let fields = fields.join(", "); + writeln!(&mut output, "{P}Self::{vname}{{{fields}}} => {{")?; + writeln!(&mut output, "{P} {v_index}u8.serialize(writer)?;")?; + wsf(&mut output, " ", None, &v.fields)?; + writeln!(&mut output, "{P}}}")?; + }, + Fields::Unnamed(f) => { + let fields = f.unnamed.iter().enumerate() + .map(|(i, _)| format!("_{i}")) + .collect::>(); + let fields = fields.join(", "); + writeln!(&mut output, "{P}Self::{vname}({fields}) => {{")?; + writeln!(&mut output, "{P} {v_index}u8.serialize(writer)?;")?; + wsf(&mut output, " ", None, &v.fields)?; + writeln!(&mut output, "{P}}}")?; + } + Fields::Unit => { + // unit types are simple and need no other help... + writeln!(&mut output, "{P}Self::{vname} => {{")?; + writeln!(&mut output, "{P} {v_index}u8.serialize(writer)?;")?; + writeln!(&mut output, "{P} Ok(())")?; + writeln!(&mut output, "{P}}}")?; + } + } } + writeln!(&mut output, " }}")?; + writeln!(&mut output, " }}")?; + + //╶───╴Deserialize╶──────────────────────────────────────────────────────────╴ + writeln!(&mut output, " fn deserialize(reader: &mut T) -> e::Result {{")?; + writeln!(&mut output, " match ::deserialize(reader)? {{")?; + + const P: &str = " "; + for (v_index, v) in data.variants.iter().enumerate() { + let vname = format!("Self::{}", v.ident); + writeln!(&mut output, "{P}{v_index} => {{")?; + wdf(&mut output, " ", vname.as_str(), &v.fields)?; + writeln!(&mut output, "{P}}}")?; + } + + writeln!(&mut output, "{P}_ => Err(otw::MalformedData)?,")?; + writeln!(&mut output, " }}")?; + writeln!(&mut output, " }}")?; + + //╶───╴Size Hint╶────────────────────────────────────────────────────────────╴ + writeln!(&mut output, " #[inline(always)]")?; + writeln!(&mut output, " fn size_hint() -> usize {{")?; + if data.variants.iter().any(|v|matches!(v.fields, Fields::Unit)) { + // if any of the variants are a unit, the minimum size is just one + writeln!(&mut output, " 1usize")?; + } else { + // TODO Could just do min(variants...) but for now im lazy + writeln!(&mut output, " 1usize")?; + } + writeln!(&mut output, " }}")?; + writeln!(&mut output, "}}")?; + Ok(output.parse().unwrap()) -} \ No newline at end of file +} + + +// write serialize func +fn wsf(out: &mut dyn Write, pad: &str, thiz: Option<&str>, fields: &Fields) -> Result<(), std::fmt::Error> { + for (i, field) in fields.iter().enumerate() { + let fname = field.ident.as_ref() + .map(|a| a.to_string()) + .unwrap_or_else(|| match thiz { + None => format!("_{i}"), + Some(_) => format!("{i}") + }); + match thiz { + // @formatter:off + None => writeln!(out, "{pad}{fname}.serialize(writer)?;")?, + Some(thiz) => writeln!(out, "{pad}{thiz}.{fname}.serialize(writer)?;")?, + // @formatter:on + } + } + writeln!(out, "{pad}Ok(())")?; + Ok(()) +} + +// write deserialize func +fn wdf(out: &mut dyn Write, pad: &str, type_name: &str, fields: &Fields) -> Result<(), std::fmt::Error> { + match fields { + // @formatter:off + Fields::Named(_) => writeln!(out, "{pad}Ok({type_name}{{")?, + Fields::Unnamed(_) => writeln!(out, "{pad}Ok({type_name}(")?, + Fields::Unit => { + // unit types need no extra logic, they are done + writeln!(out, "{pad}Ok({type_name})")?; + return Ok(()); + }, + // @formatter:on + } + + + for field in fields { + match &field.ident { + // @formatter:off + None => writeln!(out, "{pad} otw::OverTheWire::deserialize(reader)?,")?, + Some(name) => writeln!(out, "{pad} {name}: otw::OverTheWire::deserialize(reader)?,")?, + // @formatter:on + } + } + + match fields { + // @formatter:off + Fields::Named(_) => writeln!(out, "{pad}}})")?, + Fields::Unnamed(_) => writeln!(out, "{pad}))")?, + Fields::Unit => unreachable!() + // @formatter:on + } + + Ok(()) +} + +// write sizehint func +fn wshf(out: &mut dyn Write, pad: &str, fields: &Fields) -> Result<(), std::fmt::Error> { + writeln!(out, "{pad}0usize")?; + for field in fields { + writeln!(out, "{pad} .saturating_add(otw::min_wire_size::<{}>())", field.ty.to_token_stream().to_string())?; + } + Ok(()) +}