Skip to content

Commit

Permalink
Add forward_mut in Sequential derive (#884)
Browse files Browse the repository at this point in the history
* add test for try_forward_mut (fails)

* add forward_mut to Sequential derive

* also use the _mut for tuple structs
  • Loading branch information
swfsql authored Nov 6, 2023
1 parent 370334f commit 53469e9
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 12 deletions.
49 changes: 37 additions & 12 deletions dfdx-derives/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,23 +489,44 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
};

let impl_module = {
let src = match input.data {
let (src, src_mut) = match input.data {
Data::Struct(ref data) => match data.fields {
Fields::Named(ref fields) => {
let recurse = fields.named.iter().map(|f| {
let name = &f.ident;
quote_spanned! {f.span()=> self.#name.try_forward(x)? }
});
quote! { #(let x = #recurse;)* }
let (recurse, recurse_mut) = fields
.named
.iter()
.map(|f| {
let name = &f.ident;
(
quote_spanned! {f.span()=> self.#name.try_forward(x)? },
quote_spanned! {f.span()=> self.#name.try_forward_mut(x)? },
)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
(
quote! { #(let x = #recurse;)* },
quote! { #(let x = #recurse_mut;)* },
)
}
Fields::Unnamed(ref fields) => {
let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
let index = Index::from(i);
quote_spanned! {f.span()=> self.#index.try_forward(x)? }
});
quote! { #(let x = #recurse;)* }
let (recurse, recurse_mut) = fields
.unnamed
.iter()
.enumerate()
.map(|(i, f)| {
let index = Index::from(i);
(
quote_spanned! {f.span()=> self.#index.try_forward(x)? },
quote_spanned! {f.span()=> self.#index.try_forward_mut(x)? },
)
})
.unzip::<_, _, Vec<_>, Vec<_>>();
(
quote! { #(let x = #recurse;)* },
quote! { #(let x = #recurse_mut;)* },
)
}
Fields::Unit => quote! { let x = x; },
Fields::Unit => (quote! { let x = x; }, quote! { let x = x; }),
},
_ => unreachable!(),
};
Expand All @@ -520,6 +541,10 @@ pub fn sequential(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
#src
Ok(x)
}
fn try_forward_mut(&mut self, x: Input) -> Result<Self::Output, Error> {
#src_mut
Ok(x)
}
}
}
};
Expand Down
18 changes: 18 additions & 0 deletions dfdx/src/nn/layers/batch_norm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,22 @@ mod tests {
let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default());
opt.update(&mut bn, &g).expect("");
}

#[derive(Default, Clone, Sequential)]
struct Arch {
pub batch: BatchNorm2DConstConfig<3>,
}

#[test]
fn test_batchnorm2d_update_with_derive() {
let dev: TestDevice = Default::default();

let x1: Tensor<Rank3<3, 4, 5>, TestDtype, _> = dev.sample_normal();
let mut bn = dev.build_module::<TestDtype>(Arch::default());
let y = bn.forward_mut(x1.leaky_trace());
let g = y.square().mean().backward();

let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default());
opt.update(&mut bn, &g).expect("");
}
}

0 comments on commit 53469e9

Please sign in to comment.