Symbolic model serialization#

Hide code cell content
from pathlib import Path
from textwrap import shorten

import graphviz
import polarimetry
import sympy as sp
from ampform.io import aslatex
from ampform.sympy import unevaluated
from IPython.display import Markdown, Math
from polarimetry.amplitude import simplify_latex_rendering
from polarimetry.io import perform_cached_doit
from polarimetry.lhcb import load_model
from polarimetry.lhcb.particle import load_particles
from sympy.printing.mathml import MathMLPresentationPrinter

simplify_latex_rendering()

Expression trees#

SymPy expressions are built up from symbols and mathematical operations as follows:

x, y, z = sp.symbols("x y z")
expression = sp.sin(x * y) / 2 - x**2 + 1 / z
expression
\[\displaystyle - x^{2} + \frac{\sin{\left(x y \right)}}{2} + \frac{1}{z}\]

In the back, SymPy represents these expressions as trees. There are a few ways to visualize this for this particular example:

sp.printing.tree.print_tree(expression, assumptions=False)
Add: -x**2 + sin(x*y)/2 + 1/z
+-Pow: 1/z
| +-Symbol: z
| +-NegativeOne: -1
+-Mul: sin(x*y)/2
| +-Half: 1/2
| +-sin: sin(x*y)
|   +-Mul: x*y
|     +-Symbol: x
|     +-Symbol: y
+-Mul: -x**2
  +-NegativeOne: -1
  +-Pow: x**2
    +-Symbol: x
    +-Integer: 2
Hide code cell source
src = sp.dotprint(
    expression,
    styles=[
        (sp.Number, {"color": "grey", "fontcolor": "grey"}),
        (sp.Symbol, {"color": "royalblue", "fontcolor": "royalblue"}),
    ],
)
graphviz.Source(src)
../_images/835e0fa8d228e4245d9775d9d9fc0dd55d7f03eb5161088dd48b827149e3d1d2.svg

Expression trees are powerful, because we can use them as templates for any human-readable presentation we are interested in. In fact, the LaTeX representation that we saw when constructing the expression was generated by SymPy’s LaTeX printer.

src = sp.latex(expression)
Markdown(f"```latex\n{src}\n```")
- x^{2} + \frac{\sin{\left(x y \right)}}{2} + \frac{1}{z}

Hint

SymPy expressions can serve as a template for generating code!

Here’s a number of other representations:

Hide code cell source
def to_mathml(expr: sp.Expr) -> str:
    printer = MathMLPresentationPrinter()
    xml = printer._print(expr)
    return xml.toprettyxml().replace("\t", "  ")


Markdown(
    f"""
```python
# Python
{sp.pycode(expression)}
```
```cpp
// C++
{sp.cxxcode(expression, standard="c++17")}
```
```fortran
! Fortran
{sp.fcode(expression).strip()}
```
```matlab
% Matlab / Octave
{sp.octave_code(expression)}
```
```julia
# Julia
{sp.julia_code(expression)}
```
```rust
// Rust
{sp.rust_code(expression)} 
```
```xml
<!-- MathML -->
{to_mathml(expression)}
```
"""
)
# Python
-x**2 + (1/2)*math.sin(x*y) + 1/z
// C++
-std::pow(x, 2) + (1.0/2.0)*std::sin(x*y) + 1.0/z
! Fortran
-x**2 + (1.0d0/2.0d0)*sin(x*y) + 1d0/z
% Matlab / Octave
-x.^2 + sin(x.*y)/2 + 1./z
# Julia
-x .^ 2 + sin(x .* y) / 2 + 1 ./ z
// Rust
-x.powi(2) + (1_f64/2.0)*(x*y).sin() + z.recip() 
<!-- MathML -->
<mrow>
  <mrow>
    <mo>-</mo>
    <msup>
      <mi>x</mi>
      <mn>2</mn>
    </msup>
  </mrow>
  <mo>+</mo>
  <mrow>
    <mfrac>
      <mrow>
        <mi>sin</mi>
        <mfenced>
          <mrow>
            <mi>x</mi>
            <mo>&InvisibleTimes;</mo>
            <mi>y</mi>
          </mrow>
        </mfenced>
      </mrow>
      <mn>2</mn>
    </mfrac>
  </mrow>
  <mo>+</mo>
  <mfrac>
    <mn>1</mn>
    <mi>z</mi>
  </mfrac>
