implement enums

This commit is contained in:
Numbers
2025-06-02 11:53:15 +02:00
parent b0f369e809
commit d5ff54d4b5

View File

@@ -4,7 +4,7 @@ use proc_macro::TokenStream;
use std::fmt::Write;
use proc_macro2::Ident;
use quote::ToTokens;
use syn::{Data, DataStruct, DeriveInput, Fields, parse_macro_input};
use syn::{Data, DataStruct, DeriveInput, Fields, parse_macro_input, DataEnum};
#[proc_macro_derive(OverTheWire)]
pub fn derive_bin_serializer(input: TokenStream) -> TokenStream {
@@ -12,128 +12,177 @@ pub fn derive_bin_serializer(input: TokenStream) -> TokenStream {
match input.data {
Data::Struct(data) => { derive_struct(input.ident, &data).unwrap() }
Data::Enum(_data) => { panic!("Cannot derive OverTheWire for Enums (yet)"); }
Data::Enum(data) => { derive_enum(input.ident, &data).unwrap() }
Data::Union(_data) => { panic!("Cannot derive OverTheWire for Unions"); }
}
}
#[allow(unused_variables)]
fn derive_struct(name: Ident, data: &DataStruct) -> Result<TokenStream, std::fmt::Error> {
let mut output = String::new();
// collect the fields...
let input_fields = match &data.fields {
// field_name: <value>
Fields::Named(fields) =>
fields.named.iter().filter_map(
|n| { n.ident.as_ref().map(|i| i.to_string()) }
).collect::<Vec<String>>(),
// field names are 0, 1, 2, etc.
Fields::Unnamed(fields) =>
(0..fields.unnamed.len())
.map(|i| i.to_string())
.collect::<Vec<String>>(),
// no fields to worry about
Fields::Unit => Vec::new(),
};
const PAD: &str = " ";
let deserialize = match &data.fields {
Fields::Named(_) => {
let mut output = String::new();
writeln!(&mut output, "{PAD}Ok(Self{{")?;
for fname in &input_fields {
writeln!(&mut output, "{PAD} {fname}: otw::OverTheWire::deserialize(reader)?,")?
}
writeln!(&mut output, "{PAD}}})")?;
output
}
Fields::Unnamed(_) => {
let mut output = String::new();
writeln!(&mut output, "{PAD}Ok(Self(")?;
for fname in &input_fields {
writeln!(&mut output, "{PAD} otw::OverTheWire::deserialize(reader)?,")?
}
writeln!(&mut output, "{PAD}))")?;
output
}
Fields::Unit => "Ok(Self)".to_string(),
};
let serialize = match data.fields {
Fields::Named(_) | Fields::Unnamed(_) => {
let mut output = String::new();
for fname in &input_fields {
writeln!(&mut output, "{PAD}self.{fname}.serialize(writer)?;")?
}
writeln!(&mut output, "{PAD}Ok(())")?;
output
}
Fields::Unit => "Ok(())".to_string(),
};
let input_types = match &data.fields {
// field_name: <value>
Fields::Named(fields) =>
fields.named.iter()
.map(|n| { n.ty.to_token_stream().to_string() })
.collect::<Vec<String>>(),
// field names are 0, 1, 2, etc.
Fields::Unnamed(fields) => fields.unnamed.iter()
.map(|f|f.ty.to_token_stream().to_string())
.collect::<Vec<String>>(),
// no fields to worry about
Fields::Unit => Vec::new(),
};
let size_hint = match data.fields {
Fields::Named(_) | Fields::Unnamed(_) => {
let mut output = String::new();
writeln!(&mut output, "{PAD}0usize")?;
for ftype in &input_types {
writeln!(&mut output, "{PAD} .saturating_add(otw::min_wire_size::<{ftype}>())")?
}
output
}
Fields::Unit => "Ok(())".to_string(),
};
let tname = name.to_string();
writeln!(&mut output, "#[automatically_derived]")?;
writeln!(&mut output, "impl otw::OverTheWire for {name} {{")?;
writeln!(&mut output, " fn serialize<T: otw::Writer>(&self, writer: &mut T) -> e::Result<()> {{")?;
writeln!(&mut output, "{serialize}")?;
wsf(&mut output, " ", Some("self"), &data.fields)?;
writeln!(&mut output, " }}")?;
writeln!(&mut output, " fn deserialize<T: otw::Reader>(reader: &mut T) -> e::Result<Self> {{")?;
writeln!(&mut output, "{deserialize}")?;
wdf(&mut output, " ", tname.as_str(), &data.fields)?;
writeln!(&mut output, " }}")?;
writeln!(&mut output, " #[inline(always)]")?;
writeln!(&mut output, " fn size_hint() -> usize {{")?;
writeln!(&mut output, "{size_hint}")?;
wshf(&mut output, " ", &data.fields)?;
writeln!(&mut output, " }}")?;
/*
fn size_hint() -> usize {
size_of::<Self>()
}
*/
writeln!(&mut output, "}}")?;
if let Err(error) = output.parse::<TokenStream>() {
panic!("{}", output);
}
Ok(output.parse().unwrap())
}
#[allow(unused_variables)]
fn derive_enum(name: Ident, data: &DataEnum) -> Result<TokenStream, std::fmt::Error> {
let mut output = String::new();
let tname = name.to_string();
writeln!(&mut output, "#[automatically_derived]")?;
writeln!(&mut output, "impl otw::OverTheWire for {name} {{")?;
writeln!(&mut output, " fn serialize<T: otw::Writer>(&self, writer: &mut T) -> e::Result<()> {{")?;
writeln!(&mut output, " match self {{")?;
for (v_index, v) in data.variants.iter().enumerate() {
let vname = v.ident.to_string();
const P: &str = " ";
match &v.fields {
Fields::Named(f) => {
let fields = f.named.iter().enumerate()
.map(|(i, r)| r.ident.as_ref()
.expect("named enum has unnamed fields")
.to_string())
.collect::<Vec<String>>();
let fields = fields.join(", ");
writeln!(&mut output, "{P}Self::{vname}{{{fields}}} => {{")?;
writeln!(&mut output, "{P} {v_index}u8.serialize(writer)?;")?;
wsf(&mut output, " ", None, &v.fields)?;
writeln!(&mut output, "{P}}}")?;
},
Fields::Unnamed(f) => {
let fields = f.unnamed.iter().enumerate()
.map(|(i, _)| format!("_{i}"))
.collect::<Vec<String>>();
let fields = fields.join(", ");
writeln!(&mut output, "{P}Self::{vname}({fields}) => {{")?;
writeln!(&mut output, "{P} {v_index}u8.serialize(writer)?;")?;
wsf(&mut output, " ", None, &v.fields)?;
writeln!(&mut output, "{P}}}")?;
}
Fields::Unit => {
// unit types are simple and need no other help...
writeln!(&mut output, "{P}Self::{vname} => {{")?;
writeln!(&mut output, "{P} {v_index}u8.serialize(writer)?;")?;
writeln!(&mut output, "{P} Ok(())")?;
writeln!(&mut output, "{P}}}")?;
}
}
}
writeln!(&mut output, " }}")?;
writeln!(&mut output, " }}")?;
//╶───╴Deserialize╶──────────────────────────────────────────────────────────╴
writeln!(&mut output, " fn deserialize<T: otw::Reader>(reader: &mut T) -> e::Result<Self> {{")?;
writeln!(&mut output, " match <u8 as otw::OverTheWire>::deserialize(reader)? {{")?;
const P: &str = " ";
for (v_index, v) in data.variants.iter().enumerate() {
let vname = format!("Self::{}", v.ident);
writeln!(&mut output, "{P}{v_index} => {{")?;
wdf(&mut output, " ", vname.as_str(), &v.fields)?;
writeln!(&mut output, "{P}}}")?;
}
writeln!(&mut output, "{P}_ => Err(otw::MalformedData)?,")?;
writeln!(&mut output, " }}")?;
writeln!(&mut output, " }}")?;
//╶───╴Size Hint╶────────────────────────────────────────────────────────────╴
writeln!(&mut output, " #[inline(always)]")?;
writeln!(&mut output, " fn size_hint() -> usize {{")?;
if data.variants.iter().any(|v|matches!(v.fields, Fields::Unit)) {
// if any of the variants are a unit, the minimum size is just one
writeln!(&mut output, " 1usize")?;
} else {
// TODO Could just do min(variants...) but for now im lazy
writeln!(&mut output, " 1usize")?;
}
writeln!(&mut output, " }}")?;
writeln!(&mut output, "}}")?;
Ok(output.parse().unwrap())
}
// write serialize func
fn wsf(out: &mut dyn Write, pad: &str, thiz: Option<&str>, fields: &Fields) -> Result<(), std::fmt::Error> {
for (i, field) in fields.iter().enumerate() {
let fname = field.ident.as_ref()
.map(|a| a.to_string())
.unwrap_or_else(|| match thiz {
None => format!("_{i}"),
Some(_) => format!("{i}")
});
match thiz {
// @formatter:off
None => writeln!(out, "{pad}{fname}.serialize(writer)?;")?,
Some(thiz) => writeln!(out, "{pad}{thiz}.{fname}.serialize(writer)?;")?,
// @formatter:on
}
}
writeln!(out, "{pad}Ok(())")?;
Ok(())
}
// write deserialize func
fn wdf(out: &mut dyn Write, pad: &str, type_name: &str, fields: &Fields) -> Result<(), std::fmt::Error> {
match fields {
// @formatter:off
Fields::Named(_) => writeln!(out, "{pad}Ok({type_name}{{")?,
Fields::Unnamed(_) => writeln!(out, "{pad}Ok({type_name}(")?,
Fields::Unit => {
// unit types need no extra logic, they are done
writeln!(out, "{pad}Ok({type_name})")?;
return Ok(());
},
// @formatter:on
}
for field in fields {
match &field.ident {
// @formatter:off
None => writeln!(out, "{pad} otw::OverTheWire::deserialize(reader)?,")?,
Some(name) => writeln!(out, "{pad} {name}: otw::OverTheWire::deserialize(reader)?,")?,
// @formatter:on
}
}
match fields {
// @formatter:off
Fields::Named(_) => writeln!(out, "{pad}}})")?,
Fields::Unnamed(_) => writeln!(out, "{pad}))")?,
Fields::Unit => unreachable!()
// @formatter:on
}
Ok(())
}
// write sizehint func
fn wshf(out: &mut dyn Write, pad: &str, fields: &Fields) -> Result<(), std::fmt::Error> {
writeln!(out, "{pad}0usize")?;
for field in fields {
writeln!(out, "{pad} .saturating_add(otw::min_wire_size::<{}>())", field.ty.to_token_stream().to_string())?;
}
Ok(())
}