Skip to content

Commit

Permalink
More NONDET_TOL
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Aug 26, 2024
1 parent 9847814 commit d19679f
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 23 deletions.
6 changes: 3 additions & 3 deletions test/test_coulomb/test_es2_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from dxtb._src.param.utils import get_elem_param
from dxtb._src.typing import DD, Tensor

from ..conftest import DEVICE
from ..conftest import DEVICE, NONDET_TOL
from .samples import samples

sample_list = ["MB16_43_07", "MB16_43_08", "SiH4"]
Expand Down Expand Up @@ -129,7 +129,7 @@ def func(p: Tensor):
cache = es.get_cache(numbers, p, ihelp)
return es.get_shell_energy(qsh, cache)

assert dgradcheck(func, pos)
assert dgradcheck(func, pos, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand Down Expand Up @@ -160,4 +160,4 @@ def func(gexp: Tensor, hubbard: Tensor):
cache = es.get_cache(numbers, positions, ihelp)
return es.get_shell_energy(qsh, cache)

assert dgradcheck(func, (gexp, hubbard))
assert dgradcheck(func, (gexp, hubbard), nondet_tol=NONDET_TOL)
4 changes: 2 additions & 2 deletions test/test_coulomb/test_es3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.param.utils import get_elem_param
from dxtb._src.typing import DD, Tensor

from ..conftest import DEVICE
from ..conftest import DEVICE, NONDET_TOL
from .samples import samples

sample_list = ["MB16_43_01", "MB16_43_02", "SiH4_atom"]
Expand Down Expand Up @@ -120,4 +120,4 @@ def func(hubbard_derivs: Tensor):
cache = es.get_cache(numbers=numbers, ihelp=ihelp)
return es.get_atom_energy(qat, cache)

assert dgradcheck(func, hd)
assert dgradcheck(func, hd, nondet_tol=NONDET_TOL)
4 changes: 0 additions & 4 deletions test/test_coulomb/test_grad_atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ def test_single(dtype: torch.dtype, name: str) -> None:
assert pytest.approx(ref.cpu(), abs=tol) == egrad.cpu()
assert pytest.approx(egrad.cpu(), abs=tol) == agrad.cpu()

pos.detach_()


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("name1", ["SiH4_atom"])
Expand Down Expand Up @@ -145,8 +143,6 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
assert pytest.approx(ref.cpu(), abs=tol) == egrad.cpu()
assert pytest.approx(egrad.cpu(), abs=tol) == agrad.cpu()

pos.detach_()


def calc_numerical_gradient(
numbers: Tensor, positions: Tensor, ihelp: IndexHelper, charges: Tensor
Expand Down
10 changes: 5 additions & 5 deletions test/test_coulomb/test_grad_atom_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.param import get_elem_param
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ..conftest import DEVICE, NONDET_TOL
from .samples import samples

sample_list = ["LiH", "SiH4", "MB16_43_01"]
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_grad_param(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradcheck_param(dtype, name)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -93,7 +93,7 @@ def test_gradgrad_param(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradcheck_param(dtype, name)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


def gradcheck_param_batch(dtype: torch.dtype, name1: str, name2: str) -> tuple[
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_grad_param_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
# same for both values.
diffvars[0].requires_grad_(False)

assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -177,4 +177,4 @@ def test_gradgrad_param_batch(dtype: torch.dtype, name1: str, name2: str) -> Non
# same for both values.
diffvars[0].requires_grad_(False)

assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)
10 changes: 5 additions & 5 deletions test/test_coulomb/test_grad_atom_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dxtb._src.param import get_elem_param
from dxtb._src.typing import DD, Callable, Tensor

from ..conftest import DEVICE
from ..conftest import DEVICE, NONDET_TOL
from .samples import samples

sample_list = ["LiH", "SiH4", "MB16_43_01"]
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_grad_pos(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradcheck_pos(dtype, name)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -97,7 +97,7 @@ def test_gradgrad_pos(dtype: torch.dtype, name: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradcheck_pos(dtype, name)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


def gradcheck_pos_batch(
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_grad_pos_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradcheck`.
"""
func, diffvars = gradcheck_pos_batch(dtype, name1, name2)
assert dgradcheck(func, diffvars, atol=tol)
assert dgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)


@pytest.mark.grad
Expand All @@ -169,4 +169,4 @@ def test_gradgrad_pos_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
gradient from `torch.autograd.gradgradcheck`.
"""
func, diffvars = gradcheck_pos_batch(dtype, name1, name2)
assert dgradgradcheck(func, diffvars, atol=tol)
assert dgradgradcheck(func, diffvars, atol=tol, nondet_tol=NONDET_TOL)
4 changes: 0 additions & 4 deletions test/test_coulomb/test_grad_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ def test_single(dtype: torch.dtype, name: str) -> None:
assert pytest.approx(ref.cpu(), abs=tol) == egrad.cpu()
assert pytest.approx(egrad.cpu(), abs=tol) == agrad.cpu()

pos.detach_()


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("name1", ["SiH4"])
Expand Down Expand Up @@ -144,8 +142,6 @@ def test_batch(dtype: torch.dtype, name1: str, name2: str) -> None:
assert pytest.approx(ref.cpu(), abs=tol) == egrad.cpu()
assert pytest.approx(egrad.cpu(), abs=tol) == agrad.cpu()

pos.detach_()


def calc_numerical_gradient(
numbers: Tensor, positions: Tensor, ihelp: IndexHelper, charges: Tensor
Expand Down

0 comments on commit d19679f

Please sign in to comment.