Skip to content

Commit

Permalink
Updates (#41)
Browse files Browse the repository at this point in the history
* MAINT: Update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Only run type checker with Pydantic v2

* Add OpenMM unit stubs

* Update type checker config

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mattwthompson and pre-commit-ci[bot] authored Mar 27, 2024
1 parent 355ceb8 commit afcef40
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 51 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
os: [macOS-latest, ubuntu-latest]
python-version: ["3.9", "3.9", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
pydantic-version: ["1", "2"]
openmm: [true, false]

Expand All @@ -43,6 +43,7 @@ jobs:
run: python -m pip install -e .

- name: Run mypy
if: ${{ matrix.python-version == 3.10 && matrix.pydantic-version == 2}}
run: mypy -p "openff.models"

- name: Run tests
Expand Down
10 changes: 7 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
- id: trailing-whitespace
- id: debug-statements
- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.3.0
hooks:
- id: black
files: ^openff
Expand All @@ -29,10 +29,14 @@ repos:
'flake8-pytest-style',
'flake8-no-pep420',
]
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
hooks:
- id: yesqa
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.15.1
hooks:
- id: pyupgrade
files: ^openff
exclude: openff/models/_version.py|setup.py
args: [--py38-plus]
args: [--py39-plus]
1 change: 0 additions & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ dependencies:
- pytest
- pytest-cov
- pytest-randomly
- openmm
- mypy
- unyt =3
- pip:
Expand Down
8 changes: 4 additions & 4 deletions openff/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, Dict
from typing import Any, Callable

from openff.units import unit
from openff.units import Quantity

from openff.models._pydantic import BaseModel
from openff.models.types import custom_quantity_encoder, json_loader
Expand All @@ -12,8 +12,8 @@ class DefaultModel(BaseModel):
class Config:
"""Custom Pydantic configuration."""

json_encoders: Dict[Any, Callable] = {
unit.Quantity: custom_quantity_encoder,
json_encoders: dict[Any, Callable] = {
Quantity: custom_quantity_encoder,
}
json_loads: Callable = json_loader
validate_assignment: bool = True
Expand Down
59 changes: 28 additions & 31 deletions openff/models/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Custom models for dealing with unit-bearing quantities in a Pydantic-compatible manner."""

import json
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any

import numpy as np
from openff.units import Quantity, Unit, unit
import numpy
from openff.units import Quantity, Unit
from openff.utilities import has_package, requires_package

from openff.models.exceptions import (
Expand All @@ -23,7 +23,7 @@ def __getitem__(self, t):


if TYPE_CHECKING:
FloatQuantity = unit.Quantity # np.ndarray
FloatQuantity = Quantity
else:

class FloatQuantity(float, metaclass=_FloatQuantityMeta):
Expand All @@ -43,25 +43,23 @@ def validate_type(cls, val):
raise MissingUnitError(
f"Value {val} needs to be tagged with a unit"
)
elif isinstance(val, unit.Quantity):
return unit.Quantity(val)
elif isinstance(val, Quantity):
return Quantity(val)
elif _is_openmm_quantity(val):
return _from_omm_quantity(val)
else:
raise UnitValidationError(
f"Could not validate data of type {type(val)}"
)
else:
unit_ = unit(unit_)
if isinstance(val, unit.Quantity):
unit_ = Unit(unit_)
if isinstance(val, Quantity):
# some custom behavior could go here
assert unit_.dimensionality == val.dimensionality
# return through converting to some intended default units (taken from the class)
val._magnitude = float(val.m)
return val.to(unit_)
# could return here, without converting
# (could be inconsistent with data model - heteregenous but compatible units)
# return val

if _is_openmm_quantity(val):
return _from_omm_quantity(val).to(unit_)
if isinstance(val, int) and not isinstance(val, bool):
Expand All @@ -71,7 +69,7 @@ def validate_type(cls, val):
return val * unit_
if isinstance(val, str):
# could do custom deserialization here?
val = unit.Quantity(val).to(unit_)
val = Quantity(val).to(unit_)
val._magnitude = float(val._magnitude)
return val
if "unyt" in str(val.__class__):
Expand All @@ -98,7 +96,7 @@ def _is_openmm_quantity(obj: object) -> bool:


@requires_package("openmm.unit")
def _from_omm_quantity(val: "openmm.unit.Quantity") -> unit.Quantity: # type: ignore[name-defined]
def _from_omm_quantity(val: "openmm.unit.Quantity") -> Quantity:
"""
Convert float or array quantities tagged with SimTK/OpenMM units to a Pint-compatible quantity.
"""
Expand All @@ -108,17 +106,17 @@ def _from_omm_quantity(val: "openmm.unit.Quantity") -> unit.Quantity: # type: i
unit_ = val.unit
return float(val_) * Unit(str(unit_))
# Here is where the toolkit's ValidatedList could go, if present in the environment
elif (type(val_) in {tuple, list, np.ndarray}) or (
elif (type(val_) in {tuple, list, numpy.ndarray}) or (
type(val_).__module__ == "openmm.vec3"
):
array = np.asarray(val_)
array = numpy.asarray(val_)
return array * Unit(str(unit_))
elif isinstance(val_, (float, int)) and type(val_).__module__ == "numpy":
return val_ * Unit(str(unit_))
else:
raise UnitValidationError(
"Found a openmm.unit.Unit wrapped around something other than a float-like "
f"or np.ndarray-like. Found a unit wrapped around type {type(val_)}."
f"or numpy.ndarray-like. Found a unit wrapped around type {type(val_)}."
)


Expand All @@ -129,11 +127,11 @@ class QuantityEncoder(json.JSONEncoder):
This is intended to operate on FloatQuantity and ArrayQuantity objects.
"""

def default(self, obj): # noqa
if isinstance(obj, unit.Quantity):
def default(self, obj):
if isinstance(obj, Quantity):
if isinstance(obj.magnitude, (float, int)):
data = obj.magnitude
elif isinstance(obj.magnitude, np.ndarray):
elif isinstance(obj.magnitude, numpy.ndarray):
data = obj.magnitude.tolist()
else:
# This shouldn't ever be hit if our object models
Expand All @@ -155,7 +153,7 @@ def custom_quantity_encoder(v):
def json_loader(data: str) -> dict:
"""Load JSON containing custom unit-tagged quantities."""
# TODO: recursively call this function for nested models
out: Dict = json.loads(data)
out: dict = json.loads(data)
for key, val in out.items():
try:
# Directly look for an encoded FloatQuantity/ArrayQuantity,
Expand All @@ -165,7 +163,7 @@ def json_loader(data: str) -> dict:
# Handles some cases of the val being a primitive type
continue
# TODO: More gracefully parse non-FloatQuantity/ArrayQuantity dicts
unit_ = unit(v["unit"])
unit_ = Unit(v["unit"])
val = v["val"]
out[key] = unit_ * val
return out
Expand All @@ -177,7 +175,7 @@ def __getitem__(self, t):


if TYPE_CHECKING:
ArrayQuantity = unit.Quantity # np.ndarray
ArrayQuantity = Quantity
else:

class ArrayQuantity(float, metaclass=_ArrayQuantityMeta):
Expand All @@ -192,7 +190,7 @@ def validate_type(cls, val):
"""Process an array tagged with units into one tagged with "OpenFF" style units."""
unit_ = getattr(cls, "__unit__", Any)
if unit_ is Any:
if isinstance(val, (list, np.ndarray)):
if isinstance(val, (list, numpy.ndarray)):
# Work around a special case in which val might be list[openmm.unit.Quantity]
if isinstance(val, list) and {
type(element).__module__ for element in val
Expand All @@ -208,24 +206,24 @@ def validate_type(cls, val):
f"Value {val} needs to be tagged with a unit"
)

elif isinstance(val, unit.Quantity):
elif isinstance(val, Quantity):
# TODO: This might be a redundant cast causing wasted CPU time.
# But maybe it handles pint vs openff.units.unit?
return unit.Quantity(val)
return Quantity(val)
elif _is_openmm_quantity(val):
return _from_omm_quantity(val)
else:
raise UnitValidationError(
f"Could not validate data of type {type(val)}"
)
else:
unit_ = unit(unit_)
if isinstance(val, unit.Quantity):
unit_ = Unit(unit_)
if isinstance(val, Quantity):
assert unit_.dimensionality == val.dimensionality
return val.to(unit_)
if _is_openmm_quantity(val):
return _from_omm_quantity(val).to(unit_)
if isinstance(val, (np.ndarray, list)):
if isinstance(val, (numpy.ndarray, list)):
if "unyt" in str(val.__class__):
val = val.to_ndarray()
try:
Expand All @@ -239,12 +237,11 @@ def validate_type(cls, val):
raise error
if isinstance(val, bytes):
# Define outside loop
dt = np.dtype(int).newbyteorder("<")
return np.frombuffer(val, dtype=dt) * unit_
dt = numpy.dtype(int).newbyteorder("<")
return numpy.frombuffer(val, dtype=dt) * unit_
if isinstance(val, str):
# could do custom deserialization here?
raise NotImplementedError
# return unit.Quantity(val).to(unit_)
raise UnitValidationError(
f"Could not validate data of type {type(val)}"
)
13 changes: 2 additions & 11 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,13 @@ versionfile_build = openff/models/_version.py
tag_prefix = ''

[mypy]
mypy_path = stubs/
warn_unused_configs = True
# needed for pydantic v1 shim
warn_unused_ignores = False
warn_unused_ignores = True
warn_incomplete_stub = True
show_error_codes = True
exclude = openff/models/_tests/

[mypy-openff.units]
ignore_missing_imports = True

[mypy-openmm]
ignore_missing_imports = True

[mypy-openmm.unit]
ignore_missing_imports = True

[mypy-unyt]
ignore_missing_imports = True

Expand Down
9 changes: 9 additions & 0 deletions stubs/openmm/unit/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Unit:
pass

class Quantity:
@property
def unit(self) -> Unit: ...

def value_in_unit(self, unit: Unit) -> float:
...

0 comments on commit afcef40

Please sign in to comment.