From 3bfb35e3efc52321dd056b65fb863f4cfe059fc9 Mon Sep 17 00:00:00 2001 From: Jachym Putta Date: Wed, 22 May 2024 16:49:19 -0400 Subject: [PATCH 1/5] feat: Greater + GreaterOrEqual onnx import --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 4 +- crates/burn-import/onnx-tests/build.rs | 2 + .../onnx-tests/tests/greater/greater.onnx | 17 +++++++++ .../onnx-tests/tests/greater/greater.py | 38 +++++++++++++++++++ .../greater_or_equal/greater_or_equal.onnx | 17 +++++++++ .../greater_or_equal/greater_or_equal.py | 38 +++++++++++++++++++ .../onnx-tests/tests/onnx_tests.rs | 28 ++++++++++++++ crates/burn-import/src/burn/node/binary.rs | 38 +++++++++++++++++++ crates/burn-import/src/onnx/dim_inference.rs | 26 +++++++++++++ crates/burn-import/src/onnx/to_burn.rs | 18 +++++++++ 10 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/greater/greater.onnx create mode 100644 crates/burn-import/onnx-tests/tests/greater/greater.py create mode 100644 crates/burn-import/onnx-tests/tests/greater_or_equal/greater_or_equal.onnx create mode 100644 crates/burn-import/onnx-tests/tests/greater_or_equal/greater_or_equal.py diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index a21cb56663..d0c775a7e4 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -71,8 +71,8 @@ represent the corresponding Burn Op. | [GlobalAveragePool][63] | ✅ | ✅ | | [GlobalLpPool][64] | ❌ | ❌ | | [GlobalMaxPool][65] | ❌ | ❌ | -| [Greater][66] | ❌ | ✅ | -| [GreaterOrEqual][67] | ❌ | ✅ | +| [Greater][66] | ✅ | ✅ | +| [GreaterOrEqual][67] | ✅ | ✅ | | [GridSample][68] | ❌ | ❌ | | [GroupNormalization][69] | ❌ | ✅ | | [GRU][70] | ❌ | ✅ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index bd486050b6..4e8753778a 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -40,6 +40,8 @@ fn main() { .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") + .input("tests/greater/greater.onnx") + .input("tests/greater_or_equal/greater_or_equal.onnx") .input("tests/recip/recip.onnx") .input("tests/relu/relu.onnx") .input("tests/leaky_relu/leaky_relu.onnx") diff --git a/crates/burn-import/onnx-tests/tests/greater/greater.onnx b/crates/burn-import/onnx-tests/tests/greater/greater.onnx new file mode 100644 index 0000000000..adfac2bdd6 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/greater/greater.onnx @@ -0,0 +1,17 @@ +pytorch2.3.0: +8 +onnx::Greater_0 +onnx::Greater_12/Greater"Greater +main_graphZ! +onnx::Greater_0 +  + +Z! +onnx::Greater_1 +  + +b +2 +   + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/greater/greater.py b/crates/burn-import/onnx-tests/tests/greater/greater.py new file mode 100644 index 0000000000..9406b00318 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/greater/greater.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/greater/greater.onnx + +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + return torch.gt(x,y) + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + torch.set_printoptions(precision=8) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + + onnx_name = "greater.onnx" + + test_input1 = torch.randn(4, 4, device=device) + test_input2 = torch.randn(4, 4, device=device) + torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + print("Test input data: {} {}".format(test_input1, test_input2)) + output = model.forward(test_input1, test_input2) + print("Test output data: {}".format(output)) + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/greater_or_equal/greater_or_equal.onnx b/crates/burn-import/onnx-tests/tests/greater_or_equal/greater_or_equal.onnx new file mode 100644 index 0000000000..83320ae2c5 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/greater_or_equal/greater_or_equal.onnx @@ -0,0 +1,17 @@ +pytorch2.3.0: +T +onnx::GreaterOrEqual_0 +onnx::GreaterOrEqual_12/GreaterOrEqual"GreaterOrEqual +main_graphZ( +onnx::GreaterOrEqual_0 +  + +Z( +onnx::GreaterOrEqual_1 +  + +b +2 +   + +B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/greater_or_equal/greater_or_equal.py b/crates/burn-import/onnx-tests/tests/greater_or_equal/greater_or_equal.py new file mode 100644 index 0000000000..e977c34f2f --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/greater_or_equal/greater_or_equal.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/less_or_equal/less_or_equal.onnx + +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y): + return torch.ge(x,y) + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + torch.set_printoptions(precision=8) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + + onnx_name = "greater_or_equal.onnx" + + test_input1 = torch.randn(4, 4, device=device) + test_input2 = torch.randn(4, 4, device=device) + torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + print("Test input data: {} {}".format(test_input1, test_input2)) + output = model.forward(test_input1, test_input2) + print("Test output data: {}".format(output)) + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 1235c62df4..f96fef11d0 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -51,6 +51,8 @@ include_models!( mul, neg, not, + greater, + greater_or_equal, prelu, recip, reduce_max, @@ -1171,6 +1173,32 @@ mod tests { assert_eq!(output, expected); } + #[test] + fn greater() { + let device = Default::default(); + let model: greater::Model = greater::Model::new(&device); + + let input1 = Tensor::::from_floats([[1.0, 4.0, 9.0, 25.0]], &device); + let input2 = Tensor::::from_floats([[1.0, 5.0, 8.0, -25.0]], &device); + + let output = model.forward(input1, input2); + let expected = Data::from([[false, false, true, true]]); + assert_eq!(output.to_data(), expected); + } + + #[test] + fn greater_or_equal() { + let device = Default::default(); + let model: greater_or_equal::Model = greater_or_equal::Model::new(&device); + + let input1 = Tensor::::from_floats([[1.0, 4.0, 9.0, 25.0]], &device); + let input2 = Tensor::::from_floats([[1.0, 5.0, 8.0, -25.0]], &device); + + let output = model.forward(input1, input2); + let expected = Data::from([[true, false, true, true]]); + assert_eq!(output.to_data(), expected); + } + #[test] fn test_model_creation_with_a_default_device() { let device = Default::default(); diff --git a/crates/burn-import/src/burn/node/binary.rs b/crates/burn-import/src/burn/node/binary.rs index b4d409e17d..da299f23cb 100644 --- a/crates/burn-import/src/burn/node/binary.rs +++ b/crates/burn-import/src/burn/node/binary.rs @@ -16,6 +16,8 @@ pub enum BinaryType { Powi, Min, Max, + Greater, + GreaterOrEqual, } impl BinaryType { @@ -30,6 +32,8 @@ impl BinaryType { BinaryType::Powf => "powf", BinaryType::Min => "min_pair", BinaryType::Max => "max_pair", + BinaryType::Greater => "greater", + BinaryType::GreaterOrEqual => "greater_equal", } } } @@ -193,6 +197,30 @@ impl BinaryNode { }; Self::new(lhs, rhs, output, BinaryType::Max, Arc::new(function)) } + + pub(crate) fn greater(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => move |lhs, rhs| quote! { #lhs.greater(#rhs) }, + _ => panic!("greater is supported for tensor only"), + }; + Self::new(lhs, rhs, output, BinaryType::Greater, Arc::new(function)) + } + + pub(crate) fn greater_equal(lhs: Type, rhs: Type, output: Type) -> Self { + let function = match (&lhs, &rhs) { + (Type::Tensor(_), Type::Tensor(_)) => { + move |lhs, rhs| quote! { #lhs.greater_equal(#rhs) } + } + _ => panic!("greater_equal is supported for tensor only"), + }; + Self::new( + lhs, + rhs, + output, + BinaryType::GreaterOrEqual, + Arc::new(function), + ) + } } #[cfg(test)] @@ -358,6 +386,16 @@ mod tests { test_binary_operator_on_tensors!(max_pair); } + #[test] + fn test_binary_codegen_greater() { + test_binary_operator_on_tensors!(greater); + } + + #[test] + fn test_binary_codegen_greater_or_equal() { + test_binary_operator_on_tensors!(greater_equal); + } + #[test] fn test_binary_codegen_equal_tensors() { let mut graph = BurnGraph::::default(); diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index f8e95e1c10..df36c6a8d1 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -46,6 +46,8 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Mul => same_as_input(node), NodeType::Neg => same_as_input(node), NodeType::Not => same_as_input(node), + NodeType::Greater => greater_update_outputs(node), + NodeType::GreaterOrEqual => greater_or_equal_update_outputs(node), NodeType::Reciprocal => same_as_input(node), NodeType::ReduceMax => reduce_max_update_outputs(node), NodeType::ReduceMean => reduce_mean_update_outputs(node), @@ -237,6 +239,30 @@ fn reshape_update_outputs(node: &mut Node) { } } +fn greater_update_outputs(node: &mut Node) { + match &node.inputs[0].ty { + ArgType::Tensor(tensor) => { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + ..tensor.clone() + }); + } + _ => panic!("Only tensor input is valid"), + } +} + +fn greater_or_equal_update_outputs(node: &mut Node) { + match &node.inputs[0].ty { + ArgType::Tensor(tensor) => { + node.outputs[0].ty = ArgType::Tensor(TensorType { + elem_type: ElementType::Bool, + ..tensor.clone() + }); + } + _ => panic!("Only tensor input is valid"), + } +} + fn reduce_mean_update_outputs(node: &mut Node) { if node.inputs.len() != 1 { panic!("Mean: multiple inputs are not supported"); diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 31a2454aa0..b014849c35 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -251,6 +251,8 @@ impl OnnxGraph { NodeType::MatMul => graph.register(Self::matmul_conversion(node)), NodeType::Neg => graph.register(Self::neg_conversion(node)), NodeType::Not => graph.register(Self::not_conversion(node)), + NodeType::Greater => graph.register(Self::greater_conversion(node)), + NodeType::GreaterOrEqual => graph.register(Self::greater_or_equal_conversion(node)), NodeType::LayerNormalization => { graph.register(Self::layer_norm_conversion::(node)) } @@ -822,6 +824,22 @@ impl OnnxGraph { UnaryNode::not(input, output) } + fn greater_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.first().unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + + BinaryNode::greater(lhs, rhs, output) + } + + fn greater_or_equal_conversion(node: Node) -> BinaryNode { + let lhs = node.inputs.first().unwrap().to_type(); + let rhs = node.inputs.get(1).unwrap().to_type(); + let output = node.outputs.first().unwrap().to_type(); + + BinaryNode::greater_equal(lhs, rhs, output) + } + fn pow_conversion(node: Node) -> BinaryNode { let lhs = node.inputs.first().unwrap().to_type(); let rhs = node.inputs.get(1).unwrap().to_type(); From 96667463de3f90e8bc567afcf1683676b1753be6 Mon Sep 17 00:00:00 2001 From: Jachym Putta Date: Fri, 24 May 2024 16:48:30 -0400 Subject: [PATCH 2/5] WIP: skeleton, arg2 type problem --- crates/burn-import/SUPPORTED-ONNX-OPS.md | 2 +- crates/burn-import/onnx-tests/build.rs | 1 + .../onnx-tests/tests/expand/expand.onnx | Bin 0 -> 1301 bytes .../onnx-tests/tests/expand/expand.py | 41 ++++++++ .../onnx-tests/tests/onnx_tests.rs | 15 +++ crates/burn-import/src/burn/node/base.rs | 12 ++- crates/burn-import/src/burn/node/expand.rs | 92 ++++++++++++++++++ crates/burn-import/src/burn/node/mod.rs | 1 + crates/burn-import/src/onnx/dim_inference.rs | 28 ++++++ .../burn-import/src/onnx/op_configuration.rs | 25 +++++ crates/burn-import/src/onnx/to_burn.rs | 10 ++ 11 files changed, 222 insertions(+), 5 deletions(-) create mode 100644 crates/burn-import/onnx-tests/tests/expand/expand.onnx create mode 100644 crates/burn-import/onnx-tests/tests/expand/expand.py create mode 100644 crates/burn-import/src/burn/node/expand.rs diff --git a/crates/burn-import/SUPPORTED-ONNX-OPS.md b/crates/burn-import/SUPPORTED-ONNX-OPS.md index ed4115fbd1..88392e03b5 100644 --- a/crates/burn-import/SUPPORTED-ONNX-OPS.md +++ b/crates/burn-import/SUPPORTED-ONNX-OPS.md @@ -59,7 +59,7 @@ represent the corresponding Burn Op. | [Equal][51] | ✅ | ✅ | | [Erf][52] | ✅ | ✅ | | [Exp][53] | ✅ | ✅ | -| [Expand][54] | ❌ | ✅ | +| [Expand][54] | ✅ | ✅ | | [EyeLike][55] | ❌ | ❌ | | [Flatten][56] | ✅ | ✅ | | [Floor][57] | ❌ | ❌ | diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 07eb6a2026..17a335477b 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -40,6 +40,7 @@ fn main() { .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") + .input("tests/expand/expand.onnx") .input("tests/greater/greater.onnx") .input("tests/greater_or_equal/greater_or_equal.onnx") .input("tests/less/less.onnx") diff --git a/crates/burn-import/onnx-tests/tests/expand/expand.onnx b/crates/burn-import/onnx-tests/tests/expand/expand.onnx new file mode 100644 index 0000000000000000000000000000000000000000..0e17f6a8b4f54d388a553e8010a0a2de12d7a74a GIT binary patch literal 1301 zcma)5T~FIE6m3eAxR|&P&iy#Y$Jd0Q9zI6F{C3h)orA$afArFZF`{)8_=|}7ksHKO7)Nfu z*Df$h@~GJZ^%tAPp4laep`VgQ_B<;sxP^Gj3X4wO$)&u7i5B_A&>WscZgZdsS*R$2;BHNoL0tY_fc;(#x{uN5a-NTxLl(XD(yGoUi3iZrLk@<DuRF=%!bEybPIt=k_^axdNN^u%}yToNDQcaL=Z` z`(w_+$<6!q9@vEr(UCxtZ17jPc~`xtHfqq~NsYRolLx=`fzzW2A9?BrFtD5Q+e7zm Ode7CaBlcm_F#ZBJ5SUW{ literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/expand/expand.py b/crates/burn-import/onnx-tests/tests/expand/expand.py new file mode 100644 index 0000000000..3dba5cd5a2 --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/expand/expand.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# used to generate model: onnx-tests/tests/expand/expand.onnx + +import torch +import torch.nn as nn + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, *y): + return x.expand(*y) + +def main(): + # Set seed for reproducibility + torch.manual_seed(42) + + # Export to onnx + model = Model() + model.eval() + device = torch.device("cpu") + + onnx_name = "expand.onnx" + + test_input1 = torch.tensor([[1], [2], [3]]) + test_input2 = (3,4) + torch.onnx.export(model, (test_input1, test_input2), onnx_name, verbose=False, opset_version=16) + + print("Finished exporting model to {}".format(onnx_name)) + + print("Test input data: {} (3,4)".format(test_input1)) + output = model.forward(test_input1, test_input2) + print("Test output data: {}".format(output)) + # Output should be: + # tensor([[ 1, 1, 1, 1], + # [ 2, 2, 2, 2], + # [ 3, 3, 3, 3]]) + +if __name__ == '__main__': + main() diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index c423b9f1f3..1c7359e342 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -33,6 +33,7 @@ include_models!( equal, erf, exp, + expand, flatten, gather, gelu, @@ -1116,6 +1117,20 @@ mod tests { output.to_data().assert_approx_eq(&expected, 2); } + #[test] + fn expand() { + let device = Default::default(); + let model: expand::Model = expand::Model::new(&device); + + let input1 = Tensor::::from_ints([[-1], [1], [42]], &device); + let input_shape = Shape::from([3, 2]); + + let output = model.forward(input1, input_shape); + let expected = Data::from([[-1, -1], [1, 1], [42, 42]]); + + assert_eq!(output.to_data(), expected); + } + #[test] fn gelu() { let device = Default::default(); diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index b90bec6960..4911b90cc9 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -2,10 +2,11 @@ use super::{ avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode, - dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, - layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode, - max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, - squeeze::SqueezeNode, unary::UnaryNode, unsqueeze::UnsqueezeNode, + dropout::DropoutNode, expand::ExpandNode, gather::GatherNode, + global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode, + mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode, + max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, squeeze::SqueezeNode, + unary::UnaryNode, unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; use burn::backend::NdArray; @@ -87,6 +88,7 @@ pub enum Node { ConvTranspose2d(ConvTranspose2dNode), PRelu(PReluNode), Dropout(DropoutNode), + Expand(ExpandNode), Gather(GatherNode), GlobalAvgPool(GlobalAvgPoolNode), LayerNorm(LayerNormNode), @@ -117,6 +119,7 @@ macro_rules! match_all { Node::ConvTranspose2d(node) => $func(node), Node::PRelu(node) => $func(node), Node::Dropout(node) => $func(node), + Node::Expand(node) => $func(node), Node::Gather(node) => $func(node), Node::GlobalAvgPool(node) => $func(node), Node::LayerNorm(node) => $func(node), @@ -157,6 +160,7 @@ impl Node { Node::ConvTranspose2d(_) => "conv_transpose2d", Node::PRelu(_) => "prelu", Node::Dropout(_) => "dropout", + Node::Expand(_) => "expand", Node::Gather(_) => "gather", Node::GlobalAvgPool(_) => "global_avg_pool", Node::LayerNorm(_) => "layer_norm", diff --git a/crates/burn-import/src/burn/node/expand.rs b/crates/burn-import/src/burn/node/expand.rs new file mode 100644 index 0000000000..20452c2f0f --- /dev/null +++ b/crates/burn-import/src/burn/node/expand.rs @@ -0,0 +1,92 @@ +use super::{Node, NodeCodegen}; +use crate::burn::{Scope, TensorType, ToTokens, Type}; +use burn::record::PrecisionSettings; +use proc_macro2::TokenStream; +use quote::quote; + +#[derive(Debug, Clone, new)] +pub struct ExpandNode { + pub input: TensorType, + pub output: TensorType, + pub shape: Vec, +} + +impl NodeCodegen for ExpandNode { + fn output_types(&self) -> Vec { + vec![Type::Tensor(self.output.clone())] + } + + fn input_types(&self) -> Vec { + vec![Type::Tensor(self.input.clone())] + } + + fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { + let input = scope.tensor_use_owned(&self.input, node_position); + let output = &self.output.name; + let shape_values = &self.shape.to_tokens(); + + quote! { + let #output = #input.expand(#shape_values); + } + } + + fn into_node(self) -> Node { + Node::Expand(self) + } +} + +#[cfg(test)] +mod tests { + use burn::record::FullPrecisionSettings; + + use super::*; + use crate::burn::{ + graph::BurnGraph, + node::{expand::ExpandNode, test::assert_tokens}, + TensorType, + }; + + #[test] + fn test_codegen_nodes() { + let mut graph = BurnGraph::::default(); + + graph.register(ExpandNode::new( + TensorType::new_float("tensor1", 4), + TensorType::new_float("tensor2", 4), + [4, 4].into(), + )); + + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + + let expected = quote! { + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + + #[derive(Module, Debug)] + pub struct Model { + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + Self { + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.expand([4, 4]); + + tensor2 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } +} diff --git a/crates/burn-import/src/burn/node/mod.rs b/crates/burn-import/src/burn/node/mod.rs index b22876d8fd..79021a1d54 100644 --- a/crates/burn-import/src/burn/node/mod.rs +++ b/crates/burn-import/src/burn/node/mod.rs @@ -11,6 +11,7 @@ pub(crate) mod conv1d; pub(crate) mod conv2d; pub(crate) mod conv_transpose_2d; pub(crate) mod dropout; +pub(crate) mod expand; pub(crate) mod gather; pub(crate) mod global_avg_pool; pub(crate) mod layer_norm; diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 331868227e..010939594e 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -29,6 +29,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Equal => equal_update_outputs(node), NodeType::Erf => same_as_input(node), NodeType::Exp => same_as_input(node), + NodeType::Expand => expand_update_outputs(node), NodeType::Flatten => flatten_update_outputs(node), NodeType::Gelu => same_as_input(node), NodeType::GatherElements => same_as_input(node), @@ -429,6 +430,33 @@ fn equal_update_outputs(node: &mut Node) { } } +fn expand_update_outputs(node: &mut Node) { + let shape = if node.inputs.len() == 2 { + match &node.inputs[1].value { + Some(value) => match value { + Data::Int64s(shape) => Some(shape.clone()), + _ => panic!("Expand: invalid input types"), + }, + None => None, + } + } else { + node.attrs.get("shape").cloned().map(|v| v.into_i64s()) + }; + + let output = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), + _ => panic!("Expand: invalid output types"), + }; + + if let Some(shape) = shape { + node.outputs[0].ty = ArgType::Tensor(TensorType { + dim: shape.len(), + shape: None, // shape is calculated at runtime + ..output + }); + } +} + fn shape_update_outputs(node: &mut Node) { if node.inputs.len() != 1 { panic!("Shape: multiple inputs are not supported: {:?}", node); diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index e8300d060d..b760a01597 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -677,6 +677,31 @@ pub fn unsqueeze_config(node: &Node) -> Vec { } } +pub fn expand_config(node: &Node) -> Vec { + for (key, value) in node.attrs.iter() { + match key.as_str() { + "shape" => return value.clone().into_i64s(), + _ => {} + } + } + assert!( + !node.inputs.is_empty(), + "Expand: shape tensor must be present" + ); + let input_value = &node.inputs[1]; + match &node.inputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.dim, 1, "Expand: shape tensor must be 1D"); + if let Some(Data::Int64s(shape)) = input_value.value.as_ref() { + shape.clone() + } else { + panic!("Tensor data type must be int64, got {:?}", input_value) + } + } + _ => panic!("Arg for expand must be tensor or scalar"), + } +} + pub fn clip_config(node: &Node) -> (Option, Option) { let mut min_result: Option = None; let mut max_result: Option = None; diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index d89b374480..5e3b5ea367 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -25,6 +25,7 @@ use crate::{ conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, + expand::ExpandNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, @@ -238,6 +239,7 @@ impl OnnxGraph { NodeType::Equal => graph.register(Self::equal_conversion(node)), NodeType::Erf => graph.register(Self::erf_conversion(node)), NodeType::Exp => graph.register(Self::exp_conversion(node)), + NodeType::Expand => graph.register(Self::expand_conversion(node)), NodeType::Clip => graph.register(Self::clip_conversion(node)), NodeType::Cos => graph.register(Self::cos_conversion(node)), NodeType::Conv1d => graph.register(Self::conv1d_conversion::(node)), @@ -814,6 +816,14 @@ impl OnnxGraph { UnaryNode::exp(input, output) } + fn expand_conversion(node: Node) -> ExpandNode { + let input = node.inputs.first().unwrap().to_tensor_type(); + let output = node.outputs.first().unwrap().to_tensor_type(); + let shape = expand_config(&node); + + ExpandNode::new(input, output, shape) + } + fn neg_conversion(node: Node) -> UnaryNode { let input = node.inputs.first().unwrap().to_type(); let output = node.outputs.first().unwrap().to_type(); From e832a6836f1a0316df6b7cf74de8d483993bcd6b Mon Sep 17 00:00:00 2001 From: Jachym Putta Date: Tue, 28 May 2024 18:12:54 -0400 Subject: [PATCH 3/5] feat: fix model generation --- .../onnx-tests/tests/expand/expand.onnx | Bin 232 -> 158 bytes .../onnx-tests/tests/expand/expand.py | 34 ++++-------------- crates/burn-import/src/burn/node/base.rs | 1 + 3 files changed, 7 insertions(+), 28 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/expand/expand.onnx b/crates/burn-import/onnx-tests/tests/expand/expand.onnx index 55bd417529d253d3a30e9bb671bdf3425856be1e..4bb7aa8fb89438978ae828b45cd48f8ff3bdebe1 100644 GIT binary patch delta 81 zcmaFCIFFH+gG-1lwW1&~FU6{tabmxk diff --git a/crates/burn-import/onnx-tests/tests/expand/expand.py b/crates/burn-import/onnx-tests/tests/expand/expand.py index 4682446cf0..c22f9e54a4 100644 --- a/crates/burn-import/onnx-tests/tests/expand/expand.py +++ b/crates/burn-import/onnx-tests/tests/expand/expand.py @@ -6,31 +6,6 @@ from onnx import helper, TensorProto def main() -> None: - # Create a constant node for the input tensor - input_node: onnx.NodeProto = helper.make_node( - 'Constant', - inputs=[], - outputs=['input_tensor'], - value=helper.make_tensor( - name='const_input', - data_type=TensorProto.FLOAT, - dims=[2, 1], - vals=[1.0, 2.0] - ) - ) - - # Create a constant node for the shape tensor which specifies the expansion - shape_node: onnx.NodeProto = helper.make_node( - 'Constant', - inputs=[], - outputs=['shape_tensor'], - value=helper.make_tensor( - name='const_shape', - data_type=TensorProto.INT64, - dims=[2], - vals=[2, 2] # Expanding each dimension to have 2 elements - ) - ) # Define the Expand node that uses the outputs from the constant nodes expand_node: onnx.NodeProto = helper.make_node( @@ -41,12 +16,15 @@ def main() -> None: # Create the graph graph_def: onnx.GraphProto = helper.make_graph( - nodes=[input_node, shape_node, expand_node], + nodes=[expand_node], name='ExpandGraph', - inputs=[], # No inputs since all are provided by constants within the graph + inputs=[ + helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info('shape_tensor', TensorProto.INT64, [1]), + ], # No inputs since all are provided by constants within the graph outputs=[ helper.make_tensor_value_info('output', TensorProto.FLOAT, [2, 2]) - ] + ], ) # Create the model diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index 36e78d9eca..867276991f 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -1,3 +1,4 @@ +use super::expand::ExpandNode; use super::{ avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode, From da8a52205b9ad4836f264d1eb3bc8e69da7b25eb Mon Sep 17 00:00:00 2001 From: Jachym Putta Date: Thu, 30 May 2024 16:57:23 -0400 Subject: [PATCH 4/5] WIP: value not propagating --- .../onnx-tests/tests/expand/expand.onnx | 20 ++++++------ .../onnx-tests/tests/expand/expand.py | 31 ++++++++++++++----- crates/burn-import/src/burn/node/expand.rs | 27 ++++++---------- .../burn-import/src/onnx/op_configuration.rs | 15 +++++++++ crates/burn-import/src/onnx/to_burn.rs | 4 +-- 5 files changed, 59 insertions(+), 38 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/expand/expand.onnx b/crates/burn-import/onnx-tests/tests/expand/expand.onnx index 010daef317..d09cd73e61 100644 --- a/crates/burn-import/onnx-tests/tests/expand/expand.onnx +++ b/crates/burn-import/onnx-tests/tests/expand/expand.onnx @@ -1,17 +1,15 @@  -expand: -5 +expand: +>shapeshape_constant"Constant* +value*:Bshape +. input_tensor - shape_tensoroutput/Expand"Expand ExpandGraphZ +shapeoutput/Expand"Expand ExpandGraphZ input_tensor  -Z - shape_tensor - - -b -output - - +b +output +  + B \ No newline at end of file diff --git a/crates/burn-import/onnx-tests/tests/expand/expand.py b/crates/burn-import/onnx-tests/tests/expand/expand.py index c407b036c5..1d8be8437e 100644 --- a/crates/burn-import/onnx-tests/tests/expand/expand.py +++ b/crates/burn-import/onnx-tests/tests/expand/expand.py @@ -6,30 +6,45 @@ from onnx import helper, TensorProto def main() -> None: + # Define the shape tensor as a constant node + shape_value = [2, 2] # Example shape value + shape_tensor = helper.make_tensor( + name='shape', + data_type=TensorProto.INT64, + dims=[len(shape_value)], + vals=shape_value, + ) + + shape_node = helper.make_node( + 'Constant', + name='shape_constant', + inputs=[], + outputs=['shape'], + value=shape_tensor, + ) # Define the Expand node that uses the outputs from the constant nodes - expand_node: onnx.NodeProto = helper.make_node( + expand_node = helper.make_node( 'Expand', name='/Expand', - inputs=['input_tensor', 'shape_tensor'], + inputs=['input_tensor', 'shape'], outputs=['output'] ) # Create the graph - graph_def: onnx.GraphProto = helper.make_graph( - nodes=[expand_node], + graph_def = helper.make_graph( + nodes=[shape_node, expand_node], name='ExpandGraph', inputs=[ helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [2]), - helper.make_tensor_value_info('shape_tensor', TensorProto.INT64, [1]), - ], + ], outputs=[ - helper.make_tensor_value_info('output', TensorProto.FLOAT, [2]) + helper.make_tensor_value_info('output', TensorProto.FLOAT, [2, 2]) ], ) # Create the model - model_def: onnx.ModelProto = helper.make_model(graph_def, producer_name='expand') + model_def = helper.make_model(graph_def, producer_name='expand') # Save the model to a file onnx.save(model_def, 'expand.onnx') diff --git a/crates/burn-import/src/burn/node/expand.rs b/crates/burn-import/src/burn/node/expand.rs index db0532e0f5..25f8de7928 100644 --- a/crates/burn-import/src/burn/node/expand.rs +++ b/crates/burn-import/src/burn/node/expand.rs @@ -1,6 +1,5 @@ use super::{Node, NodeCodegen}; -use crate::burn::{Scope, TensorType, Type}; -use burn::prelude::Shape; +use crate::burn::{Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; @@ -9,7 +8,7 @@ use quote::quote; pub struct ExpandNode { pub input: TensorType, pub output: TensorType, - pub shape: TensorType, + pub shape: Vec, } impl NodeCodegen for ExpandNode { @@ -18,19 +17,16 @@ impl NodeCodegen for ExpandNode { } fn input_types(&self) -> Vec { - vec![ - Type::Tensor(self.input.clone()), - Type::Tensor(self.shape.clone()), - ] + vec![Type::Tensor(self.input.clone())] } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { let input = scope.tensor_use_owned(&self.input, node_position); - let shape = scope.tensor_use_owned(&self.shape, node_position); + let shape = &self.shape.to_tokens(); let output = &self.output.name; quote! { - let #output = #input.expand(Shape::new(#shape)); + let #output = #input.expand(#shape); } } @@ -57,13 +53,10 @@ mod tests { graph.register(ExpandNode::new( TensorType::new_float("tensor1", 4), TensorType::new_float("tensor2", 4), - TensorType::new_int("tensor3", 1), + [4, 4, 4, 4].into(), )); - graph.register_input_output( - vec!["tensor1".to_string(), "tensor3".to_string()], - vec!["tensor2".to_string()], - ); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); let expected = quote! { use burn::{ @@ -86,10 +79,10 @@ mod tests { } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward(&self, tensor1: Tensor, tensor2: Tensor) -> Tensor { - let tensor3 = tensor1.expand(Shape::new(tensor2)); + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = tensor1.expand([4,4,4,4]); - tensor3 + tensor2 } } }; diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index e8300d060d..37d13edb1e 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -262,6 +262,21 @@ pub fn avg_pool2d_config(curr: &Node) -> AvgPool2dConfig { .with_count_include_pad(count_include_pad == 1) } +pub fn expand_config(node: &Node) -> Vec { + let input_value = &node.inputs[1].value; + match &node.inputs[1].ty { + ArgType::Tensor(tensor) => { + assert_eq!(tensor.dim, 1, "Expand: shape tensor must be 1D"); + if let Some(Data::Int64s(shape)) = input_value.as_ref() { + shape.clone() + } else { + panic!("Tensor data type must be int64") + } + } + _ => panic!("Only tensor input is valid for shape"), + } +} + /// Create a FlattenConfig from the attributes of the node pub fn flatten_config(curr: &Node) -> (usize, usize) { // the begin dimension is the first dimension (Default: 1 per ONNX spec) diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 9861aafddb..3dc4cb93c2 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -876,10 +876,10 @@ impl OnnxGraph { } fn expand_conversion(node: Node) -> ExpandNode { - println!("{:?}", node); let input = node.inputs.first().unwrap().to_tensor_type(); - let shape = node.inputs.get(1).unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type(); + println!("{:?}", node); + let shape = expand_config(&node); ExpandNode::new(input, output, shape) } From c282d0806d8e491779b7711f2385ce4037f6f7a3 Mon Sep 17 00:00:00 2001 From: Jachym Putta Date: Fri, 31 May 2024 13:32:14 -0400 Subject: [PATCH 5/5] feat: added expand to import --- .../burn-import/onnx-tests/tests/expand/expand.onnx | 12 ++++++------ crates/burn-import/onnx-tests/tests/expand/expand.py | 2 +- crates/burn-import/onnx-tests/tests/onnx_tests.rs | 11 +++++------ crates/burn-import/src/onnx/dim_inference.rs | 2 +- crates/burn-import/src/onnx/from_onnx.rs | 3 ++- crates/burn-import/src/onnx/to_burn.rs | 1 - 6 files changed, 15 insertions(+), 16 deletions(-) diff --git a/crates/burn-import/onnx-tests/tests/expand/expand.onnx b/crates/burn-import/onnx-tests/tests/expand/expand.onnx index d09cd73e61..a3bc85bf2d 100644 --- a/crates/burn-import/onnx-tests/tests/expand/expand.onnx +++ b/crates/burn-import/onnx-tests/tests/expand/expand.onnx @@ -1,14 +1,14 @@  -expand: +expand: >shapeshape_constant"Constant* value*:Bshape . input_tensor -shapeoutput/Expand"Expand ExpandGraphZ - input_tensor - - -b +shapeoutput/Expand"Expand ExpandGraphZ + input_tensor +  + +b output   diff --git a/crates/burn-import/onnx-tests/tests/expand/expand.py b/crates/burn-import/onnx-tests/tests/expand/expand.py index 1d8be8437e..45e9c231c8 100644 --- a/crates/burn-import/onnx-tests/tests/expand/expand.py +++ b/crates/burn-import/onnx-tests/tests/expand/expand.py @@ -36,7 +36,7 @@ def main() -> None: nodes=[shape_node, expand_node], name='ExpandGraph', inputs=[ - helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [2]), + helper.make_tensor_value_info('input_tensor', TensorProto.FLOAT, [2, 1]), ], outputs=[ helper.make_tensor_value_info('output', TensorProto.FLOAT, [2, 2]) diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index e262ec151d..33822f426d 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -1124,13 +1124,12 @@ mod tests { let device = Default::default(); let model: expand::Model = expand::Model::new(&device); - let input1 = Tensor::::from_floats([[[[-1.0, 1.0, 42.0, 3.0]]]], &device); - let input2 = Tensor::::from_ints([3, 2], &device); + let input1 = Tensor::::from_floats([[-1.0], [1.0]], &device); - // let output = model.forward(input1, input2); - // let expected_shape = Shape::from([3, 2]); - // - // assert_eq!(output.shape(), expected_shape); + let output = model.forward(input1); + let expected_shape = Shape::from([2, 2]); + + assert_eq!(output.shape(), expected_shape); } #[test] diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index 99687a13fc..f2780be837 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -481,7 +481,7 @@ fn expand_update_outputs(node: &mut Node) { None => None, } } else { - node.attrs.get("shape").cloned().map(|v| v.into_i64s()) + panic!("Expand: invalid number of inputs"); }; let output = match &node.outputs[0].ty { diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/burn-import/src/onnx/from_onnx.rs index f2517ccef4..fbcba0dbc9 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/burn-import/src/onnx/from_onnx.rs @@ -17,12 +17,13 @@ use super::ir::{ArgType, Argument, Node, NodeType}; use protobuf::Message; -const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 9] = [ +const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [ NodeType::BatchNormalization, NodeType::Clip, NodeType::Conv1d, NodeType::Conv2d, NodeType::Dropout, + NodeType::Expand, NodeType::Reshape, NodeType::Unsqueeze, NodeType::ReduceSum, diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index 3dc4cb93c2..7b59123d51 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -878,7 +878,6 @@ impl OnnxGraph { fn expand_conversion(node: Node) -> ExpandNode { let input = node.inputs.first().unwrap().to_tensor_type(); let output = node.outputs.first().unwrap().to_tensor_type(); - println!("{:?}", node); let shape = expand_config(&node); ExpandNode::new(input, output, shape)