Skip to content

Commit

Permalink
Add Lie group check setting to suppress internal warnings
Browse files Browse the repository at this point in the history
 ## Problem

The Manifold class will automatically normalize Tensors to make
them valid for their associated Lie group. A warning is emitted
when this occurs, which can be somewhat noisy and undesirable.

Issue: facebookresearch#486

 ## Solution

Add setting to _LieGroupCheckContext which, when set to True, can
be used in downstream code to suppress various warnings.

 ## Testing

Ensure existing unit tests pass.
  • Loading branch information
jacoblubecki committed Nov 9, 2024
1 parent ab57673 commit 969bcb7
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 52 deletions.
97 changes: 57 additions & 40 deletions theseus/geometry/lie_group_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.

import threading
from typing import Any
from contextlib import contextmanager


class _LieGroupCheckContext:
Expand All @@ -15,51 +15,68 @@ def get_context(cls):
if not hasattr(cls.contexts, "check_lie_group"):
cls.contexts.check_lie_group = True
cls.contexts.silent = False
return cls.contexts.check_lie_group, cls.contexts.silent
cls.contexts.silence_internal_warnings = False
return (
cls.contexts.check_lie_group,
cls.contexts.silent,
cls.contexts.silence_internal_warnings,
)

@classmethod
def set_context(cls, check_lie_group: bool, silent: bool):
def set_context(
cls, check_lie_group: bool, silent: bool, silence_internal_warnings: bool
):
if not check_lie_group and not silent:
print(
"Warnings for disabled Lie group checks can be turned "
"off by passing silent=True."
)
cls.contexts.check_lie_group = check_lie_group
cls.contexts.silent = silent


class set_lie_group_check_enabled:
def __init__(self, mode: bool, silent: bool = False) -> None:
self.prev = _LieGroupCheckContext.get_context()
_LieGroupCheckContext.set_context(mode, silent)

def __enter__(self) -> None:
pass

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
_LieGroupCheckContext.set_context(*self.prev)


class enable_lie_group_check(_LieGroupCheckContext):
def __init__(self, silent: bool = False) -> None:
self._silent = silent

def __enter__(self) -> None:
self.prev = _LieGroupCheckContext.get_context()
_LieGroupCheckContext.set_context(True, self._silent)

def __exit__(self, typ, value, traceback) -> None:
_LieGroupCheckContext.set_context(*self.prev)


class no_lie_group_check(_LieGroupCheckContext):
def __init__(self, silent: bool = False) -> None:
self._silent = silent

def __enter__(self):
self.prev = super().get_context()
_LieGroupCheckContext.set_context(False, self._silent)
return self

def __exit__(self, typ, value, traceback):
_LieGroupCheckContext.set_context(*self.prev)
cls.contexts.silence_internal_warnings = silence_internal_warnings


@contextmanager
def set_lie_group_check_enabled(
mode: bool, silent: bool = False, silence_internal_warnings: bool = False
):
"""Sets whether or not Lie group checks are enabled within a context.
:param check_lie_group: Disables Lie group checks if false.
:param silent: Disables a warning that Lie group checks are disabled.
:param silence_internal_warnings: Whether to suppress recoverable
warning messages during Lie group checks, e.g. when normalization
is performed automatically.
"""
prev = _LieGroupCheckContext.get_context()
_LieGroupCheckContext.set_context(mode, silent, silence_internal_warnings)
yield
_LieGroupCheckContext.set_context(*prev)


@contextmanager
def enable_lie_group_check(
silent: bool = False, silence_internal_warnings: bool = False
):
"""Enables Lie group checks while the context is active.
:param silent: Disables a warning that Lie group checks are disabled.
:param silence_internal_warnings: Whether to suppress recoverable
warning messages during Lie group checks, e.g. when normalization
is performed automatically.
"""
with set_lie_group_check_enabled(True, silent, silence_internal_warnings):
yield