</mrow>

Foldable expressions#

The previous example is quite simple, but SymPy works just as well with huge expressions, as we will see in Large expressions. Before, though, let’s have a look how to define these larger expressions in such a way that we can still read them. A nice solution is to define sp.Expr classes with the @unevaluated decorator (see ComPWA/ampform#364). Here, we define a Chew-Mandelstam function \(\rho^\text{CM}\) for \(S\)-waves. This function requires the definition of a break-up momentum \(q\).

@unevaluated(real=False)
class PhspFactorSWave(sp.Expr):
    s: sp.Symbol
    m1: sp.Symbol
    m2: sp.Symbol
    _latex_repr_ = R"\rho^\text{{CM}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        q = BreakupMomentum(s, m1, m2)
        cm = (
            (2 * q / sp.sqrt(s))
            * sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))
            - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
        ) / (16 * sp.pi**2)
        return 16 * sp.pi * sp.I * cm


@unevaluated(real=False)
class BreakupMomentum(sp.Expr):
    s: sp.Symbol
    m1: sp.Symbol
    m2: sp.Symbol
    _latex_repr_ = R"q\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2) / (s * 4))

We now have a very clean mathematical representation of how the \(\rho^\text{CM}\) function is defined in terms of \(q\):

s, m1, m2 = sp.symbols("s m1 m2")
q_expr = BreakupMomentum(s, m1, m2)
ρ_expr = PhspFactorSWave(s, m1, m2)
Math(aslatex({e: e.evaluate() for e in [ρ_expr, q_expr]}))
\[\begin{split}\displaystyle \begin{array}{rcl} \rho^\text{CM}\left(s\right) &=& \frac{i \left(- \left(m_{1}^{2} - m_{2}^{2}\right) \left(- \frac{1}{\left(m_{1} + m_{2}\right)^{2}} + \frac{1}{s}\right) \log{\left(\frac{m_{1}}{m_{2}} \right)} + \frac{2 \log{\left(\frac{m_{1}^{2} + m_{2}^{2} + 2 \sqrt{s} q\left(s\right) - s}{2 m_{1} m_{2}} \right)} q\left(s\right)}{\sqrt{s}}\right)}{\pi} \\ q\left(s\right) &=& \frac{\sqrt{\frac{\left(s - \left(m_{1} - m_{2}\right)^{2}\right) \left(s - \left(m_{1} + m_{2}\right)^{2}\right)}{s}}}{2} \\ \end{array}\end{split}\]

Now, let’s build up a more complicated expression that contains this phase space factor. Here, we use SymPy to derive a Breit-Wigner using a single-channel \(K\) matrix [Chung et al., 1995]:

I = sp.Identity(n=1)
K = sp.MatrixSymbol("K", m=1, n=1)
ρ = sp.MatrixSymbol("rho", m=1, n=1)
T = (I - sp.I * K * ρ).inv() * K
T
\[\displaystyle \left(\mathbb{I} + - i K \rho\right)^{-1} K\]
T.as_explicit()[0, 0]
\[\displaystyle \frac{K_{0, 0}}{- i K_{0, 0} \rho_{0, 0} + 1}\]

Here we need to provide definitions for the matrix elements of \(K\) and \(\rho\). A suitable choice is our phase space factor for \(S\) waves we defined above:

m0, Γ0, γ0 = sp.symbols("m0 Gamma0 gamma0")
K_expr = (γ0**2 * m0 * Γ0) / (s - m0**2)
substitutions = {
    K[0, 0]: K_expr,
    ρ[0, 0]: ρ_expr,
}
Math(aslatex(substitutions))
\[\begin{split}\displaystyle \begin{array}{rcl} K_{0, 0} &=& \frac{\Gamma_{0} \gamma_{0}^{2} m_{0}}{- m_{0}^{2} + s} \\ \rho_{0, 0} &=& \rho^\text{CM}\left(s\right) \\ \end{array}\end{split}\]

And there we have it! After some algebraic simplifications, we get a Breit-Wigner with Chew-Mandelstam phase space factor for \(S\) waves:

