Skip to content

Commit

Permalink
Recursively resolve references on args and kwargs passed to an reacti…
Browse files Browse the repository at this point in the history
…ve operation (#944)
  • Loading branch information
maximlt authored Jun 18, 2024
1 parent 6f6f92c commit 35a4729
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 5 deletions.
6 changes: 3 additions & 3 deletions param/parameterized.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ def resolve_ref(reference, recursive=False):
"""
if recursive:
if isinstance(reference, (list, tuple, set)):
return [r for v in reference for r in resolve_ref(v)]
return [r for v in reference for r in resolve_ref(v, recursive)]
elif isinstance(reference, dict):
return [r for kv in reference.items() for o in kv for r in resolve_ref(o)]
return [r for kv in reference.items() for o in kv for r in resolve_ref(o, recursive)]
elif isinstance(reference, slice):
return [r for v in (reference.start, reference.stop, reference.step) for r in resolve_ref(v)]
return [r for v in (reference.start, reference.stop, reference.step) for r in resolve_ref(v, recursive)]
reference = transform_reference(reference)
if hasattr(reference, '_dinfo'):
dinfo = getattr(reference, '_dinfo', {})
Expand Down
2 changes: 1 addition & 1 deletion param/reactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def _compute_params(self) -> list[Parameter]:
if ref not in ps:
ps.append(ref)
for arg in list(self._operation['args'])+list(self._operation['kwargs'].values()):
for ref in resolve_ref(arg):
for ref in resolve_ref(arg, recursive=True):
if ref not in ps:
ps.append(ref)

Expand Down
20 changes: 20 additions & 0 deletions tests/testreactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,26 @@ def test_reactive_getitem_list():
assert rx([1, 2, 3])[1].rx.value == 2
assert rx([1, 2, 3])[2].rx.value == 3

def test_reactive_getitem_list_with_slice():
i = rx(1)
j = rx(5)
lst = list(range(10))
lstx = rx(lst)
sx = lstx[i: j]
assert sx.rx.value == lst[i.rx.value: j.rx.value]
i.rx.value = 2
assert sx.rx.value == lst[i.rx.value: j.rx.value]

def test_reactive_getitem_numpy_with_tuple():
i = rx(0)
j = rx(1)
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
arrx = rx(arr)
selx = arrx[i, j]
assert selx.rx.value == arr[i.rx.value, j.rx.value]
i.rx.value = 1
assert selx.rx.value == arr[i.rx.value, j.rx.value]

@pytest.mark.parametrize('ufunc', NUMPY_UFUNCS)
def test_numpy_ufunc(ufunc):
l = [1, 2, 3]
Expand Down
49 changes: 48 additions & 1 deletion tests/testrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import param
import pytest

from param.parameterized import Skip
from param.parameterized import Skip, resolve_ref
from param.reactive import bind, rx

class Parameters(param.Parameterized):
Expand Down Expand Up @@ -286,3 +286,50 @@ def gen_strings2():
assert task1.done()
assert not task2.done()
assert len(threads) == 2

def test_resolve_ref_parameter():
p = Parameters()
refs = resolve_ref(p.param.string)
assert len(refs) == 1
assert refs[0] is p.param.string

def test_resolve_ref_depends_method():
p = Parameters()
refs = resolve_ref(p.formatted_string)
assert len(refs) == 1
assert refs[0] is p.param.string

def test_resolve_ref_recursive_list():
p = Parameters()
nested = [[p.param.string]]
refs = resolve_ref(nested, recursive=True)
assert len(refs) == 1
assert refs[0] is p.param.string

def test_resolve_ref_recursive_set():
p = Parameters()
nested = {(p.param.string,)} # Parameters aren't hashable
refs = resolve_ref(nested, recursive=True)
assert len(refs) == 1
assert refs[0] is p.param.string

def test_resolve_ref_recursive_tuple():
p = Parameters()
nested = ((p.param.string,),)
refs = resolve_ref(nested, recursive=True)
assert len(refs) == 1
assert refs[0] is p.param.string

def test_resolve_ref_recursive_dict():
p = Parameters()
nested = {'0': {'0': p.param.string}}
refs = resolve_ref(nested, recursive=True)
assert len(refs) == 1
assert refs[0] is p.param.string

def test_resolve_ref_recursive_slice():
p = Parameters()
nested = [slice(p.param.string)]
refs = resolve_ref(nested, recursive=True)
assert len(refs) == 1
assert refs[0] is p.param.string

0 comments on commit 35a4729

Please sign in to comment.