Extended DataSample performance#

Hide code cell content
import logging

import ampform
import numpy as np
import qrules
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)
from tensorwaves.data import (
    IntensityDistributionGenerator,
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)
from tensorwaves.function.sympy import create_parametrized_function

LOGGER = logging.getLogger("absl")
LOGGER.setLevel(logging.ERROR)
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)

Generate amplitude model#

Formulate a HelicityModel just like in the usual workflow:

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism="helicity",
)

builder = ampform.get_builder(reaction)
builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = builder.formulate()

Now register more topologies with HelicityAdapter.permutate_registered_topologies() and formulate a new ‘extended’ model:

builder.adapter.permutate_registered_topologies()
extended_model = builder.formulate()

Create computational functions#

Now, create ParametrizedFunctions for the normal model and the extended model:

intensity = create_parametrized_function(
    expression=model.expression.doit(),
    parameters=model.parameter_defaults,
    backend="jax",
)
helicity_transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)
extended_intensity = create_parametrized_function(
    expression=extended_model.expression.doit(),
    parameters=extended_model.parameter_defaults,
    backend="jax",
)
extended_helicity_transformer = SympyDataTransformer.from_sympy(
    extended_model.kinematic_variables, backend="jax"
)

Generate data#

Generate phase space domain and hit-and-miss data sample with the normal intensity function and helicity transformer…

phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
data_generator = IntensityDistributionGenerator(
    function=intensity,
    domain_generator=phsp_generator,
    domain_transformer=helicity_transformer,
)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_momenta = phsp_generator.generate(100_000, rng)
data_momenta = data_generator.generate(10_000, rng)
phsp = helicity_transformer(phsp_momenta)
data = helicity_transformer(data_momenta)

…and with the extended function and transformer:

extended_phsp_generator = TFPhaseSpaceGenerator(
    # actually same as phsp_generator
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
extended_data_generator = IntensityDistributionGenerator(
    function=extended_intensity,
    domain_generator=phsp_generator,
    domain_transformer=helicity_transformer,
)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_momenta = extended_phsp_generator.generate(100_000, rng)
data_momenta = extended_data_generator.generate(10_000, rng)
extended_phsp = extended_helicity_transformer(phsp_momenta)
extended_data = extended_helicity_transformer(data_momenta)

Conclusion#

intensities = intensity(phsp)
extended_intensities = extended_intensity(extended_phsp)
extended_intensities.shape
(100000,)

Computation time per iteration is the same:

%timeit -n10 intensity(phsp)

%timeit -n10 extended_intensity(extended_phsp)
14.7 ms ± 761 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
14.7 ms ± 669 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Output arrays are also the same:

np.testing.assert_allclose(intensities, extended_intensities)

assert set(data) < set(extended_data)
assert set(phsp) < set(extended_phsp)
for var in data:
    np.testing.assert_allclose(phsp[var], extended_phsp[var])
    np.testing.assert_allclose(data[var], extended_data[var])