Skip to content

Commit

Permalink
Fix caching
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Aug 24, 2024
1 parent 63cb82d commit 6f805ac
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 45 deletions.
20 changes: 20 additions & 0 deletions src/dxtb/_src/calculators/config/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ class ConfigCacheStore:
potential: bool
"""Whether to store the potential matrix."""

def set(self, key: str, value: bool) -> None:
"""
Set configuration options using keyword arguments.
Parameters
----------
key : str
The configuration key.
value : bool
The configuration value.
Example
-------
config.set("hcore", True)
"""
if not hasattr(self, key):
raise ValueError(f"Unknown configuration key: {key}")

Check warning on line 81 in src/dxtb/_src/calculators/config/cache.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/calculators/config/cache.py#L81

Added line #L81 was not covered by tests

setattr(self, key, value)

Check warning on line 83 in src/dxtb/_src/calculators/config/cache.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/calculators/config/cache.py#L83

Added line #L83 was not covered by tests


class ConfigCache:
"""
Expand Down
19 changes: 13 additions & 6 deletions src/dxtb/_src/calculators/types/analytical.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ def dipole_analytical(
Tensor
Electric dipole moment of shape `(..., 3)`.
"""
# require caching for analytical calculation at end of function
kwargs["store_dipole"] = True

# run single point and check if integral is populated
result = self.singlepoint(positions, chrg, spin, **kwargs)

Expand All @@ -534,19 +537,23 @@ def dipole_analytical(
f"be added automatically if the '{efield.LABEL_EFIELD}' "
"interaction is added to the Calculator."
)
if dipint.matrix is None:

# Use try except to raise more informative error message, because
# `dipint.matrix` already raises a RuntimeError if the matrix is None.
try:
_ = dipint.matrix
except RuntimeError as e:

Check warning on line 545 in src/dxtb/_src/calculators/types/analytical.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/calculators/types/analytical.py#L545

Added line #L545 was not covered by tests
raise RuntimeError(
"Dipole moment requires a dipole integral. They should "
f"be added automatically if the '{efield.LABEL_EFIELD}' "
"interaction is added to the Calculator."
)
"interaction is added to the Calculator. This is "
"probably a bug. Check the cache setup.\n\n"
f"Original error: {str(e)}"
) from e

# pylint: disable=import-outside-toplevel
from ..properties.moments.dip import dipole

# dip = properties.dipole(
# numbers, positions, result.density, self.integrals.dipole
# )
qat = self.ihelp.reduce_orbital_to_atom(result.charges.mono)
dip = dipole(qat, positions, result.density, dipint.matrix)
return dip
Expand Down
57 changes: 29 additions & 28 deletions src/dxtb/_src/calculators/types/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dxtb import OutputHandler, timer
from dxtb._src.components.interactions.field import efield as efield
from dxtb._src.constants import defaults
from dxtb._src.typing import Any, Literal, Tensor
from dxtb._src.typing import Any, Callable, Literal, Tensor

