Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions crates/sol-macro-expander/src/expand/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,44 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenS
}

let enum_expander = CallLikeExpander { cx, contract_name: name.clone(), extra_methods };
// Remove any `Default` derives.
let mut enum_attrs = item_attrs;
for attr in &mut enum_attrs {
if !attr.path().is_ident("derive") {
continue;
}

let derives = alloy_sol_macro_input::parse_derives(attr);
let mut derives = derives.into_iter().collect::<Vec<_>>();
if derives.is_empty() {
continue;
}

let len = derives.len();
derives.retain(|derive| !derive.is_ident("Default"));
if derives.len() == len {
continue;
}

attr.meta = parse_quote! { derive(#(#derives),*) };
}

let functions_enum = (!functions.is_empty()).then(|| {
let mut attrs = item_attrs.clone();
let mut attrs = enum_attrs.clone();
let doc_str = format!("Container for all the [`{name}`](self) function calls.");
attrs.push(parse_quote!(#[doc = #doc_str]));
enum_expander.expand(ToExpand::Functions(&functions), attrs)
});

let errors_enum = (!errors.is_empty()).then(|| {
let mut attrs = item_attrs.clone();
let mut attrs = enum_attrs.clone();
let doc_str = format!("Container for all the [`{name}`](self) custom errors.");
attrs.push(parse_quote!(#[doc = #doc_str]));
enum_expander.expand(ToExpand::Errors(&errors), attrs)
});

let events_enum = (!events.is_empty()).then(|| {
let mut attrs = item_attrs;
let mut attrs = enum_attrs;
let doc_str = format!("Container for all the [`{name}`](self) events.");
attrs.push(parse_quote!(#[doc = #doc_str]));
enum_expander.expand(ToExpand::Events(&events), attrs)
Expand Down
9 changes: 6 additions & 3 deletions crates/sol-macro-input/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ pub fn derives(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
/// Returns an iterator over all the rust `::` paths in the `#[derive(...)]`
/// attributes.
pub fn derives_mapped(attrs: &[Attribute]) -> impl Iterator<Item = Path> + '_ {
derives(attrs).flat_map(|attr| {
attr.parse_args_with(Punctuated::<Path, Token![,]>::parse_terminated).unwrap_or_default()
})
derives(attrs).flat_map(parse_derives)
}

/// Parses the `#[derive(...)]` attributes into a list of paths.
pub fn parse_derives(attr: &Attribute) -> Punctuated<Path, Token![,]> {
attr.parse_args_with(Punctuated::<Path, Token![,]>::parse_terminated).unwrap_or_default()
}

// When adding a new attribute:
Expand Down
2 changes: 1 addition & 1 deletion crates/sol-macro-input/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ extern crate syn_solidity as ast;

/// Tools for working with `#[...]` attributes.
mod attr;
pub use attr::{derives_mapped, docs_str, mk_doc, ContainsSolAttrs, SolAttrs};
pub use attr::{derives_mapped, docs_str, mk_doc, parse_derives, ContainsSolAttrs, SolAttrs};

mod input;
pub use input::{SolInput, SolInputKind};
Expand Down
20 changes: 20 additions & 0 deletions crates/sol-types/tests/macros/sol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -876,3 +876,23 @@ fn event_overrides() {
assert_eq!(two::TestEvent_1::SIGNATURE, "TestEvent(bytes32,bytes32)");
assert_eq!(two::TestEvent_1::SIGNATURE_HASH, keccak256("TestEvent(bytes32,bytes32)"));
}

#[test]
fn contract_derive_default() {
sol! {
#[derive(Debug, Default)]
contract MyContract {
function f1();
function f2();
event e1();
event e2();
error c();
}
}

let MyContract::f1Call {} = MyContract::f1Call::default();
let MyContract::f2Call {} = MyContract::f2Call::default();
let MyContract::e1 {} = MyContract::e1::default();
let MyContract::e2 {} = MyContract::e2::default();
let MyContract::c {} = MyContract::c::default();
}