T_expr = T.as_explicit().xreplace(substitutions)
BW_expr = T_expr[0, 0].simplify(doit=False)
BW_expr
\[\displaystyle \frac{\Gamma_{0} \gamma_{0}^{2} m_{0}}{- i \Gamma_{0} \gamma_{0}^{2} m_{0} \rho^\text{CM}\left(s\right) - m_{0}^{2} + s}\]

The expression tree now has a node that is ‘folded’:

Hide code cell source
dot_style = [
    (sp.Basic, {"style": "filled", "fillcolor": "white"}),
    (sp.Atom, {"color": "gray", "style": "filled", "fillcolor": "white"}),
    (sp.Symbol, {"color": "dodgerblue1"}),
    (PhspFactorSWave, {"color": "indianred2"}),
]
dot = sp.dotprint(BW_expr, bgcolor=None, styles=dot_style)
graphviz.Source(dot)
../_images/25770ec8d019309013ad3947b85dd0a7e73f791af4e8b69130ac9be27d078f02.svg

After unfolding, we get the full expression tree of fundamental mathematical operations:

Hide code cell source
dot = sp.dotprint(BW_expr.doit(), bgcolor=None, styles=dot_style)
graphviz.Source(dot)
../_images/43908a020149d350fe0363d0d95bc422c95477a65f6672d2230ae636c3940476.svg

Large expressions#

Here, we import the large symbolic intensity expression that was used for 10.1007/JHEP07(2023)228 and see how well SymPy serialization performs on a much more complicated model.

DATA_DIR = Path(polarimetry.__file__).parent / "lhcb"
PARTICLES = load_particles(DATA_DIR / "particle-definitions.yaml")
MODEL = load_model(DATA_DIR / "model-definitions.yaml", PARTICLES, model_id=0)
unfolded_intensity_expr = perform_cached_doit(MODEL.full_expression)

The model contains 43,198 mathematical operations. See ComPWA/polarimetry#319 for the origin of this investigation.

Serialization with srepr#

SymPy expressions can directly be serialized to Python code as well, with the function srepr(). For the full intensity expression, we can do so with:

%%time

eval_str = sp.srepr(unfolded_intensity_expr)
CPU times: user 966 ms, sys: 0 ns, total: 966 ms
Wall time: 966 ms
Hide code cell source
n_nodes = sp.count_ops(unfolded_intensity_expr)
byt = len(eval_str.encode("utf-8"))
mb = f"{1e-6 * byt:.2f}"
rendering = shorten(eval_str, placeholder=" ...", width=85)
src = f"""
This serializes the intensity expression of {n_nodes:,d} nodes
to a string of **{mb} MB**.

```python
{rendering} {")" * (rendering.count("(") - rendering.count(")"))}
```
"""
Markdown(src)

This serializes the intensity expression of 43,198 nodes to a string of 1.04 MB.

Add(Pow(Abs(Add(Mul(Add(Mul(Integer(-1), Pow(Add(Mul(Integer(-1), I, ... ))))))))))

It is up to the user, however, to import the classes of each exported node before the string can be unparsed with eval() (see this comment).

imported_intensity_expr = eval(eval_str)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[21], line 1
----> 1 imported_intensity_expr = eval(eval_str)

File <string>:1

NameError: name 'Add' is not defined

In the case of this intensity expression, it is sufficient to import all definition from the main sympy module and the Str class. Optionally, the required import statements can be embedded into the string:

exec_str = f"""\
from sympy import *
from sympy.core.symbol import Str

def get_intensity_function() -> Expr:
    return {eval_str}
"""
exec_filename = Path("../_static/exported_intensity_model.py")
with open(exec_filename, "w") as f:
    f.write(exec_str)

See exported_intensity_model.py for the exported model.

The parsing is then done with exec() instead of the eval() function:

%%time

exec(exec_str)
imported_intensity_expr = get_intensity_function()
CPU times: user 411 ms, sys: 88.1 ms, total: 499 ms
Wall time: 498 ms

Notice how the imported expression is exactly the same as the serialized one, including assumptions:

assert imported_intensity_expr == unfolded_intensity_expr
assert hash(imported_intensity_expr) == hash(unfolded_intensity_expr)

Common sub-expressions#

A problem is that the expression exported generated with srepr() is not human-readable in practice for large expressions. One way out may be to extract common components of the main expression with Foldable expressions. Another may be to use SymPy to detect and collect common sub-expressions.

