You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I pass tensors with gradients into the forward function of sys (an object of class BuildingEnvelope), the gradients get removed at the point where get_q is called. This happens even if sys is instantiated with the kwargs backend='torch', requires_grad=True. This affects any system which uses BuildingEnvelope, as in the following setup, the gradients from loss computed with 'yn' cannot propagate back to policy (a blocks.MLP_bounds).
The cause of this issue seems to be that BuildingEnvelope.get_q is wrapped by @cast_backend, which calls torch.tensor(return_tensor, dtype=torch.float32) on the tensor returned by get_q, which removes its gradient. If this line is removed, the gradients are able to propagate and the policy can be trained normally.
The text was updated successfully, but these errors were encountered:
When I pass tensors with gradients into the forward function of sys (an object of class BuildingEnvelope), the gradients get removed at the point where get_q is called. This happens even if sys is instantiated with the kwargs
backend='torch', requires_grad=True
. This affects any system which uses BuildingEnvelope, as in the following setup, the gradients from loss computed with 'yn' cannot propagate back to policy (a blocks.MLP_bounds).The cause of this issue seems to be that BuildingEnvelope.get_q is wrapped by
@cast_backend
, which callstorch.tensor(return_tensor, dtype=torch.float32)
on the tensor returned by get_q, which removes its gradient. If this line is removed, the gradients are able to propagate and the policy can be trained normally.The text was updated successfully, but these errors were encountered: