Skip to content

Commit

Permalink
Greatly simplify AdminHolder context as hold_and_reset_prev_attrib_va…
Browse files Browse the repository at this point in the history
…lue_context
  • Loading branch information
sveinugu committed Aug 13, 2024
1 parent 2426d35 commit 1bac796
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 168 deletions.
19 changes: 16 additions & 3 deletions src/omnipy/data/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from contextlib import suppress
from contextlib import contextmanager, suppress
import functools
import inspect
import json
Expand All @@ -19,6 +19,7 @@
get_args,
get_origin,
Hashable,
Iterator,
Literal,
Optional,
ParamSpec,
Expand Down Expand Up @@ -498,10 +499,22 @@ def validate(cls: type['Model'], value: Any) -> 'Model':
Hack to allow overwriting of __iter__ method without compromising pydantic validation. Part
of the pydantic API and not the Omnipy API.
"""
# TODO: Doublecheck if validate() method is still needed for pydantic v2

_validate_cls_counts[cls.__name__] += 1
if is_model_instance(value):
with AttribHolder(
value, '__iter__', GenericModel.__iter__, switch_to_other=True, on_class=True):

@contextmanager
def temporary_set_value_iter_to_pydantic_method() -> Iterator[None]:
prev_iter = value.__class__.__iter__
value.__class__.__iter__ = GenericModel.__iter__

try:
yield
finally:
value.__class__.__iter__ = prev_iter

with temporary_set_value_iter_to_pydantic_method():
return super().validate(value)
else:
return super().validate(value)
Expand Down
71 changes: 12 additions & 59 deletions src/omnipy/util/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,66 +54,19 @@ def raise_derived(self, exc: Exception):
raise exc


Undefined = object()


# TODO: Perhaps the two use cases of this are so dissimilar that the class should be split into two
# distinct subclasses to make the code easier to understand?
class AttribHolder(AbstractContextManager):
def __init__(self,
obj: object,
attr_name: str,
other_value: object = Undefined,
reset_to_other: bool = False,
switch_to_other: bool = False,
on_class: bool = False,
copy_attr: bool = False):
self._obj_or_cls = obj.__class__ if on_class else obj
self._attr_name = attr_name
self._other_value = None if other_value is Undefined else other_value
self._prev_value: object | None = None
self._reset_to_other = reset_to_other
self._switch_to_other = switch_to_other
self._copy_attr = copy_attr
self._store_prev_attr = False
self._set_attr_to_other = False

assert not (reset_to_other and switch_to_other), \
'Only one of `reset_to_other` and `switch_to_other` can be specified.'

if other_value is not Undefined:
assert reset_to_other or switch_to_other, \
('If other_value is specified, you also need to set one of `reset_to_other` and '
'`switch_to_other`')

if reset_to_other or switch_to_other:
assert other_value is not Undefined, \
('If one of `reset_to_other` and `switch_to_other` are specified, you also need '
'to provide a value for `other_value`')

def __enter__(self):
if hasattr(self._obj_or_cls, self._attr_name):
from omnipy.util.helpers import all_equals

attr_value = getattr(self._obj_or_cls, self._attr_name)

self._store_prev_attr = \
not self._reset_to_other and not all_equals(attr_value, self._other_value)

if self._store_prev_attr:
self._prev_value = deepcopy(attr_value) if self._copy_attr else attr_value

if self._switch_to_other:
setattr(self._obj_or_cls, self._attr_name, self._other_value)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self._switch_to_other or exc_val is not None:
reset_value = self._other_value if self._reset_to_other else self._prev_value
setattr(self._obj_or_cls, self._attr_name, reset_value)
@contextmanager
def hold_and_reset_prev_attrib_value_context(
obj: object,
attr_name: str,
copy_attr: bool = False,
) -> Iterator[None]:
attr_value = getattr(obj, attr_name)
prev_value = deepcopy(attr_value) if copy_attr else attr_value

if self._store_prev_attr:
self._prev_value = None
try:
yield
finally:
setattr(obj, attr_name, prev_value)


@contextmanager
Expand Down
118 changes: 16 additions & 102 deletions tests/util/test_contexts.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from contextlib import suppress
import sys
from textwrap import dedent
from typing import Callable, TypeAlias

import pytest

from omnipy.util.contexts import (AttribHolder,
from omnipy.util.contexts import (hold_and_reset_prev_attrib_value_context,
LastErrorHolder,
print_exception,
setup_and_teardown_callback_context)
Expand Down Expand Up @@ -239,27 +240,7 @@ def test_with_last_error() -> None:
assert 'a=4 is even' in str(exc_info.getrepr())


def test_attrib_holder_init() -> None:
class A:
def __init__(self, num: int) -> None:
self.num = num

a = A(5)

with pytest.raises(AssertionError):
AttribHolder(a, 'num', reset_to_other=True)

with pytest.raises(AssertionError):
AttribHolder(a, 'num', switch_to_other=True)

with pytest.raises(AssertionError):
AttribHolder(a, 'num', 9)

with pytest.raises(AssertionError):
AttribHolder(a, 'num', 9, reset_to_other=True, switch_to_other=True)


def test_with_class_attrib_holder_reset_attr_if_exception() -> None:
def test_hold_and_reset_prev_attrib_value_context_at_teardown_and_exception() -> None:
class A:
...

Expand All @@ -268,108 +249,41 @@ def __init__(self, num: int) -> None:
self.num = num

a = A()
with AttribHolder(a, 'num') as ms:
assert ms._prev_value is None
with pytest.raises(AttributeError):
with hold_and_reset_prev_attrib_value_context(a, 'num'):
pass

b = B(5)
with AttribHolder(b, 'num') as ms:
with hold_and_reset_prev_attrib_value_context(b, 'num'):
b.num = 7
assert ms._prev_value == 5
assert b.num == 7
assert b.num == 5

try:
with suppress(RuntimeError):
b.num = 5
with AttribHolder(b, 'num') as ms:
with hold_and_reset_prev_attrib_value_context(b, 'num'):
b.num = 7
assert ms._prev_value == 5
raise RuntimeError()
except RuntimeError:
pass
assert b.num == 5


def test_with_class_attrib_holder_set_attr_to_other_if_exception() -> None:
class A:
def __init__(self, num: int) -> None:
self.num = num

a = A(5)

try:
a.num = 5
with AttribHolder(a, 'num', 9, reset_to_other=True) as ms:
a.num = 7
assert ms._prev_value is None
raise RuntimeError()
except RuntimeError:
pass
assert a.num == 9


def test_with_class_attrib_holder_reset_attr_if_exception_deepcopy() -> None:
def test_hold_and_reset_prev_attrib_value_context_at_exception_deepcopy() -> None:
class B:
def __init__(self, numbers: list[list[int]]) -> None:
self.numbers = numbers

b = B([[5]])
try:
with AttribHolder(b, 'numbers') as ms:

with suppress(RuntimeError):
with hold_and_reset_prev_attrib_value_context(b, 'numbers'):
b.numbers[0][0] += 2
assert b.numbers == [[7]]
assert ms._prev_value == [[7]]
raise RuntimeError()
except RuntimeError:
pass
assert b.numbers == [[7]]

try:
with suppress(RuntimeError):
b.numbers = [[5]]
with AttribHolder(b, 'numbers', copy_attr=True) as ms:
with hold_and_reset_prev_attrib_value_context(b, 'numbers', copy_attr=True):
b.numbers[0][0] += 2
assert b.numbers == [[7]]
assert ms._prev_value == [[5]]
raise RuntimeError()
except RuntimeError:
pass
assert b.numbers == [[5]]


def test_with_class_attrib_holder_method_switching() -> None:
class A:
...

class B:
def method(self):
return 'method'

def other_method(self):
return 'other_method'

a = A()
with AttribHolder(a, 'method', other_method, switch_to_other=True, on_class=True) as ms:
assert ms._prev_value is None
with pytest.raises(AttributeError):
a.method() # type: ignore

A.method = other_method

a = A()
with AttribHolder(a, 'method', other_method, switch_to_other=True, on_class=True) as ms:
assert a.method() == 'other_method' # type: ignore
assert ms._prev_value is None

b = B()
with AttribHolder(b, 'method', other_method, switch_to_other=True, on_class=True) as ms:
assert b.method() == 'other_method'
assert ms._prev_value.__name__ == 'method'
assert b.method() == 'method'

b = B()
try:
with AttribHolder(b, 'method', other_method, switch_to_other=True, on_class=True) as ms:
assert b.method() == 'other_method'
assert ms._prev_value.__name__ == 'method'
raise RuntimeError()
except RuntimeError:
pass
assert b.method() == 'method'
8 changes: 5 additions & 3 deletions tests/util/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from omnipy.util.contexts import AttribHolder
from omnipy.util.contexts import hold_and_reset_prev_attrib_value_context
from omnipy.util.decorators import (add_callback_after_call,
add_callback_if_exception,
apply_decorator_to_property,
Expand Down Expand Up @@ -73,7 +73,8 @@ def my_callback_after_call(ret: A | None, x: A, *, y: int) -> None:

my_a = A([1, 2, 3])

restore_numbers_context = AttribHolder(my_a, 'numbers', copy_attr=True)
restore_numbers_context = hold_and_reset_prev_attrib_value_context(
my_a, 'numbers', copy_attr=True)
decorated_my_appender = add_callback_after_call(
my_appender, my_callback_after_call, restore_numbers_context, my_a, y=0)

Expand Down Expand Up @@ -103,7 +104,8 @@ def my_callback_after_call(ret: A | None, x: A, *, y: int) -> None:
my_a = A([1, 2, 3])
my_other_a = A([1, 2, 3])

restore_numbers_context = AttribHolder(my_a, 'numbers', copy_attr=True)
restore_numbers_context = hold_and_reset_prev_attrib_value_context(
my_a, 'numbers', copy_attr=True)
decorated_my_appender = add_callback_after_call(
my_appender, my_callback_after_call, restore_numbers_context, my_other_a, y=4)

Expand Down
2 changes: 1 addition & 1 deletion tests/util/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import weakref

from pydantic import BaseModel
from pydantic.fields import Undefined
from pydantic.generics import GenericModel
import pytest
from typing_inspect import get_generic_type

from omnipy.api.protocols.private.util import HasContents, IsSnapshotHolder
from omnipy.data.dataset import Dataset
from omnipy.data.model import Model
from omnipy.util.contexts import Undefined
from omnipy.util.helpers import (all_type_variants,
called_from_omnipy_tests,
ensure_non_str_byte_iterable,
Expand Down

0 comments on commit 1bac796

Please sign in to comment.