sub_exprs, common_expr = sp.cse(unfolded_intensity_expr, order="none")
Hide code cell source
Math(sp.multiline_latex(sp.Symbol("I"), common_expr[0], environment="eqnarray"))
\[\begin{split}\displaystyle \begin{eqnarray} I & = & \left|{x_{113} x_{118} + x_{205} x_{210} + x_{220} x_{223} + x_{239} x_{240} - x_{247} x_{249} - x_{256} x_{258} - x_{262} x_{263} - x_{268} x_{269} + x_{273} \mathcal{H}^\mathrm{production}_{K(892), -1, - \frac{1}{2}} - x_{275} \mathcal{H}^\mathrm{production}_{K(892), 1, \frac{1}{2}} + x_{29} x_{34} + x_{35} x_{38}}\right|^{2} \nonumber\\ & & + \left|{- x_{113} x_{249} - x_{118} x_{247} + x_{205} x_{269} + x_{210} x_{268} - x_{220} x_{240} + x_{223} x_{239} + x_{256} x_{263} - x_{258} x_{262} - x_{270} x_{35} + x_{272} x_{34} \mathcal{H}^\mathrm{production}_{K(892), 1, \frac{1}{2}} + x_{272} x_{38} \mathcal{H}^\mathrm{production}_{K(892), -1, - \frac{1}{2}} + x_{274} x_{29}}\right|^{2} \nonumber\\ & & + \left|{- x_{113} x_{263} + x_{118} x_{262} + x_{205} x_{223} + x_{210} x_{220} - x_{239} x_{269} + x_{240} x_{268} - x_{247} x_{258} - x_{249} x_{256} + x_{273} \mathcal{H}^\mathrm{production}_{K(892), 1, \frac{1}{2}} - x_{275} \mathcal{H}^\mathrm{production}_{K(892), -1, - \frac{1}{2}} + x_{29} x_{38} + x_{34} x_{35}}\right|^{2} \nonumber\\ & & + \left|{x_{118} \left(x_{251} x_{281} + x_{253} x_{283} + x_{255} x_{285}\right) + x_{210} \left(- x_{226} x_{303} - x_{228} x_{304} - x_{230} x_{305} - x_{232} x_{306} - x_{236} x_{299} x_{307} - x_{238} x_{301} x_{307}\right) + x_{223} \left(x_{126} x_{130} x_{163} x_{177} x_{178} x_{179} x_{185} x_{186} x_{187} x_{188} x_{233} x_{299} \mathcal{H}^\mathrm{decay}_{L(1520), 0, \frac{1}{2}} \mathcal{H}^\mathrm{production}_{L(1520), \frac{1}{2}, 0} + x_{126} x_{130} x_{177} x_{178} x_{179} x_{191} x_{199} x_{200} x_{201} x_{202} x_{233} x_{301} \mathcal{H}^\mathrm{decay}_{L(1690), 0, \frac{1}{2}} \mathcal{H}^\mathrm{production}_{L(1690), \frac{1}{2}, 0} - x_{264} x_{303} - x_{265} x_{304} - x_{266} x_{305} - x_{267} x_{306}\right) + x_{240} \left(x_{135} x_{292} + x_{141} x_{294} + x_{146} x_{296} + x_{162} x_{298} + x_{190} x_{300} + x_{204} x_{302}\right) + x_{249} \left(x_{259} x_{281} + x_{260} x_{283} + x_{261} x_{285}\right) + x_{258} \left(x_{112} x_{290} + x_{287} x_{76} + x_{289} x_{96}\right) + x_{263} \left(x_{243} x_{286} x_{287} + x_{245} x_{288} x_{289} + x_{246} x_{288} x_{290}\right) + x_{269} \left(- x_{211} x_{292} - x_{212} x_{294} - x_{213} x_{296} - x_{215} x_{298} - x_{217} x_{300} - x_{219} x_{302}\right) + x_{270} \left(x_{276} \mathcal{H}^\mathrm{production}_{K(1430), 0, \frac{1}{2}} + x_{277} \mathcal{H}^\mathrm{production}_{K(700), 0, \frac{1}{2}} + x_{279} \mathcal{H}^\mathrm{production}_{K(892), 0, \frac{1}{2}}\right) + x_{274} \left(- x_{276} \mathcal{H}^\mathrm{production}_{K(1430), 0, - \frac{1}{2}} - x_{277} \mathcal{H}^\mathrm{production}_{K(700), 0, - \frac{1}{2}} - x_{279} \mathcal{H}^\mathrm{production}_{K(892), 0, - \frac{1}{2}}\right) - x_{308} x_{34} \mathcal{H}^\mathrm{production}_{K(892), -1, - \frac{1}{2}} - x_{308} x_{38} \mathcal{H}^\mathrm{production}_{K(892), 1, \frac{1}{2}}}\right|^{2} \end{eqnarray}\end{split}\]
Hide code cell source
Math(aslatex(dict(sub_exprs[:10])))
\[\begin{split}\displaystyle \begin{array}{rcl} x_{0} &=& m_{K(1430)}^{2} \\ x_{1} &=& m_{2}^{2} \\ x_{2} &=& m_{3}^{2} \\ x_{3} &=& \frac{x_{1}}{2} - x_{2} \\ x_{4} &=& i \left(\sigma_{1} + x_{3}\right) \\ x_{5} &=& \frac{\Gamma_{K(1430)} m_{K(1430)} x_{4} e^{- \gamma_{K(1430)} \sigma_{1}}}{x_{0} + x_{3}} + \sigma_{1} - x_{0} \\ x_{6} &=& \frac{\mathcal{H}^\mathrm{decay}_{K(1430), 0, 0}}{x_{5}} \\ x_{7} &=& m_{K(700)}^{2} \\ x_{8} &=& \frac{\Gamma_{K(700)} m_{K(700)} x_{4} e^{- \gamma_{K(700)} \sigma_{1}}}{x_{3} + x_{7}} + \sigma_{1} - x_{7} \\ x_{9} &=& \frac{\mathcal{H}^\mathrm{decay}_{K(700), 0, 0}}{x_{8}} \\ \end{array}\end{split}\]