from ..properties.vibration import (
IRResult,
Expand Down Expand Up @@ -507,19 +507,7 @@ def dipole_deriv(
Tensor
Cartesian dipole derivative of shape ``(..., 3, nat, 3)``.
"""

if use_analytical is True:
if not hasattr(self, "dipole_analytical") or not callable(
getattr(self, "dipole_analytical")
):
raise ValueError(
"Analytical dipole moment not available. "
"Please use a calculator, which subclasses "
"the `AnalyticalCalculator`."
)
dip_fcn = self.dipole_analytical # type: ignore
else:
dip_fcn = self.dipole
dip_fcn = self._get_dipole_fcn(use_analytical)

if use_functorch is True:
# pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -602,20 +590,8 @@ def polarizability(
# retrieve the efield interaction and the field
field = self.interactions.get_interaction(efield.LABEL_EFIELD).field

if use_analytical is True:
if not hasattr(self, "dipole_analytical") or not callable(
getattr(self, "dipole_analytical")
):
raise ValueError(
"Analytical dipole moment not available. "
"Please use a calculator, which subclasses "
"the `AnalyticalCalculator`."
)

# FIXME: Not working for Raman
dip_fcn = self.dipole_analytical # type: ignore
else:
dip_fcn = self.dipole
# FIXME: Not working for Raman
dip_fcn = self._get_dipole_fcn(use_analytical)

if use_functorch is False:
# pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -958,6 +934,31 @@ def raman(

return RamanResult(vib_res.freqs, intensities, depol)

##########################################################################

def _get_dipole_fcn(self, use_analytical: bool) -> Callable:
if use_analytical is False:
return self.dipole

if not hasattr(self, "dipole_analytical"):
raise ValueError(

Check warning on line 944 in src/dxtb/_src/calculators/types/autograd.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/calculators/types/autograd.py#L944

Added line #L944 was not covered by tests
"Analytical dipole moment not available. "
"Please use a calculator, which subclasses "
"the `AnalyticalCalculator`."
)
if not callable(getattr(self, "dipole_analytical")):
raise ValueError(

Check warning on line 950 in src/dxtb/_src/calculators/types/autograd.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/calculators/types/autograd.py#L950

Added line #L950 was not covered by tests
"Calculator an attribute `dipole_analytical` but it "
"is not callable. This should not happen and is an "
"implementation error."
)

self.opts.cache.store.dipole = True

return self.dipole_analytical # type: ignore

##########################################################################

def calculate(
self,
properties: list[str],
Expand Down
1 change: 0 additions & 1 deletion src/dxtb/_src/calculators/types/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

if TYPE_CHECKING:
from ..base import Calculator
del TYPE_CHECKING

__all__ = [
"requires_positions_grad",
Expand Down
22 changes: 19 additions & 3 deletions src/dxtb/_src/calculators/types/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ def singlepoint(
# Core Hamiltonian integral (requires overlap internally!)
#
# This should be the final integral, because the others are
# potentially calculated on CPU (libcint) even in GPU runs.
# potentially calculated on CPU (libcint), even in GPU runs.
# To avoid unnecessary data transfer, the core Hamiltonian should
# be last. Internally, the overlap integral is only transfered back
# to GPU when all multipole integrals are calculated.
# be calculated last. Internally, the overlap integral is only
# transfered back to GPU when all multipole integrals are calculated.
if self.opts.ints.level >= labels.INTLEVEL_HCORE:
OutputHandler.write_stdout_nf(" - Core Hamiltonian ... ", v=3)
timer.start("Core Hamiltonian", parent_uid="Integrals")
Expand Down Expand Up @@ -325,15 +325,31 @@ def singlepoint(

if kwargs.get("store_hcore", copts.hcore):
self.cache["hcore"] = self.integrals.hcore
else:
if self.integrals.hcore is not None:
if self.integrals.hcore.requires_grad is False:
self.integrals.hcore.clear()

if kwargs.get("store_overlap", copts.overlap):
self.cache["overlap"] = self.integrals.overlap
else:
if self.integrals.overlap is not None:
if self.integrals.overlap.requires_grad is False:
self.integrals.overlap.clear()

if kwargs.get("store_dipole", copts.dipole):
self.cache["dipint"] = self.integrals.dipole
else:
if self.integrals.dipole is not None:
if self.integrals.dipole.requires_grad is False:
self.integrals.dipole.clear()

if kwargs.get("store_quadrupole", copts.quadrupole):
self.cache["quadint"] = self.integrals.quadrupole
else:
if self.integrals.quadrupole is not None:
if self.integrals.quadrupole.requires_grad is False:
self.integrals.quadrupole.clear()

self._ncalcs += 1
return result
Expand Down
15 changes: 9 additions & 6 deletions src/dxtb/_src/calculators/types/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def dipole_numerical(
chrg: Tensor | float | int = defaults.CHRG,
spin: Tensor | float | int | None = defaults.SPIN,
step_size: int | float = defaults.STEP_SIZE,
**kwargs: Any,
) -> Tensor:
r"""
Numerically calculate the electric dipole moment :math:`\mu`.
Expand Down Expand Up @@ -332,11 +333,11 @@ def dipole_numerical(
with OutputHandler.with_verbosity(0):
field[..., i] += step_size
self.interactions.update_efield(field=field)
gr = self.energy(positions, chrg, spin)
gr = self.energy(positions, chrg, spin, **kwargs)

field[..., i] -= 2 * step_size
self.interactions.update_efield(field=field)
gl = self.energy(positions, chrg, spin)
gl = self.energy(positions, chrg, spin, **kwargs)

field[..., i] += step_size
self.interactions.update_efield(field=field)
Expand All @@ -359,6 +360,7 @@ def dipole_deriv_numerical(
chrg: Tensor | float | int = defaults.CHRG,
spin: Tensor | float | int | None = defaults.SPIN,
step_size: int | float = defaults.STEP_SIZE,
**kwargs: Any,
) -> Tensor:
r"""
Numerically calculate cartesian dipole derivative :math:`\mu'`.
Expand Down Expand Up @@ -411,10 +413,10 @@ def dipole_deriv_numerical(
for j in range(3):
with OutputHandler.with_verbosity(0):
positions[..., i, j] += step_size
r = _dipfcn(positions, chrg, spin)
r = _dipfcn(positions, chrg, spin, **kwargs)

positions[..., i, j] -= 2 * step_size
l = _dipfcn(positions, chrg, spin)
l = _dipfcn(positions, chrg, spin, **kwargs)

positions[..., i, j] += step_size
deriv[..., :, i, j] = 0.5 * (r - l) / step_size
Expand All @@ -438,6 +440,7 @@ def polarizability_numerical(
chrg: Tensor | float | int = defaults.CHRG,
spin: Tensor | float | int | None = defaults.SPIN,
step_size: int | float = defaults.STEP_SIZE,
**kwargs: Any,
) -> Tensor:
r"""
Numerically calculate the polarizability tensor :math:`\alpha`.
Expand Down Expand Up @@ -489,11 +492,11 @@ def polarizability_numerical(
with OutputHandler.with_verbosity(0):
field[..., i] += step_size
self.interactions.update_efield(field=field)
gr = _dipfcn(positions, chrg, spin)
gr = _dipfcn(positions, chrg, spin, **kwargs)

field[..., i] -= 2 * step_size
self.interactions.update_efield(field=field)
gl = _dipfcn(positions, chrg, spin)
gl = _dipfcn(positions, chrg, spin, **kwargs)

field[..., i] += step_size
self.interactions.update_efield(field=field)
Expand Down
8 changes: 7 additions & 1 deletion src/dxtb/_src/integral/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,13 @@ def to_pt(self, path: PathLike | None = None) -> None:
@property
def matrix(self) -> Tensor:
if self._matrix is None:
raise RuntimeError("Integral matrix has not been calculated.")
raise RuntimeError(

Check warning on line 386 in src/dxtb/_src/integral/base.py

View check run for this annotation

Codecov / codecov/patch

src/dxtb/_src/integral/base.py#L386

Added line #L386 was not covered by tests
"Integral matrix not found. This can be caused by two "
"reasons:\n"
"1. The integral has not been calculated yet.\n"
"2. The integral was cleared, despite being required "
"in a subsequent calculation. Check the cache settings."
)
return self._matrix

@matrix.setter
Expand Down
88 changes: 88 additions & 0 deletions test/test_calculator/test_cache/test_integrals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# This file is part of dxtb.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Test caching integrals.
"""

from __future__ import annotations

import pytest
import torch

from dxtb._src.typing import DD, Tensor
from dxtb.calculators import GFN1Calculator

from ...conftest import DEVICE

opts = {"cache_enabled": True, "verbosity": 0}


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_overlap_deleted(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd)

calc = GFN1Calculator(numbers, opts={"verbosity": 0}, **dd)
assert calc._ncalcs == 0

# overlap should not be cached
assert calc.opts.cache.store.overlap == False

# one successful calculation
energy = calc.get_energy(positions)
assert calc._ncalcs == 1
assert isinstance(energy, Tensor)

# cache should be empty
assert calc.cache.overlap is None

# ... but also the tensors in the calculator should be deleted
assert calc.integrals.overlap is not None
assert calc.integrals.overlap._matrix is None
assert calc.integrals.overlap._norm is None
assert calc.integrals.overlap._gradient is None


@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_overlap_retained_for_grad(dtype: torch.dtype) -> None:
dd: DD = {"device": DEVICE, "dtype": dtype}

numbers = torch.tensor([3, 1], device=DEVICE)
positions = torch.tensor(
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], **dd, requires_grad=True
)

calc = GFN1Calculator(numbers, opts={"verbosity": 0}, **dd)
assert calc._ncalcs == 0

# overlap should not be cached
assert calc.opts.cache.store.overlap == False

# one successful calculation
energy = calc.get_energy(positions)
assert calc._ncalcs == 1
assert isinstance(energy, Tensor)

# cache should still be empty ...
assert calc.cache.overlap is None

# ... but the tensors in the calculator should still be there
assert calc.integrals.overlap is not None
assert calc.integrals.overlap._matrix is not None
assert calc.integrals.overlap._norm is not None

0 comments on commit 6f805ac

Please sign in to comment.