diff --git a/onnx2torch/node_converters/range.py b/onnx2torch/node_converters/range.py index feef661e..8535e613 100644 --- a/onnx2torch/node_converters/range.py +++ b/onnx2torch/node_converters/range.py @@ -10,12 +10,13 @@ from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode -from onnx2torch.utils.common import OnnxToTorchModule from onnx2torch.utils.common import OperationConverterResult from onnx2torch.utils.common import onnx_mapping_from_node +from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx +from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport -class OnnxRange(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring +class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring def __init__(self): super().__init__() self.register_buffer('dummy_buffer', torch.Tensor(), persistent=False) @@ -27,7 +28,7 @@ def _get_scalar(value) -> Union[float, int]: return value - def forward( # pylint: disable=missing-function-docstring + def _arange( self, start: Union[torch.Tensor, float, int], limit: Union[torch.Tensor, float, int], @@ -40,6 +41,19 @@ def forward( # pylint: disable=missing-function-docstring device=self.dummy_buffer.device, ) + def forward( # pylint: disable=missing-function-docstring + self, + start: Union[torch.Tensor, float, int], + limit: Union[torch.Tensor, float, int], + delta: Union[torch.Tensor, float, int], + ) -> torch.Tensor: + forward_lambda = lambda: self._arange(start, limit, delta) + + if torch.onnx.is_in_onnx_export(): + return DefaultExportToOnnx.export(forward_lambda, 'Range', start, limit, delta, {}) + + return forward_lambda() + @add_converter(operation_type='Range', version=11) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument