Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates #41

Merged
merged 7 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
os: [macOS-latest, ubuntu-latest]
python-version: ["3.9", "3.9", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
pydantic-version: ["1", "2"]
openmm: [true, false]

Expand All @@ -43,6 +43,7 @@ jobs:
run: python -m pip install -e .

- name: Run mypy
if: ${{ matrix.python-version == 3.10 && matrix.pydantic-version == 2}}
run: mypy -p "openff.models"

- name: Run tests
Expand Down
10 changes: 7 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ repos:
- id: trailing-whitespace
- id: debug-statements
- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.3.0
hooks:
- id: black
files: ^openff
Expand All @@ -29,10 +29,14 @@ repos:
'flake8-pytest-style',
'flake8-no-pep420',
]
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
hooks:
- id: yesqa
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.15.1
hooks:
- id: pyupgrade
files: ^openff
exclude: openff/models/_version.py|setup.py
args: [--py38-plus]
args: [--py39-plus]
1 change: 0 additions & 1 deletion devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ dependencies:
- pytest
- pytest-cov
- pytest-randomly
- openmm
- mypy
- unyt =3
- pip:
Expand Down
8 changes: 4 additions & 4 deletions openff/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Callable, Dict
from typing import Any, Callable

from openff.units import unit
from openff.units import Quantity

from openff.models._pydantic import BaseModel
from openff.models.types import custom_quantity_encoder, json_loader
Expand All @@ -12,8 +12,8 @@ class DefaultModel(BaseModel):
class Config:
"""Custom Pydantic configuration."""

