Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: migrate to pydantic v2 #51

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions .flake8

This file was deleted.

8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install coverage flake8
pip install hatch
pip install .
- name: Lint with flake8
- name: Lint with ruff
run: |
flake8 . --count --exit-zero --show-source --statistics
hatch run dev:check
- name: Test with unittest
run: |
coverage run -m unittest
hatch run dev:cov
- name: Upload Coverage to Codecov
if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9
uses: codecov/codecov-action@v2
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache/

# Translations
*.mo
Expand Down Expand Up @@ -129,4 +130,4 @@ dmypy.json
.pyre/

# IDEs
.vscode
.vscode
File renamed without changes.
107 changes: 98 additions & 9 deletions bpx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,102 @@
"""BPX schema and parsers"""
# flake8: noqa F401
from .expression_parser import ExpressionParser
from .function import Function
from .interpolated_table import InterpolatedTable
from .schema import BPX, check_sto_limits
from .utilities import get_electrode_concentrations, get_electrode_stoichiometries

__version__ = "0.4.0"

__all__ = [
"BPX",
"ExpressionParser",
"Function",
"InterpolatedTable",
"check_sto_limits",
"get_electrode_concentrations",
"get_electrode_stoichiometries",
"parse_bpx_file",
"parse_bpx_obj",
"parse_bpx_str",
]

from .interpolated_table import InterpolatedTable
from .expression_parser import ExpressionParser
from .function import Function
from .validators import check_sto_limits
from .schema import BPX
from .parsers import parse_bpx_str, parse_bpx_obj, parse_bpx_file
from .utilities import get_electrode_stoichiometries, get_electrode_concentrations

def parse_bpx_obj(bpx: dict, v_tol: float = 0.001) -> BPX:
"""
A convenience function to parse a bpx dict into a BPX model.

Parameters
----------
bpx: dict
a dict object in bpx format
v_tol: float
absolute tolerance in [V] to validate the voltage limits, 1 mV by default

Returns
-------
BPX: :class:`bpx.BPX`
a parsed BPX model
"""
if v_tol < 0:
error_msg = "v_tol should not be negative"
raise ValueError(error_msg)

BPX.Settings.tolerances["Voltage [V]"] = v_tol

return BPX.model_validate(bpx)


def parse_bpx_file(filename: str, v_tol: float = 0.001) -> BPX:
"""
A convenience function to parse a bpx file into a BPX model.

Parameters
----------
filename: str
a filepath to a bpx file
v_tol: float
absolute tolerance in [V] to validate the voltage limits, 1 mV by default

Returns
-------
BPX: :class:`bpx.BPX`
a parsed BPX model
"""

from pathlib import Path

bpx = ""
if filename.endswith((".yml", ".yaml")):
import yaml

with Path(filename).open(encoding="utf-8") as f:
bpx = yaml.safe_load(f)
else:
import orjson as json

with Path(filename).open(encoding="utf-8") as f:
bpx = json.loads(f.read())

return parse_bpx_obj(bpx, v_tol)


def parse_bpx_str(bpx: str, v_tol: float = 0.001) -> BPX:
"""
A convenience function to parse a json formatted string in bpx format into a BPX
model.

Parameters
----------
bpx: str
a json formatted string in bpx format
v_tol: float
absolute tolerance in [V] to validate the voltage limits, 1 mV by default

Returns
-------
BPX:
a parsed BPX model
"""
import orjson as json

bpx = json.loads(bpx)
return parse_bpx_obj(bpx, v_tol)
20 changes: 7 additions & 13 deletions bpx/expression_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class ExpressionParser:

ParseException = pp.ParseException

def __init__(self):
def __init__(self) -> None:
fnumber = ppc.number()
ident = pp.Literal("x")
fn_ident = pp.Literal("x")
Expand All @@ -31,21 +31,15 @@ def __init__(self):

expr_list = pp.delimitedList(pp.Group(expr))

def insert_fn_argcount_tuple(t):
def insert_fn_argcount_tuple(t: tuple) -> None:
fn = t.pop(0)
num_args = len(t[0])
t.insert(0, (fn, num_args))

fn_call = (fn_ident + lpar - pp.Group(expr_list) + rpar).setParseAction(
insert_fn_argcount_tuple
)
fn_call = (fn_ident + lpar - pp.Group(expr_list) + rpar).setParseAction(insert_fn_argcount_tuple)

atom = (
addop[...]
+ (
(fn_call | fnumber | ident).set_parse_action(self.push_first)
| pp.Group(lpar + expr + rpar)
)
addop[...] + ((fn_call | fnumber | ident).set_parse_action(self.push_first) | pp.Group(lpar + expr + rpar))
).set_parse_action(self.push_unary_minus)

