Skip to content

Commit

Permalink
ENH: adapt implementation to QRules v0.10
Browse files Browse the repository at this point in the history
  • Loading branch information
redeboer committed Oct 13, 2023
1 parent 00e3b56 commit e7c3327
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 25 deletions.
31 changes: 26 additions & 5 deletions src/ampform/helicity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
from attrs.validators import deep_iterable, instance_of, optional
from qrules.combinatorics import perform_external_edge_identical_particle_combinatorics
from qrules.particle import Particle
from qrules.transition import ReactionInfo, StateTransition
from qrules.transition import (
InteractionProperties,
ReactionInfo,
State,
StateTransition,
)

from ampform._qrules import get_qrules_version
from ampform.dynamics.builder import (
ResonanceDynamicsBuilder,
TwoBodyKinematicVariableSet,
Expand Down Expand Up @@ -70,6 +76,7 @@

if TYPE_CHECKING:
from IPython.lib.pretty import PrettyPrinter
from qrules.topology import MutableTransition

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -450,11 +457,9 @@ def __formulate_topology_amplitude(
) -> sp.Expr:
sequential_expressions: list[sp.Expr] = []
for transition in transitions:
sequential_graphs = perform_external_edge_identical_particle_combinatorics(
transition.to_graph()
)
sequential_graphs = _perform_combinatorics(transition)
for graph in sequential_graphs:
first_transition = StateTransition.from_graph(graph)
first_transition = _freeze(graph)
expression = self.__formulate_sequential_decay(first_transition)
sequential_expressions.append(expression)

Expand Down Expand Up @@ -558,6 +563,22 @@ def __generate_amplitude_prefactor(
return None


def _perform_combinatorics(
transition: StateTransition,
) -> list[MutableTransition[State, InteractionProperties]]:
if get_qrules_version() < (0, 10):
return perform_external_edge_identical_particle_combinatorics(
transition.to_graph() # type: ignore[attr-defined]
)
return perform_external_edge_identical_particle_combinatorics(transition.unfreeze())


def _freeze(graph: MutableTransition[State, InteractionProperties]) -> StateTransition:
if get_qrules_version() < (0, 10):
return StateTransition.from_graph(graph) # type: ignore[attr-defined]
return graph.freeze()


class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder):
r"""Amplitude model generator for the canonical helicity formalism.
Expand Down
52 changes: 40 additions & 12 deletions src/ampform/helicity/align/dpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from attrs import define, field
from attrs.validators import in_
from qrules.topology import Topology
from qrules.transition import ReactionInfo, StateTransition, StateTransitionCollection
from qrules.transition import ReactionInfo, StateTransition
from sympy.physics.quantum.spin import Rotation as Wigner

from ampform._qrules import get_qrules_version
from ampform.helicity.align import SpinAlignment
from ampform.helicity.decay import (
get_outer_state_ids,
Expand All @@ -34,6 +35,11 @@
if TYPE_CHECKING:
from sympy.physics.quantum.spin import WignerD

if get_qrules_version() < (0, 10):
from qrules.transition import ( # type: ignore[attr-defined]
StateTransitionCollection,
)


@define
class DalitzPlotDecomposition(SpinAlignment):
Expand Down Expand Up @@ -109,8 +115,14 @@ def __call__(
return Wigner.d(j, m, m_prime, zeta)


T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology)
"""Allowed types for :func:`relabel_edge_ids`."""
if get_qrules_version() < (0, 10):
T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology)
"""Allowed types for :func:`relabel_edge_ids`."""
else:
T = TypeVar( # type: ignore[misc] # pyright: ignore[reportConstantRedefinition]
"T", ReactionInfo, StateTransition, Topology
)
"""Allowed types for :func:`relabel_edge_ids`."""


@singledispatch
Expand All @@ -121,21 +133,29 @@ def relabel_edge_ids(obj: T) -> T:

@relabel_edge_ids.register(ReactionInfo)
def _(obj: ReactionInfo) -> ReactionInfo: # type: ignore[misc]
return ReactionInfo( # no attrs.evolve() in order to call __attrs_post_init__()
transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups],
if get_qrules_version() < (0, 10):
return ReactionInfo( # type: ignore[call-arg]
transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], # type: ignore[attr-defined]
formalism=obj.formalism,
)
return ReactionInfo(
# no attrs.evolve() in order to call __attrs_post_init__()
transitions=[relabel_edge_ids(g) for g in obj.transitions],
formalism=obj.formalism,
)


@relabel_edge_ids.register(StateTransitionCollection)
def _(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc]
return StateTransitionCollection( # no attrs.evolve() for __attrs_post_init__()
[relabel_edge_ids(transition) for transition in obj.transitions]
)
if get_qrules_version() < (0, 10):

def __relabel_stc(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc]
return StateTransitionCollection(
[relabel_edge_ids(transition) for transition in obj.transitions]
)

@relabel_edge_ids.register(StateTransition)
def _(obj: StateTransition) -> StateTransition: # type: ignore[misc]
relabel_edge_ids.register(StateTransitionCollection)(__relabel_stc)


def __relabel_st(obj: StateTransition) -> StateTransition: # type: ignore[misc]
mapping = __get_default_relabel_mapping()
return attrs.evolve(
obj,
Expand All @@ -144,6 +164,14 @@ def _(obj: StateTransition) -> StateTransition: # type: ignore[misc]
)


if get_qrules_version() < (0, 10):
relabel_edge_ids.register(StateTransition)(__relabel_st)
else:
from qrules.topology import FrozenTransition

relabel_edge_ids.register(FrozenTransition)(__relabel_st)


@relabel_edge_ids.register(Topology)
def _(obj: Topology) -> Topology: # type: ignore[misc]
mapping = __get_default_relabel_mapping()
Expand Down
48 changes: 42 additions & 6 deletions src/ampform/helicity/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@
from typing import TYPE_CHECKING, Iterable

from attrs import frozen
from qrules.quantum_numbers import InteractionProperties
from qrules.transition import ReactionInfo, State, StateTransition

from ampform._qrules import get_qrules_version

if TYPE_CHECKING:
from qrules.quantum_numbers import InteractionProperties
from qrules.topology import Topology

if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal
if sys.version_info < (3, 10):
from typing_extensions import TypeGuard
else:
from typing import TypeGuard


@frozen
Expand Down Expand Up @@ -103,12 +109,30 @@ def _(obj: TwoBodyDecay) -> TwoBodyDecay:
def _(obj: tuple) -> TwoBodyDecay:
if len(obj) == 2: # noqa: PLR2004
transition, node_id = obj
if isinstance(transition, StateTransition) and isinstance(node_id, int):
return TwoBodyDecay.from_transition(*obj)
if _is_qrules_state_transition(transition) and isinstance(node_id, int):
return TwoBodyDecay.from_transition(transition, node_id)
msg = f"Cannot create a {TwoBodyDecay.__name__} from {obj}"
raise NotImplementedError(msg)


def _is_qrules_state_transition(obj) -> TypeGuard[StateTransition]:
if get_qrules_version() >= (0, 10):
from qrules.topology import FrozenTransition

if isinstance(obj, FrozenTransition):
if any(not isinstance(s, State) for s in obj.states.values()):
return False
if any(
not isinstance(i, InteractionProperties)
for i in obj.interactions.values()
):
return False
return True
if get_qrules_version() < (0, 10) and isinstance(obj, StateTransition): # type: ignore[misc]
return True
return False


@lru_cache(maxsize=None)
def is_opposite_helicity_state(topology: Topology, state_id: int) -> bool:
"""Determine if an edge is an "opposite helicity" state.
Expand Down Expand Up @@ -328,8 +352,13 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in
>>> from qrules.topology import create_isobar_topologies
>>> topologies = create_isobar_topologies(5)
>>> determine_attached_final_state(topologies[0], state_id=5)
>>> determine_attached_final_state(topologies[3], state_id=5)
[0, 3, 4]
>>> import pytest
>>> from ampform._qrules import get_qrules_version
>>> if get_qrules_version() < (0, 10):
... pytest.skip('Doctest only works for qrules>=0.10')
...
"""
edge = topology.edges[state_id]
if edge.ending_node_id is None:
Expand All @@ -343,13 +372,20 @@ def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]:
raise NotImplementedError(msg)


@get_outer_state_ids.register(StateTransition)
def _(transition: StateTransition) -> list[int]:
def __convert_state_transition(transition: StateTransition) -> list[int]:
outer_state_ids = list(transition.initial_states)
outer_state_ids += sorted(transition.final_states)
return outer_state_ids


if get_qrules_version() < (0, 10):
get_outer_state_ids.register(StateTransition)(__convert_state_transition)
else:
from qrules.topology import FrozenTransition

get_outer_state_ids.register(FrozenTransition)(__convert_state_transition)


@get_outer_state_ids.register(ReactionInfo)
def _(reaction: ReactionInfo) -> list[int]:
return get_outer_state_ids(reaction.transitions[0])
Expand Down
12 changes: 10 additions & 2 deletions src/ampform/kinematics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from qrules.topology import Topology
from qrules.transition import ReactionInfo, StateTransition

from ampform._qrules import get_qrules_version
from ampform.helicity.decay import assert_isobar_topology
from ampform.kinematics.angles import compute_helicity_angles
from ampform.kinematics.lorentz import (
Expand Down Expand Up @@ -120,6 +121,13 @@ def _(obj: Topology) -> Topology:
return obj


@_get_topology.register(StateTransition)
def _(obj: StateTransition) -> Topology:
def __get_state_transition(obj: StateTransition) -> Topology:
return obj.topology


if get_qrules_version() < (0, 10):
_get_topology.register(StateTransition)(__get_state_transition)
else:
from qrules.topology import FrozenTransition

_get_topology.register(FrozenTransition)(__get_state_transition)

0 comments on commit e7c3327

Please sign in to comment.