Skip to content

Commit

Permalink
ENH: adapt implementation to QRules v0.10
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Oct 17, 2023
1 parent 00e3b56 commit 90194d1
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 31 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ reportPrivateImportUsage = false
reportPrivateUsage = false
reportUnboundVariable = false
reportUnknownArgumentType = false
reportUnknownLambdaType = false
reportUnknownMemberType = false
reportUnknownParameterType = false
reportUnknownVariableType = false
Expand Down
33 changes: 28 additions & 5 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
from attrs.validators import deep_iterable, instance_of, optional
from qrules.combinatorics import perform_external_edge_identical_particle_combinatorics
from qrules.particle import Particle
from qrules.transition import ReactionInfo, StateTransition
from qrules.transition import (
InteractionProperties,
ReactionInfo,
State,
StateTransition,
)

from ampform._qrules import get_qrules_version
from ampform.dynamics.builder import (
ResonanceDynamicsBuilder,
TwoBodyKinematicVariableSet,
Expand Down Expand Up @@ -70,6 +76,7 @@

if TYPE_CHECKING:
from IPython.lib.pretty import PrettyPrinter
from qrules.topology import MutableTransition

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -450,11 +457,9 @@ def __formulate_topology_amplitude(
) -> sp.Expr:
sequential_expressions: list[sp.Expr] = []
for transition in transitions:
sequential_graphs = perform_external_edge_identical_particle_combinatorics(
transition.to_graph()
)
sequential_graphs = _perform_combinatorics(transition)
for graph in sequential_graphs:
first_transition = StateTransition.from_graph(graph)
first_transition = _freeze(graph)
expression = self.__formulate_sequential_decay(first_transition)
sequential_expressions.append(expression)

Expand Down Expand Up @@ -558,6 +563,24 @@ def __generate_amplitude_prefactor(
return None


def _perform_combinatorics(
transition: StateTransition,
) -> list[MutableTransition[State, InteractionProperties]]:
if get_qrules_version() < (0, 10):
return perform_external_edge_identical_particle_combinatorics(
transition.to_graph() # type: ignore[attr-defined]
)
graph = transition.convert(lambda s: (s.particle, s.spin_projection)).unfreeze()
combinations = perform_external_edge_identical_particle_combinatorics(graph)
return [g.freeze().convert(lambda s: State(*s)).unfreeze() for g in combinations]


def _freeze(graph: MutableTransition[State, InteractionProperties]) -> StateTransition:
if get_qrules_version() < (0, 10):
return StateTransition.from_graph(graph) # type: ignore[attr-defined]
return graph.freeze()


class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder):
r"""Amplitude model generator for the canonical helicity formalism.
Expand Down
52 changes: 40 additions & 12 deletions src/ampform/helicity/align/dpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from attrs import define, field
from attrs.validators import in_
from qrules.topology import Topology
from qrules.transition import ReactionInfo, StateTransition, StateTransitionCollection
from qrules.transition import ReactionInfo, StateTransition
from sympy.physics.quantum.spin import Rotation as Wigner

from ampform._qrules import get_qrules_version
from ampform.helicity.align import SpinAlignment
from ampform.helicity.decay import (
get_outer_state_ids,
Expand All @@ -34,6 +35,11 @@
if TYPE_CHECKING:
from sympy.physics.quantum.spin import WignerD

if get_qrules_version() < (0, 10):
from qrules.transition import ( # type: ignore[attr-defined]
StateTransitionCollection,
)


@define
class DalitzPlotDecomposition(SpinAlignment):
Expand Down Expand Up @@ -109,8 +115,14 @@ def __call__(
return Wigner.d(j, m, m_prime, zeta)


T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology)
"""Allowed types for :func:`relabel_edge_ids`."""
if get_qrules_version() < (0, 10):
T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology)
"""Allowed types for :func:`relabel_edge_ids`."""
else:
T = TypeVar( # type: ignore[misc] # pyright: ignore[reportConstantRedefinition]
"T", ReactionInfo, StateTransition, Topology
)
"""Allowed types for :func:`relabel_edge_ids`."""


@singledispatch
Expand All @@ -121,21 +133,29 @@ def relabel_edge_ids(obj: T) -> T:

@relabel_edge_ids.register(ReactionInfo)
def _(obj: ReactionInfo) -> ReactionInfo: # type: ignore[misc]
return ReactionInfo( # no attrs.evolve() in order to call __attrs_post_init__()
transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups],
if get_qrules_version() < (0, 10):
return ReactionInfo( # type: ignore[call-arg]
transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], # type: ignore[attr-defined]
formalism=obj.formalism,
)
return ReactionInfo(
# no attrs.evolve() in order to call __attrs_post_init__()
transitions=[relabel_edge_ids(g) for g in obj.transitions],
formalism=obj.formalism,
)


@relabel_edge_ids.register(StateTransitionCollection)
def _(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc]
return StateTransitionCollection( # no attrs.evolve() for __attrs_post_init__()
[relabel_edge_ids(transition) for transition in obj.transitions]
)
if get_qrules_version() < (0, 10):

def __relabel_stc(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc]
return StateTransitionCollection(
[relabel_edge_ids(transition) for transition in obj.transitions]
)

@relabel_edge_ids.register(StateTransition)
def _(obj: StateTransition) -> StateTransition: # type: ignore[misc]
relabel_edge_ids.register(StateTransitionCollection)(__relabel_stc)


def __relabel_st(obj: StateTransition) -> StateTransition: # type: ignore[misc]
mapping = __get_default_relabel_mapping()
return attrs.evolve(
obj,
Expand All @@ -144,6 +164,14 @@ def _(obj: StateTransition) -> StateTransition: # type: ignore[misc]
)


if get_qrules_version() < (0, 10):
relabel_edge_ids.register(StateTransition)(__relabel_st)
else:
from qrules.topology import FrozenTransition

relabel_edge_ids.register(FrozenTransition)(__relabel_st)


@relabel_edge_ids.register(Topology)
def _(obj: Topology) -> Topology: # type: ignore[misc]
mapping = __get_default_relabel_mapping()
Expand Down
48 changes: 42 additions & 6 deletions src/ampform/helicity/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@
from typing import TYPE_CHECKING, Iterable

from attrs import frozen
from qrules.quantum_numbers import InteractionProperties
from qrules.transition import ReactionInfo, State, StateTransition

from ampform._qrules import get_qrules_version

if TYPE_CHECKING:
from qrules.quantum_numbers import InteractionProperties
from qrules.topology import Topology

if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if sys.version_info < (3, 10):
from typing_extensions import TypeGuard
else:
from typing import TypeGuard


@frozen
Expand Down Expand Up @@ -103,12 +109,30 @@ def _(obj: TwoBodyDecay) -> TwoBodyDecay:
def _(obj: tuple) -> TwoBodyDecay:
if len(obj) == 2: # noqa: PLR2004
transition, node_id = obj
if isinstance(transition, StateTransition) and isinstance(node_id, int):
return TwoBodyDecay.from_transition(*obj)
if _is_qrules_state_transition(transition) and isinstance(node_id, int):
return TwoBodyDecay.from_transition(transition, node_id)
msg = f"Cannot create a {TwoBodyDecay.__name__} from {obj}"
raise NotImplementedError(msg)


def _is_qrules_state_transition(obj) -> TypeGuard[StateTransition]:
if get_qrules_version() >= (0, 10):
from qrules.topology import FrozenTransition

if isinstance(obj, FrozenTransition):
if any(not isinstance(s, State) for s in obj.states.values()):
return False
if any(
not isinstance(i, InteractionProperties)
for i in obj.interactions.values()
):
return False
return True
if get_qrules_version() < (0, 10) and isinstance(obj, StateTransition): # type: ignore[misc]
return True
return False


@lru_cache(maxsize=None)
def is_opposite_helicity_state(topology: Topology, state_id: int) -> bool:
"""Determine if an edge is an "opposite helicity" state.
Expand Down Expand Up @@ -328,8 +352,13 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in
>>> from qrules.topology import create_isobar_topologies
>>> topologies = create_isobar_topologies(5)
>>> determine_attached_final_state(topologies[0], state_id=5)
>>> determine_attached_final_state(topologies[3], state_id=5)
[0, 3, 4]
>>> import pytest
>>> from ampform._qrules import get_qrules_version
>>> if get_qrules_version() < (0, 10):
... pytest.skip('Doctest only works for qrules>=0.10')
...
"""
edge = topology.edges[state_id]
if edge.ending_node_id is None:
Expand All @@ -343,13 +372,20 @@ def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]:
raise NotImplementedError(msg)


@get_outer_state_ids.register(StateTransition)
def _(transition: StateTransition) -> list[int]:
def __convert_state_transition(transition: StateTransition) -> list[int]:
outer_state_ids = list(transition.initial_states)
outer_state_ids += sorted(transition.final_states)
return outer_state_ids


if get_qrules_version() < (0, 10):
get_outer_state_ids.register(StateTransition)(__convert_state_transition)
else:
from qrules.topology import FrozenTransition

get_outer_state_ids.register(FrozenTransition)(__convert_state_transition)


@get_outer_state_ids.register(ReactionInfo)
def _(reaction: ReactionInfo) -> list[int]:
return get_outer_state_ids(reaction.transitions[0])
Expand Down
5 changes: 3 additions & 2 deletions src/ampform/helicity/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,9 @@ def get_boost_chain_suffix(topology: Topology, state_id: int) -> str:
the internal decay topology.
>>> from qrules.topology import create_isobar_topologies
>>> from ampform._qrules import get_qrules_version
>>> topologies = create_isobar_topologies(5)
>>> topology = topologies[0]
>>> topology = topologies[0 if get_qrules_version() < (0, 10) else 3]
>>> for i in topology.intermediate_edge_ids | topology.outgoing_edge_ids:
... suffix = get_boost_chain_suffix(topology, i)
... print(f"{i}: 'phi{suffix}'")
Expand All @@ -364,7 +365,7 @@ def get_boost_chain_suffix(topology: Topology, state_id: int) -> str:
5: 'phi_034'
6: 'phi_12'
7: 'phi_34^034'
>>> topology = topologies[1]
>>> topology = topologies[1 if get_qrules_version() < (0, 10) else 2]
>>> for i in topology.intermediate_edge_ids | topology.outgoing_edge_ids:
... suffix = get_boost_chain_suffix(topology, i)
... print(f"{i}: 'phi{suffix}'")
Expand Down
12 changes: 10 additions & 2 deletions src/ampform/kinematics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from qrules.topology import Topology
from qrules.transition import ReactionInfo, StateTransition

from ampform._qrules import get_qrules_version
from ampform.helicity.decay import assert_isobar_topology
from ampform.kinematics.angles import compute_helicity_angles
from ampform.kinematics.lorentz import (
Expand Down Expand Up @@ -120,6 +121,13 @@ def _(obj: Topology) -> Topology:
return obj


@_get_topology.register(StateTransition)
def _(obj: StateTransition) -> Topology:
def __get_state_transition(obj: StateTransition) -> Topology:
return obj.topology


if get_qrules_version() < (0, 10):
_get_topology.register(StateTransition)(__get_state_transition)
else:
from qrules.topology import FrozenTransition

_get_topology.register(FrozenTransition)(__get_state_transition)
4 changes: 3 additions & 1 deletion src/ampform/kinematics/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,10 @@ def get_invariant_mass_symbol(topology: Topology, state_id: int) -> sp.Symbol:
state :math:`5` is :math:`m_{034}`, because :math:`p_5=p_0+p_3+p_4`:
>>> from qrules.topology import create_isobar_topologies
>>> from ampform._qrules import get_qrules_version
>>> topologies = create_isobar_topologies(5)
>>> get_invariant_mass_symbol(topologies[0], state_id=5)
>>> topology = topologies[0 if get_qrules_version() < (0, 10) else 3]
>>> get_invariant_mass_symbol(topology, state_id=5)
m_034
Naturally, the 'invariant' mass label for a final state is just the mass of the
Expand Down
5 changes: 3 additions & 2 deletions tests/helicity/test_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from qrules.topology import Topology, create_isobar_topologies

from ampform._qrules import get_qrules_version
from ampform.helicity.decay import (
determine_attached_final_state,
get_sibling_state_id,
Expand All @@ -24,10 +25,10 @@ def test_determine_attached_final_state():
topology.outgoing_edge_ids
)
# intermediate states
topology = topologies[0]
topology = topologies[0 if get_qrules_version() < (0, 10) else 1]
assert determine_attached_final_state(topology, state_id=4) == [0, 1]
assert determine_attached_final_state(topology, state_id=5) == [2, 3]
topology = topologies[1]
topology = topologies[1 if get_qrules_version() < (0, 10) else 0]
assert determine_attached_final_state(topology, state_id=4) == [1, 2, 3]
assert determine_attached_final_state(topology, state_id=5) == [2, 3]

Expand Down
3 changes: 2 additions & 1 deletion tests/kinematics/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from qrules.topology import Topology, create_isobar_topologies

from ampform._qrules import get_qrules_version
from ampform.kinematics.lorentz import FourMomenta, create_four_momentum_symbols

if TYPE_CHECKING:
Expand All @@ -18,6 +19,6 @@ def topology_and_momentum_symbols(
n = len(data_sample)
assert n == 4
topologies = create_isobar_topologies(n)
topology = topologies[1]
topology = topologies[1 if get_qrules_version() < (0, 10) else 0]
momentum_symbols = create_four_momentum_symbols(topology)
return topology, momentum_symbols

0 comments on commit 90194d1

Please sign in to comment.