# by defining exponentiation as "atom [ ^ factor ]..." instead of "atom
Expand All @@ -59,16 +53,16 @@ def insert_fn_argcount_tuple(t):
self.expr_stack = []
self.parser = expr

def push_first(self, toks):
def push_first(self, toks: tuple) -> None:
self.expr_stack.append(toks[0])

def push_unary_minus(self, toks):
def push_unary_minus(self, toks: tuple) -> None:
for t in toks:
if t == "-":
self.expr_stack.append("unary -")
else:
break

def parse_string(self, model_str, parse_all=True):
def parse_string(self, model_str: str, *, parse_all: bool = True) -> None:
self.expr_stack = []
self.parser.parseString(model_str, parseAll=parse_all)
66 changes: 49 additions & 17 deletions bpx/function.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

import copy
from importlib import util
import tempfile
from typing import Callable
from importlib import util
from pathlib import Path
from typing import TYPE_CHECKING, Any

from pydantic_core import CoreSchema, core_schema

from bpx import ExpressionParser

if TYPE_CHECKING:
from collections.abc import Callable

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler


class Function(str):
"""
Expand All @@ -16,31 +25,47 @@ class Function(str):
- single variable 'x'
"""

__slots__ = ()

parser = ExpressionParser()
default_preamble = "from math import exp, tanh, cosh"

@classmethod
def __get_validators__(cls):
yield cls.validate

@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(examples=["1 + x", "1.9793 * exp(-39.3631 * x)" "2 * x**2"])
def __get_pydantic_json_schema__(
cls,
core_schema: CoreSchema,
handler: GetJsonSchemaHandler,
) -> dict[str, Any]:
json_schema = handler(core_schema)
json_schema["examples"] = ["1 + x", "1.9793 * exp(-39.3631 * x)" "2 * x**2"]
return handler.resolve_ref_schema(json_schema)

@classmethod
def validate(cls, v: str) -> Function:
if not isinstance(v, str):
raise TypeError("string required")
error_msg = "string required"
raise TypeError(error_msg)
try:
cls.parser.parse_string(v)
except ExpressionParser.ParseException as e:
raise ValueError(str(e))
raise ValueError(str(e)) from e
return cls(v)

def __repr__(self):
@classmethod
def __get_pydantic_core_schema__(
cls,
source_type: str,
handler: GetCoreSchemaHandler,
) -> CoreSchema:
return core_schema.no_info_after_validator_function(
cls.validate,
handler(str),
)

def __repr__(self) -> str:
return f"Function({super().__repr__()})"

def to_python_function(self, preamble: str = None) -> Callable:
def to_python_function(self, preamble: str | None = None) -> Callable:
"""
Return a python function that can be called with a single argument 'x'

Expand All @@ -61,9 +86,7 @@ def to_python_function(self, preamble: str = None) -> Callable:
function_body = f" return {self}"
source_code = preamble + function_def + function_body

with tempfile.NamedTemporaryFile(
suffix="{}.py".format(function_name), delete=False
) as tmp:
with tempfile.NamedTemporaryFile(suffix=f"{function_name}.py", delete=False) as tmp:
# write to a tempory file so we can
# get the source later on using inspect.getsource
# (as long as the file still exists)
Expand All @@ -75,6 +98,15 @@ def to_python_function(self, preamble: str = None) -> Callable:
module = util.module_from_spec(spec)
spec.loader.exec_module(module)

# Delete
tmp.close()
Path(tmp.name).unlink(missing_ok=True)
if module.__cached__:
cached_file = Path(module.__cached__)
cached_path = cached_file.parent
cached_file.unlink(missing_ok=True)
if not any(cached_path.iterdir()):
cached_path.rmdir()

# return the new function object
value = getattr(module, function_name)
return value
return getattr(module, function_name)
14 changes: 9 additions & 5 deletions bpx/interpolated_table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import List

from pydantic import BaseModel, validator
from pydantic import BaseModel, ValidationInfo, field_validator


class InterpolatedTable(BaseModel):
Expand All @@ -12,8 +14,10 @@ class InterpolatedTable(BaseModel):
x: List[float]
y: List[float]

@validator("y")
def same_length(cls, v: list, values: dict) -> list:
if "x" in values and len(v) != len(values["x"]):
raise ValueError("x & y should be same length")
@field_validator("y")
@classmethod
def same_length(cls, v: list, info: ValidationInfo) -> list:
if "x" in info.data and len(v) != len(info.data["x"]):
error_msg = "x & y should be same length"
raise ValueError(error_msg)
return v
Loading
Loading