diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index 7aa250ce4e..322dcf9219 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -496,11 +496,18 @@ mod tests { [ [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], [11., 12., 13., 14., 15., 16., 17., 18., 19., 20.], + [21., 22., 23., 24., 25., 26., 27., 28., 29., 30.], + [31., 32., 33., 34., 35., 36., 37., 38., 39., 40.], + [41., 42., 43., 44., 45., 46., 47., 48., 49., 50.], ], &device, ); let output = model.forward(input); - let expected = TensorData::from([[1f32, 2., 3., 4., 5.]]); + let expected = TensorData::from([ + [1f32, 2., 3., 4., 5.], + [11f32, 12., 13., 14., 15.], + [21., 22., 23., 24., 25.], + ]); output.to_data().assert_eq(&expected, true); } diff --git a/crates/burn-import/onnx-tests/tests/slice/slice.onnx b/crates/burn-import/onnx-tests/tests/slice/slice.onnx index f8123a417b..437c0be67a 100644 Binary files a/crates/burn-import/onnx-tests/tests/slice/slice.onnx and b/crates/burn-import/onnx-tests/tests/slice/slice.onnx differ diff --git a/crates/burn-import/onnx-tests/tests/slice/slice.py b/crates/burn-import/onnx-tests/tests/slice/slice.py index f02bc771b3..7341617381 100644 --- a/crates/burn-import/onnx-tests/tests/slice/slice.py +++ b/crates/burn-import/onnx-tests/tests/slice/slice.py @@ -5,9 +5,10 @@ import onnx from onnx import helper, TensorProto + def main() -> None: # Starts - starts_val = [0,0] # Example shape value + starts_val = [-5, 0] # Equivalently [0, 0] starts_tensor = helper.make_tensor( name="starts", data_type=TensorProto.INT64, @@ -23,7 +24,7 @@ def main() -> None: ) # Ends - ends_val = [1,5] # Example shape value + ends_val = [3, -5] # Equivalently [3, 5] ends_tensor = helper.make_tensor( name="ends", data_type=TensorProto.INT64, @@ -39,7 +40,7 @@ def main() -> None: ) # Axes - axes_val = [0,1] # Example shape value + axes_val = [0, 1] # Example shape value axes_tensor = helper.make_tensor( name="axes", data_type=TensorProto.INT64, @@ -83,11 +84,9 @@ def main() -> None: nodes=[starts_node, ends_node, axes_node, steps_node, slice_node], name="SliceGraph", inputs=[ - helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]), - ], - outputs=[ - helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5]) + helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [5, 10]), ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 5])], ) # Create the model diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 60a8c93a6b..6904489e89 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -1284,44 +1284,39 @@ pub fn shape_config(curr: &Node) -> (usize, usize) { } pub fn slice_config(node: &Node) -> (Vec, Vec) { - let start_value = &node.inputs[1].value; - let end_value = &node.inputs[2].value; - - let starts = match &node.inputs[1].ty { - ArgType::Tensor(tensor) => { - assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D"); - if let Some(Data::Int64s(shape)) = start_value.as_ref() { - shape - .iter() - .map(|x| { - assert!(*x >= 0, "Slice: start must be positive"); + fn ensure_1d_tensor(node: &Node, index: usize) { + match &node.inputs[index].ty { + ArgType::Tensor(tensor) => assert_eq!(tensor.dim, 1, "Slice: tensor must be 1D"), + _ => panic!("Only tensor input is valid"), + }; + } + + fn get_input_values(node: &Node, index: usize) -> Vec { + let tensor_shape = match &node.inputs[0].ty { + ArgType::Tensor(tensor) => tensor.shape.as_ref().unwrap(), + _ => panic!("Only tensor input is valid"), + }; + match &node.inputs[index].value { + Some(Data::Int64s(shape)) => shape + .iter() + .enumerate() + .map(|(i, x)| { + if x.is_negative() { + tensor_shape[i] - x.wrapping_abs() as usize + } else { *x as usize - }) - .collect() - } else { - panic!("Tensor data type must be int64") - } + } + }) + .collect(), + _ => panic!("Tensor data type must be int64"), } - _ => panic!("Only tensor input is valid for shape"), - }; + } - let ends = match &node.inputs[2].ty { - ArgType::Tensor(tensor) => { - assert_eq!(tensor.dim, 1, "Slice: ends tensor must be 1D"); - if let Some(Data::Int64s(shape)) = end_value.as_ref() { - shape - .iter() - .map(|x| { - assert!(*x >= 0, "Slice: end must be positive"); - *x as usize - }) - .collect() - } else { - panic!("Tensor data type must be int64") - } - } - _ => panic!("Only tensor input is valid for shape"), - }; + ensure_1d_tensor(node, 1); + ensure_1d_tensor(node, 2); + + let starts = get_input_values(node, 1); + let ends = get_input_values(node, 2); for (key, value) in node.attrs.iter() { match key.as_str() {