Skip to content

Commit

Permalink
FIX: fixed export to onnx (Range) (#161)
Browse files Browse the repository at this point in the history
fixed export to onnx (Range)

---------

Co-authored-by: igor <i.kalgin@expasoft.tech>
  • Loading branch information
ivkalgin and igor authored Jun 1, 2023
1 parent 2ed856f commit 567037e
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions onnx2torch/node_converters/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node
from onnx2torch.utils.custom_export_to_onnx import DefaultExportToOnnx
from onnx2torch.utils.custom_export_to_onnx import OnnxToTorchModuleWithCustomExport


class OnnxRange(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
class OnnxRange(nn.Module, OnnxToTorchModuleWithCustomExport): # pylint: disable=missing-class-docstring
def __init__(self):
super().__init__()
self.register_buffer('dummy_buffer', torch.Tensor(), persistent=False)
Expand All @@ -27,7 +28,7 @@ def _get_scalar(value) -> Union[float, int]:

return value

def forward( # pylint: disable=missing-function-docstring
def _arange(
self,
start: Union[torch.Tensor, float, int],
limit: Union[torch.Tensor, float, int],
Expand All @@ -40,6 +41,19 @@ def forward( # pylint: disable=missing-function-docstring
device=self.dummy_buffer.device,
)

def forward( # pylint: disable=missing-function-docstring
self,
start: Union[torch.Tensor, float, int],
limit: Union[torch.Tensor, float, int],
delta: Union[torch.Tensor, float, int],
) -> torch.Tensor:
forward_lambda = lambda: self._arange(start, limit, delta)

if torch.onnx.is_in_onnx_export():
return DefaultExportToOnnx.export(forward_lambda, 'Range', start, limit, delta, {})

return forward_lambda()


@add_converter(operation_type='Range', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
Expand Down

0 comments on commit 567037e

Please sign in to comment.