-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/refactor relative epsilon #602
Open
michalk8
wants to merge
10
commits into
main
Choose a base branch
from
feature/refactor-relative-epsilon
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
f14846c
Don't install tensorstore on 3.13 yet
michalk8 b4f7245
Fix epsilon scheduler
michalk8 28da1be
Undo changes in the test
michalk8 ade9aea
Fix eps sched docs
michalk8 1da51f6
Remove mention of `scale_epsilon`
michalk8 3ee8bc0
Remove mention of norms in point cloud
michalk8 2b1edc3
Nicer pointcloud docs
michalk8 723f6eb
Update docs of relative epsilon
michalk8 21ce601
Update geometry.py
marcocuturi 86aa841
Update epsilon_scheduler.py
marcocuturi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,95 +11,63 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Optional | ||
from typing import Optional | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import jax.tree_util as jtu | ||
|
||
__all__ = ["Epsilon", "DEFAULT_SCALE"] | ||
|
||
#: Scaling applied to statistic (mean/std) of cost to compute default epsilon. | ||
DEFAULT_SCALE = 0.05 | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
@jtu.register_pytree_node_class | ||
class Epsilon: | ||
"""Scheduler class for the regularization parameter epsilon. | ||
r"""Scheduler class for the regularization parameter epsilon. | ||
|
||
An epsilon scheduler outputs a regularization strength, to be used by in a | ||
Sinkhorn-type algorithm, at any iteration count. That value is either the | ||
final, targeted regularization, or one that is larger, obtained by | ||
geometric decay of an initial value that is larger than the intended target. | ||
Concretely, the value returned by such a scheduler will consider first | ||
the max between ``target`` and ``init * target * decay ** iteration``. | ||
If the ``scale_epsilon`` parameter is provided, that value is used to | ||
multiply the max computed previously by ``scale_epsilon``. | ||
An epsilon scheduler outputs a regularization strength, to be used by the | ||
:term:`Sinkhorn algorithm` or variant, at any iteration count. That value is | ||
either the final, targeted regularization, or one that is larger, obtained by | ||
geometric decay of an initial multiplier. | ||
|
||
Args: | ||
target: the epsilon regularizer that is targeted. If :obj:`None`, | ||
use :obj:`DEFAULT_SCALE`, currently set at :math:`0.05`. | ||
scale_epsilon: if passed, used to multiply the regularizer, to rescale it. | ||
If :obj:`None`, use :math:`1`. | ||
init: initial value when using epsilon scheduling, understood as multiple | ||
of target value. if passed, ``int * decay ** iteration`` will be used | ||
to rescale target. | ||
decay: geometric decay factor, :math:`<1`. | ||
target: The epsilon regularizer that is targeted. | ||
init: Initial value when using epsilon scheduling, understood as a multiple | ||
of the ``target``, following :math:`\text{init} \text{decay}^{\text{it}}`. | ||
decay: Geometric decay factor, :math:`\leq 1`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
target: Optional[float] = None, | ||
scale_epsilon: Optional[float] = None, | ||
init: float = 1.0, | ||
decay: float = 1.0 | ||
): | ||
self._target_init = target | ||
self._scale_epsilon = scale_epsilon | ||
self._init = init | ||
self._decay = decay | ||
def __init__(self, target: jnp.array, init: float = 1.0, decay: float = 1.0): | ||
assert decay <= 1.0, f"Decay must be <= 1, found {decay}." | ||
self.target = target | ||
self.init = init | ||
self.decay = decay | ||
|
||
@property | ||
def target(self) -> float: | ||
"""Return the final regularizer value of scheduler.""" | ||
target = DEFAULT_SCALE if self._target_init is None else self._target_init | ||
scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon | ||
return scale * target | ||
def __call__(self, it: Optional[int]) -> jnp.array: | ||
"""Intermediate regularizer value at a given iteration number. | ||
|
||
def at(self, iteration: Optional[int] = 1) -> float: | ||
"""Return (intermediate) regularizer value at a given iteration.""" | ||
if iteration is None: | ||
Args: | ||
it: Current iteration. If :obj:`None`, return :attr:`target`. | ||
|
||
Returns: | ||
The epsilon value at the iteration. | ||
""" | ||
if it is None: | ||
return self.target | ||
# check the decay is smaller than 1.0. | ||
decay = jnp.minimum(self._decay, 1.0) | ||
# the multiple is either 1.0 or a larger init value that is decayed. | ||
multiple = jnp.maximum(self._init * (decay ** iteration), 1.0) | ||
multiple = jnp.maximum(self.init * (self.decay ** it), 1.0) | ||
return multiple * self.target | ||
|
||
def done(self, eps: float) -> bool: | ||
"""Return whether the scheduler is done at a given value.""" | ||
return eps == self.target | ||
|
||
def done_at(self, iteration: Optional[int]) -> bool: | ||
"""Return whether the scheduler is done at a given iteration.""" | ||
return self.done(self.at(iteration)) | ||
|
||
def set(self, **kwargs: Any) -> "Epsilon": | ||
"""Return a copy of self, with potential overwrites.""" | ||
kwargs = { | ||
"target": self._target_init, | ||
"scale_epsilon": self._scale_epsilon, | ||
"init": self._init, | ||
"decay": self._decay, | ||
**kwargs | ||
} | ||
return Epsilon(**kwargs) | ||
def __repr__(self) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cool! |
||
return ( | ||
f"{self.__class__.__name__}(target={self.target:.4f}, " | ||
f"init={self.init:.4f}, decay={self.decay:.4f})" | ||
) | ||
|
||
def tree_flatten(self): # noqa: D102 | ||
return ( | ||
self._target_init, self._scale_epsilon, self._init, self._decay | ||
), None | ||
return (self.target,), {"init": self.init, "decay": self.decay} | ||
|
||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): # noqa: D102 | ||
del aux_data | ||
return cls(*children) | ||
return cls(*children, **aux_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should move
DEFAULT_SCALE
togeometry.py
since this is where it's used? maybe also change its name? I could suggestDEFAULT_EPSILON_SCALE