@contextmanager
def no_lie_group_check(silent: bool = False, silence_internal_warnings: bool = False):
"""Disables Lie group checks while the context is active.
:param silent: Disables a warning that Lie group checks are disabled.
:param silence_internal_warnings: Whether to suppress recoverable
warning messages during Lie group checks, e.g. when normalization
is performed automatically.
"""
with set_lie_group_check_enabled(False, silent, silence_internal_warnings):
yield
32 changes: 24 additions & 8 deletions theseus/geometry/manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,21 @@ def __init__(
dtype = torch.get_default_dtype()
if tensor is not None:
if disable_checks:
checks_enabled, silent_unchecks = False, True
checks_enabled = False
silent_unchecks = True
disable_internal_warnings = True
else:
checks_enabled, silent_unchecks = _LieGroupCheckContext.get_context()
(
checks_enabled,
silent_unchecks,
disable_internal_warnings,
) = _LieGroupCheckContext.get_context()
if checks_enabled:
tensor = self._check_tensor(tensor, strict_checks)
tensor = self._check_tensor(
tensor,
strict_checks,
silent_normalization=disable_internal_warnings,
)
elif not silent_unchecks:
warnings.warn(
f"Manifold consistency checks are disabled "
Expand Down Expand Up @@ -113,18 +123,24 @@ def _check_tensor_impl(tensor: torch.Tensor) -> bool:
pass

@classmethod
def _check_tensor(cls, tensor: torch.Tensor, strict: bool = True) -> torch.Tensor:
def _check_tensor(
cls,
tensor: torch.Tensor,
strict: bool = True,
silent_normalization: bool = False,
) -> torch.Tensor:
check = cls._check_tensor_impl(tensor)

if not check:
if strict:
raise ValueError(f"The input tensor is not valid for {cls.__name__}.")
else:
tensor = cls.normalize(tensor)
warnings.warn(
f"The input tensor is not valid for {cls.__name__} "
f"and has been normalized."
)
if not silent_normalization:
warnings.warn(
f"The input tensor is not valid for {cls.__name__} "
f"and has been normalized."
)

return tensor

Expand Down
2 changes: 1 addition & 1 deletion theseus/geometry/se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _hat_matrix_check(matrix: torch.Tensor):
if matrix.ndim != 3 or matrix.shape[1:] != (4, 4):
raise ValueError("Hat matrices of SE3 can only be 4x4 matrices")

checks_enabled, silent_unchecks = _LieGroupCheckContext.get_context()
checks_enabled, silent_unchecks, _ = _LieGroupCheckContext.get_context()
if checks_enabled:
return SE3_base.check_hat_tensor(matrix)
elif not silent_unchecks:
Expand Down
2 changes: 1 addition & 1 deletion theseus/geometry/so2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _check_tensor_impl(tensor: torch.Tensor) -> bool:
def _hat_matrix_check(matrix: torch.Tensor):
_check = matrix.ndim == 3 and matrix.shape[1:] == (2, 2)

checks_enabled, silent_unchecks = _LieGroupCheckContext.get_context()
checks_enabled, silent_unchecks, _ = _LieGroupCheckContext.get_context()
if checks_enabled:
_check &= matrix[:, 0, 0].abs().max().item() < theseus.constants.EPS
_check &= matrix[:, 1, 1].abs().max().item() < theseus.constants.EPS
Expand Down
4 changes: 2 additions & 2 deletions theseus/geometry/so3.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _unit_quaternion_check(quaternion: torch.Tensor):
if quaternion.ndim != 2 or quaternion.shape[1] != 4:
raise ValueError("Quaternions can only be 4-D vectors.")

checks_enabled, silent_unchecks = _LieGroupCheckContext.get_context()
checks_enabled, silent_unchecks, _ = _LieGroupCheckContext.get_context()
if checks_enabled:
SO3_base.check_unit_quaternion(quaternion)
elif not silent_unchecks:
Expand All @@ -144,7 +144,7 @@ def _hat_matrix_check(matrix: torch.Tensor):
if matrix.ndim != 3 or matrix.shape[1:] != (3, 3):
raise ValueError("Hat matrices of SO(3) can only be 3x3 matrices")

checks_enabled, silent_unchecks = _LieGroupCheckContext.get_context()
checks_enabled, silent_unchecks, _ = _LieGroupCheckContext.get_context()
if checks_enabled:
SO3_base.check_hat_tensor(matrix)
elif not silent_unchecks:
Expand Down

0 comments on commit 969bcb7

Please sign in to comment.