Skip to content

Commit

Permalink
ENH: adapt implementation to QRules v0.10 (#362)
Browse files Browse the repository at this point in the history
* DX: run AmpForm tests with QRules v0.9.x
* ENH: adapt implementation to QRules v0.10
* ENH: remove QRules version restriction
* FEAT: add `get_qrules_version()` function
* MAINT: update links to QRules v0.10.x API
* MAINT: upgrade constraints to QRules v0.10.1
  • Loading branch information
redeboer authored Mar 1, 2024
1 parent 3d3aa88 commit 3e26e8d
Show file tree
Hide file tree
Showing 23 changed files with 215 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .constraints/py3.10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ python-lsp-server==1.10.0
pytoolconfig==1.3.1
pyyaml==6.0.1
pyzmq==25.1.2
qrules==0.9.8
qrules==0.10.1
referencing==0.33.0
requests==2.31.0
rfc3339-validator==0.1.4
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ python-lsp-server==1.10.0
pytoolconfig==1.3.1
pyyaml==6.0.1
pyzmq==25.1.2
qrules==0.9.8
qrules==0.10.1
referencing==0.33.0
requests==2.31.0
rfc3339-validator==0.1.4
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ python-lsp-server==1.10.0
pytoolconfig==1.3.1
pyyaml==6.0.1
pyzmq==25.1.2
qrules==0.9.8
qrules==0.10.1
referencing==0.33.0
requests==2.31.0
rfc3339-validator==0.1.4
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.7.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pytoolconfig==1.3.0
pytz==2024.1
pyyaml==6.0.1
pyzmq==24.0.1
qrules==0.9.8
qrules==0.10.1
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.8.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ pytoolconfig==1.3.1
pytz==2024.1
pyyaml==6.0.1
pyzmq==25.1.2
qrules==0.9.8
qrules==0.10.1
referencing==0.33.0
requests==2.31.0
rfc3339-validator==0.1.4
Expand Down
2 changes: 1 addition & 1 deletion .constraints/py3.9.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ python-lsp-server==1.10.0
pytoolconfig==1.3.1
pyyaml==6.0.1
pyzmq==25.1.2
qrules==0.9.8
qrules==0.10.1
referencing==0.33.0
requests==2.31.0
rfc3339-validator==0.1.4
Expand Down
36 changes: 36 additions & 0 deletions .github/workflows/ci-qrules-v0.9.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Test with QRules v0.9

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

env:
PYTHONHASHSEED: "0"

on:
push:
branches:
- main
- epic/*
- "[0-9]+.[0-9]+.x"
pull_request:
branches:
- main
- epic/*
- "[0-9]+.[0-9]+.x"
workflow_dispatch:

jobs:
pytest:
name: Run unit tests
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- uses: ComPWA/actions/pip-install@v1
with:
additional-packages: tox
editable: "yes"
extras: test
python-version: "3.9"
specific-packages: qrules==0.9.*
- run: pytest -n auto
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
"ReactionInfo": "qrules.transition.ReactionInfo",
"Slider": ("obj", "symplot.Slider"),
"State": "qrules.transition.State",
"StateTransition": "qrules.transition.StateTransition",
"StateTransition": "qrules.topology.Transition",
"T": "TypeVar",
"Topology": "qrules.topology.Topology",
"WignerD": "sympy.physics.quantum.spin.WignerD",
Expand Down Expand Up @@ -238,7 +238,7 @@
"numpy": (f"https://numpy.org/doc/{pin_minor('numpy')}", None),
"pwa": ("https://pwa.readthedocs.io", None),
"python": ("https://docs.python.org/3", None),
"qrules": (f"https://qrules.readthedocs.io/en/{pin('qrules')}", None),
"qrules": (f"https://qrules.readthedocs.io/{pin('qrules')}", None),
"sympy": ("https://docs.sympy.org/latest", None),
}
linkcheck_anchors = False
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/amplitude.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In {doc}`qrules:usage/reaction`, we used {func}`~qrules.generate_transitions` to create a list of allowed {class}`~qrules.transition.StateTransition`s for a specific decay channel:"
"In {doc}`qrules:usage/reaction`, we used {func}`~qrules.generate_transitions` to create a list of allowed {class}`~qrules.topology.Transition`s for a specific decay channel:"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/dynamics/custom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"A function that behaves like a {class}`.ResonanceDynamicsBuilder` should return a {class}`tuple` of some {class}`~sympy.core.expr.Expr` (which formulates your lineshape) and a {class}`dict` of {class}`~sympy.core.symbol.Symbol`s to some suggested initial values. This signature is required so the builder knows how to extract the correct symbol names and their suggested initial values from a {class}`~qrules.transition.StateTransition`."
"A function that behaves like a {class}`.ResonanceDynamicsBuilder` should return a {class}`tuple` of some {class}`~sympy.core.expr.Expr` (which formulates your lineshape) and a {class}`dict` of {class}`~sympy.core.symbol.Symbol`s to some suggested initial values. This signature is required so the builder knows how to extract the correct symbol names and their suggested initial values from a {class}`~qrules.topology.Transition`."
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/usage/helicity/formalism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"See {func}`.formulate_isobar_wigner_d` and {func}`.formulate_isobar_cg_coefficients` for how these Wigner-$D$ functions and Clebsch-Gordan coefficients are computed for each node on a {class}`~qrules.transition.StateTransition`.\n",
"See {func}`.formulate_isobar_wigner_d` and {func}`.formulate_isobar_cg_coefficients` for how these Wigner-$D$ functions and Clebsch-Gordan coefficients are computed for each node on a {class}`~qrules.topology.Transition`.\n",
"\n",
"We can see this also from the original {class}`~qrules.transition.ReactionInfo` objects. Let's select only the {attr}`~qrules.transition.ReactionInfo.transitions` where the $a_1(1260)^+$ resonance has spin projection $-1$ (taken to be helicity $-1$ in the helicity formalism). We then see just one {class}`~qrules.transition.StateTransition` in the helicity basis and three transitions in the canonical basis:"
"We can see this also from the original {class}`~qrules.transition.ReactionInfo` objects. Let's select only the {attr}`~qrules.transition.ReactionInfo.transitions` where the $a_1(1260)^+$ resonance has spin projection $-1$ (taken to be helicity $-1$ in the helicity formalism). We then see just one {class}`~qrules.topology.Transition` in the helicity basis and three transitions in the canonical basis:"
]
},
{
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ classifiers = [
]
dependencies = [
"attrs >=20.1.0", # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen
"qrules ==0.9.*, >=0.9.6", # https://github.com/ComPWA/qrules/pull/145
"qrules >=0.9.6",
"sympy >=1.10",
'importlib-metadata; python_version <"3.8.0"',
'singledispatchmethod; python_version <"3.8.0"',
'typing-extensions; python_version <"3.8.0"',
]
Expand Down Expand Up @@ -208,6 +209,7 @@ reportPrivateUsage = false
reportReturnType = false
reportUnboundVariable = false
reportUnknownArgumentType = false
reportUnknownLambdaType = false
reportUnknownMemberType = false
reportUnknownParameterType = false
reportUnknownVariableType = false
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Build amplitude models with different PWA formalisms.
AmpForm formalizes formalisms from :doc:`Partial Wave Analysis <pwa:index>`. It provides
tools to convert `~qrules.transition.StateTransition` solutions that the `.qrules`
tools to convert `~qrules.topology.Transition` solutions that the `.qrules`
package found into an `.HelicityModel`. The output `.HelicityModel` can then be used by
external fitter packages to generate a data set (toy Monte Carlo) for this specific
reaction process, or to optimize ('fit') its parameters so that they resemble the data
Expand Down
24 changes: 24 additions & 0 deletions src/ampform/_qrules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import sys
from functools import lru_cache

if sys.version_info < (3, 8):
from importlib_metadata import version
else:
from importlib.metadata import version


@lru_cache(maxsize=1)
def get_qrules_version() -> tuple[int, ...]:
"""Get the version of qrules as a tuple of integers.
>>> get_qrules_version() >= (0, 10)
True
>>> import pytest
>>> from ampform._qrules import get_qrules_version
>>> if get_qrules_version() < (0, 10):
... pytest.skip("Doctest only works for qrules>=0.10")
"""
v = version("qrules")
return tuple(int(i) for i in v.split(".") if i.strip().isdigit())
35 changes: 29 additions & 6 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@
from attrs.validators import deep_iterable, instance_of, optional
from qrules.combinatorics import perform_external_edge_identical_particle_combinatorics
from qrules.particle import Particle
from qrules.transition import ReactionInfo, StateTransition
from qrules.transition import (
InteractionProperties,
ReactionInfo,
State,
StateTransition,
)

from ampform._qrules import get_qrules_version
from ampform.dynamics.builder import (
ResonanceDynamicsBuilder,
TwoBodyKinematicVariableSet,
Expand Down Expand Up @@ -75,6 +81,7 @@
from typing import override
if TYPE_CHECKING:
from IPython.lib.pretty import PrettyPrinter
from qrules.topology import MutableTransition

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -453,11 +460,9 @@ def __formulate_topology_amplitude(
) -> sp.Expr:
sequential_expressions: list[sp.Expr] = []
for transition in transitions:
sequential_graphs = perform_external_edge_identical_particle_combinatorics(
transition.to_graph()
)
sequential_graphs = _perform_combinatorics(transition)
for graph in sequential_graphs:
first_transition = StateTransition.from_graph(graph)
first_transition = _freeze(graph)
expression = self.__formulate_sequential_decay(first_transition)
sequential_expressions.append(expression)

Expand Down Expand Up @@ -561,6 +566,24 @@ def __generate_amplitude_prefactor(
return None


def _perform_combinatorics(
transition: StateTransition,
) -> list[MutableTransition[State, InteractionProperties]]:
if get_qrules_version() < (0, 10):
return perform_external_edge_identical_particle_combinatorics(
transition.to_graph() # type: ignore[attr-defined]
)
graph = transition.convert(lambda s: (s.particle, s.spin_projection)).unfreeze()
combinations = perform_external_edge_identical_particle_combinatorics(graph)
return [g.freeze().convert(lambda s: State(*s)).unfreeze() for g in combinations]


def _freeze(graph: MutableTransition[State, InteractionProperties]) -> StateTransition:
if get_qrules_version() < (0, 10):
return StateTransition.from_graph(graph) # type: ignore[attr-defined]
return graph.freeze()


class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder):
r"""Amplitude model generator for the canonical helicity formalism.
Expand Down Expand Up @@ -656,7 +679,7 @@ def assign( # noqa: PLR6301
- `str`: Select transition nodes by the name of the `~.TwoBodyDecay.parent`
`~qrules.particle.Particle`.
- `.TwoBodyDecay` or `tuple` of a `~qrules.transition.StateTransition` with a
- `.TwoBodyDecay` or `tuple` of a `~qrules.topology.Transition` with a
node ID: set dynamics for one specific transition node.
"""
msg = (
Expand Down
2 changes: 1 addition & 1 deletion src/ampform/helicity/align/axisangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def formulate_wigner_rotation(
:cite:`marangottoHelicityAmplitudesGeneric2020`, p.6, especially Eq.(36).
Args:
transition: The `~qrules.transition.StateTransition` in which you
transition: The `~qrules.topology.Transition` in which you
want to rotate one of the spin states.
rotated_state_id: The state ID of a spin `~qrules.transition.State`
that you want to rotate.
Expand Down
54 changes: 41 additions & 13 deletions src/ampform/helicity/align/dpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from attrs import define, field
from attrs.validators import in_
from qrules.topology import Topology
from qrules.transition import ReactionInfo, StateTransition, StateTransitionCollection
from qrules.transition import ReactionInfo, StateTransition
from sympy.physics.quantum.spin import Rotation as Wigner

from ampform._qrules import get_qrules_version
from ampform.helicity.align import SpinAlignment
from ampform.helicity.decay import (
get_outer_state_ids,
Expand All @@ -35,6 +36,11 @@
if TYPE_CHECKING:
from sympy.physics.quantum.spin import WignerD

if get_qrules_version() < (0, 10):
from qrules.transition import ( # type: ignore[attr-defined]
StateTransitionCollection,
)


@define
class DalitzPlotDecomposition(SpinAlignment):
Expand Down Expand Up @@ -112,33 +118,47 @@ def __call__(
return Wigner.d(j, m, m_prime, zeta)


T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology)
"""Allowed types for :func:`relabel_edge_ids`."""
if get_qrules_version() < (0, 10):
T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology)
"""Allowed types for :func:`relabel_edge_ids`."""
else:
T = TypeVar( # type: ignore[misc] # pyright: ignore[reportConstantRedefinition]
"T", ReactionInfo, StateTransition, Topology
)
"""Allowed types for :func:`relabel_edge_ids`."""


@singledispatch
def relabel_edge_ids(obj: T) -> T:
def relabel_edge_ids(obj: T) -> T: # type: ignore[reportInvalidTypeForm]
msg = f"Cannot relabel edge IDs of a {type(obj).__name__}"
raise NotImplementedError(msg)


@relabel_edge_ids.register(ReactionInfo)
def _(obj: ReactionInfo) -> ReactionInfo: # type: ignore[misc]
return ReactionInfo( # no attrs.evolve() in order to call __attrs_post_init__()
transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups],
if get_qrules_version() < (0, 10):
return ReactionInfo( # type: ignore[call-arg]
transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], # type: ignore[attr-defined]
formalism=obj.formalism,
)
return ReactionInfo(
# no attrs.evolve() in order to call __attrs_post_init__()
transitions=[relabel_edge_ids(g) for g in obj.transitions],
formalism=obj.formalism,
)


@relabel_edge_ids.register(StateTransitionCollection)
def _(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc]
return StateTransitionCollection([ # no attrs.evolve() for __attrs_post_init__()
relabel_edge_ids(transition) for transition in obj.transitions
])
if get_qrules_version() < (0, 10):

def __relabel_stc(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc]
return StateTransitionCollection([
relabel_edge_ids(transition) for transition in obj.transitions
])

relabel_edge_ids.register(StateTransitionCollection)(__relabel_stc)

@relabel_edge_ids.register(StateTransition)
def _(obj: StateTransition) -> StateTransition: # type: ignore[misc]

def __relabel_st(obj: StateTransition) -> StateTransition: # type: ignore[misc]
mapping = __get_default_relabel_mapping()
return attrs.evolve(
obj,
Expand All @@ -147,6 +167,14 @@ def _(obj: StateTransition) -> StateTransition: # type: ignore[misc]
)


if get_qrules_version() < (0, 10):
relabel_edge_ids.register(StateTransition)(__relabel_st)
else:
from qrules.topology import FrozenTransition

relabel_edge_ids.register(FrozenTransition)(__relabel_st)


@relabel_edge_ids.register(Topology)
def _(obj: Topology) -> Topology: # type: ignore[misc]
mapping = __get_default_relabel_mapping()
Expand Down
Loading

0 comments on commit 3e26e8d

Please sign in to comment.