This already works quite well with sp.lambdify (without cse=True, this would takes minutes):

%%time

args = sorted(unfolded_intensity_expr.free_symbols, key=str)
_ = sp.lambdify(args, unfolded_intensity_expr, cse=True, dummify=True)
CPU times: user 1.5 s, sys: 3.97 ms, total: 1.5 s
Wall time: 1.5 s

Still, as can be seen above, there are many sub-expressions that have exactly the same form. It would be better to find those expressions that have a similar structure, so that we can serialize them to functions or custom sub-definitions.

In SymPy, the equivalence between the expressions can be determined by the match() method using Wild symbols. We therefore first have to make all symbols in the common sub-expressions ‘wild’. In addition, in the case of this intensity expression, some of symbols are indexed and need to be replaced first.

pure_symbol_expr = unfolded_intensity_expr.replace(
    query=lambda z: isinstance(z, sp.Indexed),
    value=lambda z: sp.Symbol(sp.latex(z), **z.assumptions0),
)
sub_exprs, common_expr = sp.cse(pure_symbol_expr, order="none")

Note that for example the following two common sub-expressions are equivalent:

\[\begin{split}\displaystyle \begin{array}{rcl} x_{5} &=& \frac{\Gamma_{K(1430)} m_{K(1430)} x_{4} e^{- \gamma_{K(1430)} \sigma_{1}}}{x_{0} + x_{3}} + \sigma_{1} - x_{0} \\ x_{8} &=& \frac{\Gamma_{K(700)} m_{K(700)} x_{4} e^{- \gamma_{K(700)} \sigma_{1}}}{x_{3} + x_{7}} + \sigma_{1} - x_{7} \\ \end{array}\end{split}\]

Wild symbols now allow us to find how these expressions relate to each other.

is_symbol = lambda z: isinstance(z, sp.Symbol)
make_wild = lambda z: sp.Wild(z.name)
X = [x.replace(is_symbol, make_wild) for _, x in sub_exprs]
Math(aslatex(X[5].match(X[8])))
\[\begin{split}\displaystyle \begin{array}{rcl} \gamma_{K(700)} &=& \gamma_{K(1430)} \\ x_{7} &=& x_{0} \\ m_{K(700)} &=& m_{K(1430)} \\ \Gamma_{K(700)} &=& \Gamma_{K(1430)} \\ \end{array}\end{split}\]

Hint

This can be used to define functions for larger, common expression blocks.