json_encoders: Dict[Any, Callable] = {
unit.Quantity: custom_quantity_encoder,
json_encoders: dict[Any, Callable] = {
Quantity: custom_quantity_encoder,
}
json_loads: Callable = json_loader
validate_assignment: bool = True
Expand Down
59 changes: 28 additions & 31 deletions openff/models/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Custom models for dealing with unit-bearing quantities in a Pydantic-compatible manner."""

import json
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any

import numpy as np
from openff.units import Quantity, Unit, unit
import numpy
from openff.units import Quantity, Unit
from openff.utilities import has_package, requires_package

from openff.models.exceptions import (
Expand All @@ -23,7 +23,7 @@ def __getitem__(self, t):


if TYPE_CHECKING:
FloatQuantity = unit.Quantity # np.ndarray
FloatQuantity = Quantity
else:

class FloatQuantity(float, metaclass=_FloatQuantityMeta):
Expand All @@ -43,25 +43,23 @@ def validate_type(cls, val):
raise MissingUnitError(
f"Value {val} needs to be tagged with a unit"
)
elif isinstance(val, unit.Quantity):
return unit.Quantity(val)
elif isinstance(val, Quantity):
return Quantity(val)
elif _is_openmm_quantity(val):
return _from_omm_quantity(val)
else:
raise UnitValidationError(
f"Could not validate data of type {type(val)}"
)
else:
unit_ = unit(unit_)
if isinstance(val, unit.Quantity):
unit_ = Unit(unit_)
if isinstance(val, Quantity):
# some custom behavior could go here
assert unit_.dimensionality == val.dimensionality
# return through converting to some intended default units (taken from the class)
val._magnitude = float(val.m)
return val.to(unit_)
# could return here, without converting
# (could be inconsistent with data model - heteregenous but compatible units)
# return val

if _is_openmm_quantity(val):
return _from_omm_quantity(val).to(unit_)
if isinstance(val, int) and not isinstance(val, bool):
Expand All @@ -71,7 +69,7 @@ def validate_type(cls, val):
return val * unit_
if isinstance(val, str):
# could do custom deserialization here?
val = unit.Quantity(val).to(unit_)
val = Quantity(val).to(unit_)
val._magnitude = float(val._magnitude)
return val
if "unyt" in str(val.__class__):
Expand All @@ -98,7 +96,7 @@ def _is_openmm_quantity(obj: object) -> bool:


@requires_package("openmm.unit")
def _from_omm_quantity(val: "openmm.unit.Quantity") -> unit.Quantity: # type: ignore[name-defined]
def _from_omm_quantity(val: "openmm.unit.Quantity") -> Quantity:
"""
Convert float or array quantities tagged with SimTK/OpenMM units to a Pint-compatible quantity.
"""
Expand All @@ -108,17 +106,17 @@ def _from_omm_quantity(val: "openmm.unit.Quantity") -> unit.Quantity: # type: i
unit_ = val.unit
return float(val_) * Unit(str(unit_))
# Here is where the toolkit's ValidatedList could go, if present in the environment
elif (type(val_) in {tuple, list, np.ndarray}) or (
elif (type(val_) in {tuple, list, numpy.ndarray}) or (
type(val_).__module__ == "openmm.vec3"
):
array = np.asarray(val_)
array = numpy.asarray(val_)
return array * Unit(str(unit_))
elif isinstance(val_, (float, int)) and type(val_).__module__ == "numpy":
return val_ * Unit(str(unit_))
else:
raise UnitValidationError(
"Found a openmm.unit.Unit wrapped around something other than a float-like "
f"or np.ndarray-like. Found a unit wrapped around type {type(val_)}."
f"or numpy.ndarray-like. Found a unit wrapped around type {type(val_)}."
)


Expand All @@ -129,11 +127,11 @@ class QuantityEncoder(json.JSONEncoder):
This is intended to operate on FloatQuantity and ArrayQuantity objects.
"""

def default(self, obj): # noqa
if isinstance(obj, unit.Quantity):
def default(self, obj):
if isinstance(obj, Quantity):
if isinstance(obj.magnitude, (float, int)):
data = obj.magnitude
elif isinstance(obj.magnitude, np.ndarray):
elif isinstance(obj.magnitude, numpy.ndarray):
data = obj.magnitude.tolist()
else:
# This shouldn't ever be hit if our object models
Expand All @@ -155,7 +153,7 @@ def custom_quantity_encoder(v):
def json_loader(data: str) -> dict:
"""Load JSON containing custom unit-tagged quantities."""
# TODO: recursively call this function for nested models
out: Dict = json.loads(data)
out: dict = json.loads(data)
for key, val in out.items():
try:
# Directly look for an encoded FloatQuantity/ArrayQuantity,
Expand All @@ -165,7 +163,7 @@ def json_loader(data: str) -> dict:
# Handles some cases of the val being a primitive type
continue
# TODO: More gracefully parse non-FloatQuantity/ArrayQuantity dicts
unit_ = unit(v["unit"])
unit_ = Unit(v["unit"])
val = v["val"]
out[key] = unit_ * val
return out
Expand All @@ -177,7 +175,7 @@ def __getitem__(self, t):


if TYPE_CHECKING:
ArrayQuantity = unit.Quantity # np.ndarray
ArrayQuantity = Quantity
else:

class ArrayQuantity(float, metaclass=_ArrayQuantityMeta):
Expand All @@ -192,7 +190,7 @@ def validate_type(cls, val):
"""Process an array tagged with units into one tagged with "OpenFF" style units."""
unit_ = getattr(cls, "__unit__", Any)
if unit_ is Any:
if isinstance(val, (list, np.ndarray)):
if isinstance(val, (list, numpy.ndarray)):
# Work around a special case in which val might be list[openmm.unit.Quantity]
if isinstance(val, list) and {
type(element).__module__ for element in val
Expand All @@ -208,24 +206,24 @@ def validate_type(cls, val):
f"Value {val} needs to be tagged with a unit"
)

elif isinstance(val, unit.Quantity):
elif isinstance(val, Quantity):
# TODO: This might be a redundant cast causing wasted CPU time.
# But maybe it handles pint vs openff.units.unit?
return unit.Quantity(val)
return Quantity(val)
elif _is_openmm_quantity(val):
return _from_omm_quantity(val)
else:
raise UnitValidationError(
f"Could not validate data of type {type(val)}"
)
else:
unit_ = unit(unit_)
if isinstance(val, unit.Quantity):
unit_ = Unit(unit_)
if isinstance(val, Quantity):
assert unit_.dimensionality == val.dimensionality
return val.to(unit_)
if _is_openmm_quantity(val):
return _from_omm_quantity(val).to(unit_)
if isinstance(val, (np.ndarray, list)):
if isinstance(val, (numpy.ndarray, list)):
if "unyt" in str(val.__class__):
val = val.to_ndarray()
try:
Expand All @@ -239,12 +237,11 @@ def validate_type(cls, val):
raise error
if isinstance(val, bytes):
# Define outside loop
dt = np.dtype(int).newbyteorder("<")
return np.frombuffer(val, dtype=dt) * unit_
dt = numpy.dtype(int).newbyteorder("<")
return numpy.frombuffer(val, dtype=dt) * unit_
if isinstance(val, str):
# could do custom deserialization here?
raise NotImplementedError
# return unit.Quantity(val).to(unit_)
raise UnitValidationError(
f"Could not validate data of type {type(val)}"
)
13 changes: 2 additions & 11 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,13 @@ versionfile_build = openff/models/_version.py
tag_prefix = ''

[mypy]
mypy_path = stubs/
warn_unused_configs = True
# needed for pydantic v1 shim
warn_unused_ignores = False
warn_unused_ignores = True
warn_incomplete_stub = True
show_error_codes = True
exclude = openff/models/_tests/

[mypy-openff.units]
ignore_missing_imports = True

[mypy-openmm]
ignore_missing_imports = True

[mypy-openmm.unit]
ignore_missing_imports = True

[mypy-unyt]
ignore_missing_imports = True

Expand Down
9 changes: 9 additions & 0 deletions stubs/openmm/unit/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Unit:
pass

class Quantity:
@property
def unit(self) -> Unit: ...

def value_in_unit(self, unit: Unit) -> float:
...