Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

converting sympy.NumberSymbol to torch.tensor in export_torch.py #726

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

tbuckworth
Copy link

attempting to address #656

@MilesCranmer
Copy link
Owner

Nice! Do you want to add a unit test for the MWE you described in the issue?

@coveralls
Copy link

coveralls commented Sep 26, 2024

Pull Request Test Coverage Report for Build 11200885649

Details

  • 1 of 1 (100.0%) changed or added relevant line in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage remained the same at 93.735%

Totals Coverage Status
Change from base Build 10518940981: 0.0%
Covered Lines: 1137
Relevant Lines: 1213

💛 - Coveralls

@tbuckworth
Copy link
Author

Hi I've added a unit test to here, let me know if that is sufficient or if i need to do anything else

@MilesCranmer
Copy link
Owner

Seems like the test is failing

@tbuckworth
Copy link
Author

Apologies, I won't be able to address it until next week

@MilesCranmer
Copy link
Owner

No worries!

@tbuckworth
Copy link
Author

I figured out what's going on.

sin(sign(-0.04)) gets simplified to -sin(1)

sin(1) requires converting the argument 1, which is a sympy.core.numbers.One instance, this is a subclass of sympy.Rational, so it is picked up by that line and converted into float(1).

But torch.sin requires torch.tensor as input.

So i have added a line to make sympy.core.numbers.One get treated the same way as sympy.Float subclasses.
All the torch tests pass on my machine, so I think this is fine behaviour.

perhaps you would rather fix it at a different level of abstraction?

Copy link
Owner

@MilesCranmer MilesCranmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the underlying issue is the use of issubclass instead of isinstance... Maybe try replacing those conditions with the following code?

def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
    super().__init__(**kwargs)

    self._sympy_func = expr.func

    if isinstance(expr, sympy.Float):
        self._value = torch.nn.Parameter(torch.tensor(float(expr)))
        self._torch_func = lambda: self._value
        self._args = ()
    elif isinstance(expr, sympy.Rational) and not isinstance(expr, sympy.Integer):
        # This is some fraction fixed in the operator.
        self._value = float(expr)
        self._torch_func = lambda: self._value
        self._args = ()
    elif isinstance(expr, sympy.UnevaluatedExpr):
        if len(expr.args) != 1 or not isinstance(expr.args[0], sympy.Float):
            raise ValueError(
                "UnevaluatedExpr should only be used to wrap floats."
            )
        self.register_buffer("_value", torch.tensor(float(expr.args[0])))
        self._torch_func = lambda: self._value
        self._args = ()
    elif isinstance(expr, sympy.Integer):
        # Handles Integer special cases like NegativeOne, One, Zero
        self._value = int(expr)
        self._torch_func = lambda: self._value
        self._args = ()
    elif isinstance(expr, sympy.NumberSymbol):
        # Handles mathematical constants like pi, E
        self._value = float(expr)
        self._torch_func = lambda: self._value
        self._args = ()
    elif isinstance(expr, sympy.Symbol):
        self._name = expr.name
        self._torch_func = lambda value: value
        self._args = ((lambda memodict: memodict[expr.name]),)
    else:
        try:
            self._torch_func = _func_lookup[expr.func]
        except KeyError:
            raise KeyError(
                f"Function {expr.func} was not found in Torch function mappings. "
                "Please add it to extra_torch_mappings in the format, e.g., "
                "{sympy.sqrt: torch.sqrt}."
            )
        args = []
        for arg in expr.args:
            try:
                arg_ = _memodict[arg]
            except KeyError:
                arg_ = type(self)(
                    expr=arg,
                    _memodict=_memodict,
                    _func_lookup=_func_lookup,
                    **kwargs,
                )
                _memodict[arg] = arg_
            args.append(arg_)
        self._args = torch.nn.ModuleList(args)

@tbuckworth
Copy link
Author

Thanks!

That code gives the following error

TypeError: sin(): argument 'input' (position 1) must be Tensor, not int

note that it is now saying int instead of float.

that's due to this code:

elif isinstance(expr, sympy.Integer):
    # Handles Integer special cases like NegativeOne, One, Zero
    self._value = int(expr)
    self._torch_func = lambda: self._value
    self._args = ()

is there a reason that ints need to be treated this way?

for the sin(1) case, torch.sin requires that 1 is a torch.tensor.

would it be ok to just add the following code before the integer case?

elif isinstance(expr, sympy.core.numbers.One):
    # Handles Integer special cases like NegativeOne, One, Zero
    self._value = torch.tensor(int(expr))
    self._torch_func = lambda: self._value
    self._args = ()

or should it be torch.nn.Parameter(torch.tensor(float(expr))) so that it can be fine-tuned?

either way, it then passes the test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants