Common Partial Wave Analysis#

Google Colab Binder

The “Common Partial Wave Analysis” organization (ComPWA) aims to make amplitude analysis accessible through transparent and interactive documentation, modern software development tools, and collaboration-independent frameworks. Contact details can be found here.

Main projects#

ComPWA maintains three main Python packages with which you can do a full partial wave analysis. The packages are designed as libraries, so that they can be used separately by other projects.

Each of these libraries come with interactive and interlinked documentation that is intended to bring theory and code closer together. The PWA Pages takes that one step further: it is an independent and easy-to-maintain documentation project that can serve as a central place to gather links to PWA theory and software.

QRules

Documentation build status 10.5281/zenodo.5526360 PyPI package Conda package
Rule-based particle reaction problem solver on a quantum number level

AmpForm

Documentation build status 10.5281/zenodo.5526648 PyPI package Conda package
Automatically generate symbolic amplitude models for Partial Wave Analysis

TensorWaves

Documentation build status 10.5281/zenodo.5526650 PyPI package Conda package
Convert large symbolic expressions to numerical differentiable functions for performing fast fits on large data samples

PWA Pages

Documentation build status
A central knowledge-base for Partial Wave Analysis theory and software

Finally, the technical reports on these pages (compwa.github.io) describe more general tips and tricks, some of which can be of interest to general Python users as well!

Deprecated projects

The main packages listed above originate from the following, deprecated projects:

  • ComPWA: a single framework for Partial Wave Analysis written in C++.

  • pycompwa: the Python interface of ComPWA, which also hosted a first version of the PWA Expert System.

  • PWA Expert System (split into QRules and AmpForm).

Long-term development#

Partial Wave Analysis is a complicated research discipline, where several aspects of quantum field theory, experimental physics, statistics, regression analysis, and high-performance computing come together. This has led to a large number of PWA frameworks that taylor to the need of each collaboration.

This state of affairs is only natural: research requires a flexible and specialized approach. If, say, some background component shows up in an ongoing analysis, one may need to implement some formalism that can handle it or add some alignment parameters that were not yet supplied by the framework. It’s therefore hard to facilitate ongoing research, while at the same time developing a general, long-term PWA tool.

This situation however has a few major disadvantages:

  • Collaborations usually start developing a PWA framework from scratch. It therefore takes years for packages to move beyond the basic PWA formalisms.

  • Development is slow, because expertise is splintered: every group is working on their own package.

  • Results become unreliable: the fewer people use a package, the more bugs will remain unnoticed.

  • Once developers leave, the framework collapses. Sadly, this happens more often than not, as developers are usually scientists on short-term contracts.

ComPWA attempts to break this with the following ideals:

The first point is crucial. ComPWA rather sacrifices functionality for design and developer experience related developments. Many other frameworks have started with the same ideal of having a good software design etc., but soon begin to drop those ideals for the understandable reasons described above. We believe that only by sticking to those ideals, it is possible to maintain a long-term and collaboration-independent common tool.

Developer Experience#

PWA is performed by researchers and this field of study is but relatively small. In practice, users of PWA frameworks are therefore close to the development of the framework itself. This means that, ideally, the gap between the developer and the user should remain as small as possible.

In order to close this gap as best as possible, ComPWA follows the following ideas:

  1. Frameworks are built up as modular libraries with an accessible and well-documented API (see for example here). Users preferably set up their analysis through Jupyter notebooks or scripts that use these libraries. This allows the user to adapt to the specific challenges that their research challenges pose, while at the same time thinking along or improving the with the design and features offered by the library.

  2. New features are added to the library with care, as to not let the library grow over time with features that are specific to certain analyses. Procedures are kept small and general enough, so that they can be used by different user scripts. This relates to 1., because the API and underlying code base should remain understandable.

  3. It should be easy to make the step from using the library to contributing to the source code. The main starting points are the example pages provided by each library (accessible without any software but a browser). From there, it should be just a matter of a running few command lines to start modifying and trying out changes to the code-bases (see e.g. Local set-up) in a standardized and automated environment.

  4. Developments industry techniques and software development tools are followed closely. Despite the specific nature of PWA, many aspects can be performed or supported by existing technologies or packages. Think of the regression process that forms the basis of Machine Learning, but also the growing number of tools that are popular in data science, like Pandas and Jupyter notebooks.

  5. As frameworks are open source and users and developers come and go, decision making is recorded as thoroughly as possible. Fortunately, Git (commit history) and GitHub (issues, PRs, and discussions) make this extremely easy. In addition, larger decisions are recorded in the form of ADRs and explorations of challenges posed by physics and software are hosted in the form of Technical reports.

Design#

  • Code modularity and transparency. For example, separation of qrules, ampform, and tensorwaves. The former two include all of the physics, while tensorwaves can use these amplitude models and perform fits, but aims to keep physics logic contained upstream.

  • Keep the code simple by sticking to the core responsibility: construct amplitude models and fitting them to data. Avoid “feature creep”!

  • Accommodate both stable development and flexibility for ongoing analyses (see e.g. Branching model).

  • ComPWA values the open-closed principle. Where possible, libraries are intended to give users the flexibility to insert custom behavior (like custom dynamics) without having to introduce new updates to the library.

Open source#

ComPWA repositories are intended to be collaboration-independent. As such, they are always public and open source and free to try out and re-use under the GPLv3+ license.

At the same time, open source projects come with many challenges: it is crucial to maintain strict standards for the code-base from the start, when anyone is allowed to contribute.

User Experience#

PWA is a difficult field to get into and to navigate around. ComPWA therefore puts most effort into maintaining easily navigable and interactive documentation that narrows the gap between code and theory.

  • All libraries provide example pages (see e.g. here). These pages are written as Jupyter notebook and provide a close link between code (how to use the library in a script, with links to the API) and theory (explanations of the basics being performed).

  • Texts on the web-pages are thoroughly interlinked, so that the reader can easily to navigate to literature or external resources for more information. The intention is to make PWA a more accessible field for newcomers and to provide reference to the literature that was consulted for the implementation.

  • Almost all ComPWA libraries are written in Python. This makes it easy for analysis users to install and use. In addition, the Python community has developed excellent tools that make it easy to document and maintain a clean codebase, so that is easy to make the step to become a developer.

This page is combines documentation on projects provided by the ComPWA organization on GitHub. It is more technical than the PWA Pages and focuses on the ComPWA organization only. Read more about our ideals and ongoing projects on the main page.

Help developing#

GitPod Open in Visual Studio Code

This page describes some of the tools and conventions followed by Common Partial Wave Analysis. Where possible, we use the source code of the AmpForm repository as example, because its file structure is comparable to that of other ComPWA repositories.

Tip

To start developing, simply run the following from a cloned repository on your machine:

conda env create
conda activate ampform
pre-commit install --install-hooks
python3 -m venv ./venv
source ./venv/bin/activate
python3 -m pip install -c .constraints/py3.8.txt -e .[dev]
pre-commit install --install-hooks

Replace 3.8 with the Python version you use on your machine.

See Virtual environment for more info.

Local set-up#

Virtual environment#

When developing source code, it is safest to work within a virtual environment, so that all package dependencies and developer tools are safely contained. This is helpful in case something goes wrong with the dependencies: just trash the environment and recreate it. In addition, you can easily install other versions of the dependencies, without affecting other packages you may be working on.

Two common tools to manage virtual environments are Conda and Python’s built-in venv. In either case, you have to activate the environment whenever you want to run the framework or use the developer tools.

Conda can be installed without administrator rights. It is recommended to download Miniconda, as it is much smaller than Anaconda. In addition, Conda can install more than just Python packages.

All packages maintained by the ComPWA organization provide a Conda environment file (environment.yml) that defines all requirements when working on the source code of that repository. To create an environment specific for this repository, simply navigate to the main folder of the source code and run:

conda env create

Conda now creates an environment with a name that is defined in the environment.yml file. In addition, it will install the framework itself in “editable” mode, so that you can start developing right away.

If you have Python’s venv, available on your system, you can create a virtual environment with it. Navigate to some convenient folder and run:

python3 -m venv ./venv

This creates a folder called venv where all Python packages will be contained. To activate the environment, run:

source ./venv/bin/activate

Now you can safely install the package you want to work on (see “editable” mode), as well as any additional required packages (see optional dependencies):

pip install -e .
Editable installation#

When developing a package, it is most convenient if you install it in “editable” mode. This allows you to tweak the source code and try out new ideas immediately, because the source code is considered the ‘installation’.

With pip install, a package can be installed in “editable” mode with the -e flag. Simply clone the repository you want to work on, navigate into it, and run:

python3 -m pip install -e .

Internally, this calls:

python3 setup.py develop

This will also install all dependencies required by the package.

Optional dependencies#

Some packages suggest optional dependencies. They can be installed with pip’s “extras” syntax. Some examples would be:

pip install tensorwaves[jax,scipy]
pip install .[test]  # local directory, not editable
pip install -e .[dev]  #  editable + all dev requirements
pip install "tensorwaves[jax,scipy]"
pip install ".[test]"  # local directory, not editable
pip install -e ".[dev]"  #  editable + all dev requirements

Developers require several additional tools besides the dependencies required to run the package itself (see Automated coding conventions). All those additional requirements can be installed with the last example.

Pinning dependency versions#

To ensure that developers use exactly the same versions of the package dependencies and developer requirements, some of the repositories provide constraint files. These files can be used to ‘pin’ all versions of installed packages as follows:

python3 -m pip install -c .constraints/py3.8.txt -e .

The syntax works just as well for Optional dependencies:

python3 -m pip install -c .constraints/py3.8.txt -e .[doc,sty]
python3 -m pip install -c .constraints/py3.8.txt -e .[test]
python3 -m pip install -c .constraints/py3.8.txt -e .[dev]
python3 -m pip install -c .constraints/py3.8.txt -e ".[doc,sty]"
python3 -m pip install -c .constraints/py3.8.txt -e ".[test]"
python3 -m pip install -c .constraints/py3.8.txt -e ".[dev]"

The constraint files are updated automatically with pip-tools through GitHub Actions. See requirements-pr.yml and requirements-cron.yml.

Note

Constraint files ensure that the framework is deterministic and reproducible (up to testing) for all commits and versions, which is vital for both users (doing analysis) and for developers (for instance with continuous integration). In other words, it provides a way out of “dependency hell”.

Updating#

It may be that new commits in the repository modify the dependencies. In that case, you have to rerun this command after pulling new commits from the repository:

git checkout main
git pull
pip install -c .constraints/py3.8.txt -e .[dev]
git checkout main
git pull
pip install -c .constraints/py3.8.txt -e ".[dev]"

If you still have problems, it may be that certain dependencies have become redundant. In that case, trash the virtual environment and create a new one.

Julia#

Julia is an upcoming programming language in High-Energy Physics. While ComPWA is mainly developed in Python, we try to taylor to new trends and are experimenting with Julia as well.

Julia can be downloaded here or can be installed within your virtual environment with juliaup. To install Julia system-wide in Linux and Mac, you’ll have to unpack the downloaded tar file to a location that is easily accessible. Here’s an example, where we also make the Julia executable available to the system:

Install juliaup for installing and managing Julia versions.

conda install juliaup -c conda-forge
Optional: select Julia version

By default, this provides you with the latest Julia release. Optionally, you can switch versions as follows:

conda install juliaup -c conda-forge
juliaup add 1.9
juliaup default 1.9

You can switch back to the latest version with:

juliaup default release
cd ~/Downloads
tar xzf julia-1.9.2-linux-x86_64.tar.gz
mkdir ~/opt ~/bin
mv julia-1.9.2 ~/opt/
ln -s ~/opt/julia-1.9.2/bin/julia ~/bin/julia

Make sure that ~/bin is listed in the PATH environment variable, e.g. by updating it through your .bashrc file:

export PATH="~/bin:$PATH"
cd ~/Downloads
tar xzf julia-1.9.2-linux-x86_64.tar.gz
sudo mv julia-1.9.2 /opt/
sudo ln -s /opt/julia-1.9.2/bin/julia /usr/local/bin/julia

Just as in Python, it’s safest to work with a virtual environment. You can read more about Julia environments here. An environment is defined through a Project.toml file (which defines direct dependencies) and a Manifest.toml file (which exactly pins the installed versions of all recursive dependencies). Don’t touch these files―they are automatically managed by the package manager. It does make sense though to commit both Project.toml and Manifest.toml files, so that the environment is reproducible for each commit (see also Pinning dependency versions).

See also

Have a look here if you want to integrate Jupyter notebooks with Julia kernels into your documentation.

Automated coding conventions#

Where possible, we define and enforce our coding conventions through automated tools, instead of describing them in documentation. These tools perform their checks when you commit files locally (see Pre-commit), when running tox, and when you make a pull request.

The tools are mainly configured through pyproject.toml, tox.ini, and the workflow files under .github. These configuration files are kept up to date through the ComPWA/policy repository, which essentially defines the developer environment across all ComPWA repositories.

If you run into persistent linting errors, this may mean we need to further specify our conventions. In that case, it’s best to create an issue or a pull request at ComPWA/policy and propose a policy change that can be formulated through those config files.

Pre-commit#

All style checks are enforced through a tool called pre-commit. It’s best to activate this tool locally as well. This has to be done only once, after you clone the repository:

pre-commit install --install-hooks

Upon committing, pre-commit runs a set of checks as defined in the file .pre-commit-config.yaml over all staged files. You can also quickly run all checks over all indexed files in the repository with the command:

pre-commit run -a

Whenever you submit a pull request, this command is automatically run on GitHub actions and on pre-commit.ci , ensuring that all files in the repository follow the same conventions as set in the config files of these tools.

Tox#

More thorough checks can be run in one go with the following command:

tox -p

This command will run pytest, perform all style checks, build the documentation, and verify cross-references in the documentation and the API. It’s especially recommended to run tox before submitting a pull request!

More specialized tox job are defined in the tox.ini config file, under each testenv section. You can list all environments, along with a description of what they do, by running:

tox -av
GitHub Actions#

All style checks, testing of the documentation and links, and unit tests are performed upon each pull request through GitHub Actions (see status overview here). The checks are defined under the .github folder. All checks performed for each PR have to pass before the PR can be merged.

Style checks#

Formatting#

Formatters are tools that automatically format source code, or some document. Naturally, this speeds up your own programming, but these tools are particularly important when collaborating, because a standardized format avoids line conflicts in Git and makes diffs in code review easier to read.

For the Python source code, we use black and isort (through Ruff). For other code, we use Prettier. All of these formatters are “opinionated formatters”: they offer only limited configuration options, as to make formatting as conform as possible.

Pre-commit performs some additional formatting jobs. For instance, it formats Jupyter notebooks with nbQA and strips them of any output cells with nbstripout.

Linting#

Linters point out when certain style conventions are not correctly followed. Unlike with formatters, you have to fix the errors yourself. As mentioned in Automated coding conventions, style conventions are formulated in config files. The main linter that ComPWA projects use, is Ruff.

Spelling#

Throughout this repository, we follow American English (en-us) spelling conventions. As a tool, we use cSpell, because it allows to check variable names in camel case and snake case. This way, a spelling checker helps you avoid mistakes in the code as well! cSpell is enforced through pre-commit.

Accepted words are tracked through the .cspell.json file. As with the other config files, .cspell.json formulates our conventions with regard to spelling and can be continuously updated while our code base develops. In the file, the words section lists words that you want to see as suggested corrections, while ignoreWords are just the words that won’t be flagged. Try to be sparse in adding words: if some word is just specific to one file, you can ignore it inline, or you can add the file to the ignorePaths section if you want to ignore it completely.

It is easiest to use cSpell in Visual Studio code, through the Code Spell Checker extension: it provides linting, suggests corrections from the words section, and enables you to quickly add or ignore words through the .cspell.json file.

Testing#

The fastest way to run all tests is with the command:

pytest -n auto

The flag -n auto causes pytest to run with a distributed strategy.

Try to keep test coverage high. You can compute current coverage by running

tox -e cov

and opening htmlcov/index.html in a browser.

To get an idea of performance per component, run

pytest --profile-svg

and check the stats and the prof/combined.svg output file.

Note

Jupyter notebooks can also be used as tests. See more info here.

Documentation#

The documentation that you find on ComPWA pages like pwa.rtfd.io is built with Sphinx. Sphinx also builds the API page of the packages and therefore checks whether the docstrings in the Python source code are valid and correctly interlinked.

We make use of Markedly Structured Text (MyST), so you can write the documentation in both Markdown and reStructuredText. In addition, it’s easy to write (interactive) code examples in Jupyter notebooks and host them on the website (see MyST-NB)!

Documentation preview#

You can quickly build the documentation with the command:

tox -e doc

If you are doing a lot of work on the documentation, sphinx-autobuild is a nice tool to use. Just run:

tox -e doclive

This will start a server http://127.0.0.1:8000 where you can continuously preview the changes you make to the documentation.

Finally, a nice feature of Read the Docs, where we host our documentation, is that documentation is built for each pull request as well. This means that you can view the documentation for your changes as well. For more info, see here, or just click “details” under the RTD check once you submit your PR.

Jupyter Notebooks#

The docs folder can also contain Jupyter notebooks. These notebooks are rendered as HTML by MyST-NB. The notebooks are also run and tested whenever you make a pull request, so they also serve as integration tests.

If you want to improve those notebooks, we recommend working with Jupyter Lab, which is installed with the dev requirements. Jupyter Lab offers a nicer developer experience than the default Jupyter notebook editor does. A few useful Jupyter Lab plugins are also installed through the optional dependencies.

Now, if you want to test all notebooks in the documentation folder and check what their output cells will look like in the Documentation, you can do this with:

tox -e docnb

This command takes more time than tox -e doc, but it is good practice to do this before you submit a pull request. It’s also possible to continuously generate the HTML pages including cell output while you work on the notebooks with:

EXECUTE_NB= tox -e doclive

Tip

Notebooks are automatically formatted through pre-commit (see Formatting). If you want to format the notebooks automatically as you’re working, you can do so with jupyterlab-code-formatter, which is automatically installed with the dev requirements.

IJulia notebooks#

It’s also possible to execute and render Jupyter notebooks with Julia kernels. For this, install Julia and install IJulia:

julia -e 'import Pkg; Pkg.add("IJulia")'
import Pkg
Pkg.add("IJulia")

Usually, this also installs a Jupyter kernel directly. Optionally, you can define a Jupyter kernel manually:

julia -e 'using IJulia; installkernel("julia")'
using IJulia
installkernel("julia")

and select it as kernel in the Jupyter notebook.

Note

As mentioned in Julia, Julia can be installed within your Conda environment through juliaup. This is, however, not yet a virtual environment for Julia itself. You can create a virtual environment for Julia itself by for instance defining it through a code cell like this:

using Pkg
Pkg.activate(".")  # if environment is defined in this folder
Pkg.instantiate()

See Jupyter notebook with Julia kernel for an example.

Additionally, you can install a Language Server for Julia in Jupyter Lab. To do so, run:

julia -e 'import Pkg; Pkg.add("LanguageServer")'
using Pkg
Pkg.add("LanguageServer")

Collaboration#

The source code of all ComPWA repositories is maintained with Git and GitHub. We keep track of issues with the code, documentation, and developer set-up with GitHub issues (see for instance here). This is also the place where you can report bugs.

Tip

If you are new to working with GitHub, have a look at the tutorials on GitHub Skills. For good tutorials on Git, see:

Issue management#

We keep track of issue dependencies, time estimates, planning, pipeline statuses, et cetera with GitHub project boards (GitHub Issues). The main project boards are:

Some issues are not public. To get access, you can request to become member of the ComPWA GitHub organization. Other information that is publicly available are:

  • Issue labels: help to categorize issues by type (maintenance, enhancement, bug, etc.). The labels are also used to in the sub-sections of the release notes.

  • Milestones: way to bundle issues and PRs for upcoming releases.

  • Releases.

All of these are important for the Release flow and therefore also serve as a way to document the framework.

Branching model#

While our aim is to maintain long-term, stable projects, PWA software projects are academic projects that are subject to change and often require swift modifications or new features for ongoing analyses. For this reason, we work in different layers of development. These layers are represented by Git branches.

Epic branches

stable branch#

Represents the latest release of the package that can be found on both the GitHub release page and on PyPI (see Release flow). The documentation of the stable branch is also the default view you see on Read the Docs (RTD). See e.g. ampform.rtfd.io/en/stable.

main branch#

Represents the upcoming release of the package. This branch is not guaranteed to be stable, but has high CI standards and can only be updated through reviewed pull requests. The documentation of the main branch can be found on RTD under “latest”, see e.g. ampform.rtfd.io/en/latest.

Epic branches#

When working on a feature or larger refactoring that may take a longer time (think of implementing a new PWA formalism), we isolate its development under an ‘epic branch’, separate from the main branch. Eventually, this epic branch is to be merged back into the main, until then it is available for discussion and testing.

Pull requests to an epic branch require no code review and the CI checks are less strict. This allows for faster development, while still offering the possibility to discuss new implementations and keeping track of related issues.

Epic branches can be installed through PyPI as well. Say that a certain epic is located under the branch epic/some-title and that the source code is located under https://github.com/ComPWA/ampform, it can be installed as follows:

python3 -m pip install git+https://github.com/ComPWA/ampform@epic/some-title
Feature branches#

The main branch and Epic branches can be updated through pull requests. It is best to create such a pull request from a separate branch, which does not have any CI or code review restrictions. We call this a “feature branch”.

Commit conventions#
  • Please use conventional commit messages: start the commit with one of the semantic keywords below in UPPER CASE, followed by a column, then the commit header. The message itself should be in imperative mood — just imagine the commit to give a command to the code framework. So for instance:

    DX: implement coverage report tools
    FIX: remove typo in raised `ValueError` message
    MAINT: remove redundant print statements
    DOC: rewrite welcome pages
    BREAK: remove `formulate_model()` alias method
    

    The allowed semantic keywords (commit types) are as follows:[1]

    Commit type

    Repository label

    Description

    FEAT

    #C2E0C6 ✨ Feature

    New feature added to the package

    ENH

    #C2E0C6 ⚙️ Enhancement

    Improvements and optimizations of existing features

    FIX

    #e11d21 🐛 Bug

    Bug has been fixed

    BREAK

    #F9D0C4 ⚠️ Interface

    Breaking changes to the API

    BEHAVIOR

    #F9D0C4 ❗ Behavior

    Changes that may affect the framework output

    DOC

    #bfd4f2 📝 Docs

    Improvements or additions to documentation

    MAINT

    #FFCD8F 🔨 Maintenance

    Maintenance and upkeep improvements

    DX

    #FEF2C0 🖱️ DX

    Improvements to the Developer Experience

  • Keep pull requests small. If the issue you try to address is too big, discuss in the team whether the issue can be converted into an Epic and split up into smaller tasks.

  • Before creating a pull request, run Tox.

  • Also use a conventional commit message style for the PR title. This is because we follow a linear commit history and the PR title will become the eventual commit message. A linear commit history is important for the Release flow and it is easier to navigate through changes once something goes wrong. In fact, in a linear commit history, commits that a have been merged into the main branch become more like small intermediate patches between the minor and major releases.

    Note that a conventional commit message style is enforced through GitHub Actions with commitlint, as well as a check on PR labels (see example here). The commit messages are centrally defined for the ComPWA organization at ComPWA/commitlint-config.

  • PRs can only be merged through ‘squash and merge’. There, you will see a summary based on the separate commits that constitute this PR. Leave the relevant commits in as bullet points. See the commit history for examples. This comes in especially handy when drafting a release!

Release flow#

Releases are managed with the GitHub release page, see for instance the one for AmpForm. The release notes there are automatically generated from the PRs that were merged into the main branch since the previous tag and can be viewed and edited as a release draft if you are a member of the ComPWA organization. Each of the entries are generated from the PR titles, categorized by issue label (see configuration in .github/release-drafter.yml).

Once a release is made on GitHub for a repository with source code for a Python package, a new version is automatically published on PyPI and the stable branch is updated to this latest tag. The package version is taken from the Git tag associated with the release on GitHub (see setuptools-scm). This way, the release notes on GitHub serve as a changelog as well!

Release tags have to follow the Semantic Versioning scheme! This ensures that the tag can be used by setuptools-scm (in case the repository is a Python package). In addition, milestones with the same name as the release tag are automatically closed.

Code editors#

Even though we try to standardize the developer set-up of the repositories, we encourage you to use the code editors that you feel comfortable with. Where possible, we therefore define settings of linters, formatters, etc in config files that are specific to those tools (using pyproject.toml where possible), not in the configuration files of the editors.

Still, where code editor settings can be shared through configuration files in the repository, we provide recommended settings for the code editor as well. This is especially the case for VSCode.

Tip

We are open to other code editors as well. An example would be maintaining a local vimrc for users who prefer VIM. Other IDEs we’d like to support are PyCharm, Atom, IntelliJ with Python. So we’ll gladly integrate your editor settings where possible as you contribute to the frameworks!

Visual Studio code#

We recommend using Visual Studio Code as it’s free, regularly updated, and very flexible through it’s wide offer of user extensions.

If you add or open this repository as a VSCode workspace, the file .vscode/settings.json will ensure that you have the right developer settings for this repository. In addition, VSCode will automatically recommend you to install a number of extensions that we use when working on this code base. They are defined in the .vscode/extensions.json file.

You can still specify your own settings in either the user or encompassing workspace settings, as the VSCode settings that come with this are folder settings.

Conda and VSCode

ComPWA projects are best developed with Conda and VSCode. The complete developer install procedure then becomes:

git clone https://github.com/ComPWA/ampform  # or some other repo
cd ampform
conda env create
conda activate pwa  # or whatever the environment name is
code .  # open folder in VSCode

Writing durable software#

ComPWA strives to follow best practices from software development in industry. Following these standards not only makes the code easier to maintain and the software more reliable, it also provides you with the opportunity to learn about these practices while developing the code-base. Below you can find some resources we highly recommend you to be familiar with.

Software development in Python

Clean Code

Test-Driven Development

Software Design

Algorithms


Architectural Decision Records#

This log lists the architectural decisions for the ComPWA Organization:

[ADR-000] Use ADRs#

Status: accepted

Deciders: @redeboer, @spflueger

Technical story: A large number of issues in the expertsystem are correlated (e.g. ComPWA/expertsystem#40, ComPWA/expertsystem#44, ComPWA/expertsystem#22) so that resulting PRs (in this case, ComPWA/expertsystem#42) lacked direction. This led us to consider ADRs.

Context and Problem Statement#

We want to record architectural decisions made in this project. Which format and structure should these records follow?

Considered Options#
Decision Outcome#

Chosen option: “MADR 2.1.2”, because

  • Implicit assumptions should be made explicit. Design documentation is important to enable people understanding the decisions later on. See also A rational design process: How and why to fake it.

  • The MADR format is lean and fits our development style.

  • The MADR structure is comprehensible and facilitates usage & maintenance.

  • The MADR project is vivid.

  • Version 2.1.2 is the latest one available when starting to document ADRs.

[ADR-001] Amplitude model#

  • Status: accepted

  • Deciders: @redeboer @spflueger

Context and problem statement#

From the perspective of a PWA fitter package, the responsibility of the expertsystem is to construct a AmplitudeModel that serves as blueprint for a function that can be evaluated. Such a function has the following requirements:

  1. It should be able to compute a list of real-valued intensities \(\mathbb{R}^m\) from a dataset of four-momenta \(\mathbb{R}^{m\times n\times4}\), where \(m\) is the number of events and \(n\) is the number of final state particles.

  2. It should contain parameters that can be tweaked, so that they can be optimized with regard to a certain estimator.

Technical story#
  • ComPWA/ampform#5: Coupling parameters in the AmplitudeModel is difficult (has to be done through the place where they are used in the dynamics or intensity section) and counter-intuitive (cannot be done through the parameters section)

  • ComPWA/expertsystem#440: when overwriting existing dynamics, old parameters are not cleaned up from the parameters section

  • ComPWA/expertsystem#441: parameters contain a name that can be changed, but that results in a mismatch between the key that is used in the parameters section and the name of the parameter to which that entry points.

  • ComPWA/ComPWA-legacy#226: Use a math language for the blueprint of the function. This was also discussed early to mid 2020, but dropped in favor of custom python code + amplitf. The reasoning was that the effort of writing some new math language plus generators converting a mathematical expression into a function (using various back-ends) requires too much manpower.

Decision drivers#
Solution requirements#
  1. The AmplitudeModel has to be convertible to a function which can be evaluated using various computation back-ends (numpy, tensorflow, theano, jax, …)

  2. Ideally, the model should be complete in the sense that it contains all information to construct the complete model. This means that some “common” functions like a Breit-Wigner and Blatt-Weisskopf form factors should also be contained inside the AmplitudeModel. This guarantees reproducibility!

  3. Adding new operators/models should not trigger many code modifications (open-closed principle), for instance adding new dynamics or formalisms.

  4. Extendible:

    • Add or replace current parts of an existing model. For example replace the dynamics part of some decay.

    • Change a function plus a dataset to an estimator function. This is a subtle but important point. The function should hide its details (which backend and its mathematical expression) and yet be extendable to an estimator.

  5. Definition and easy extraction of components. Components are certain sub-parts of the complete mathematical expression. This is at least needed for the calculation of fit fractions, or plotting individual parts of the intensity.

Considered solutions#
Customized Python classes#

Currently (v0.6.8), the AmplitudeModel contains five sections (instances of specific classes):

  • kinematics: defines initial and final state

  • particles: particle definitions (spin, etc.)

  • dynamics: a mapping that defines which dynamics type to apply to which particle

  • intensity: the actual amplitude model that is to be converted by a fitter package into a function as described above

  • parameters: an inventory of parameters that are used in intensity and dynamics

This structure can be represented in YAML, see an example here.

A fitter package converts intensity together with dynamics into a function. Any references to parameters that intensity or dynamics contain are converted into a parameter of the function. The parameters are initialized with the value as listed in the parameters section of the AmplitudeModel.

Alternative solutions#
SymPy#

Some useful SymPy pages:

Defining the amplitude model in terms of SymPy#

Parameters would become a mapping of Symbols to initial, suggested values and dynamics would be a mapping of Symbols to ‘suggested’ expressions. Intensity will be the eventual combined expression.

from __future__ import annotations

import sympy as sp
from attrs import define, field


@define
class AmplitudeModel:
    initial_values: dict[sp.Symbol, float] = field(factory=dict)
    dynamics: dict[sp.Symbol, sp.Function] = field(factory=dict)
    intensity: sp.Expr = field(default=None)

There needs to be one symbol \(x\) that represents the four-momentum input:

x = sp.Symbol("x")

As an example, let’s create an AmplitudeModel with an intensity that is a sum of Gaussians. Each Gaussian here takes the rôle of a dynamics function:

model = AmplitudeModel()

N_COMPONENTS = 3
for i in range(1, N_COMPONENTS + 1):
    mu = sp.Symbol(Rf"\mu_{i}")
    sigma = sp.Symbol(Rf"\sigma_{i}")
    model.initial_values.update({
        mu: float(i),
        sigma: 1 / (2 * i),
    })
    gauss = sp.exp(-((x - mu) ** 2) / (sigma**2)) / (sigma * sp.sqrt(2 * sp.pi))
    dyn_symbol = sp.Symbol(Rf"\mathrm{{dyn}}_{i}")
    model.dynamics[dyn_symbol] = gauss

coherent_sum = sum(model.dynamics)
model.intensity = coherent_sum
model.initial_values
{\mu_1: 1.0,
 \sigma_1: 0.5,
 \mu_2: 2.0,
 \sigma_2: 0.25,
 \mu_3: 3.0,
 \sigma_3: 0.16666666666666666}
model.intensity
\[\displaystyle \mathrm{dyn}_1 + \mathrm{dyn}_2 + \mathrm{dyn}_3\]

Dynamics are inserted into the intensity expression of the model:

model.intensity.subs(model.dynamics)
\[\displaystyle \frac{\sqrt{2} e^{- \frac{\left(- \mu_{3} + x\right)^{2}}{\sigma_{3}^{2}}}}{2 \sqrt{\pi} \sigma_{3}} + \frac{\sqrt{2} e^{- \frac{\left(- \mu_{2} + x\right)^{2}}{\sigma_{2}^{2}}}}{2 \sqrt{\pi} \sigma_{2}} + \frac{\sqrt{2} e^{- \frac{\left(- \mu_{1} + x\right)^{2}}{\sigma_{1}^{2}}}}{2 \sqrt{\pi} \sigma_{1}}\]

And, for evaluating, the ‘suggested’ initial parameter values are inserted:

model.intensity.subs(model.dynamics).subs(model.initial_values)
\[\displaystyle \frac{1.0 \sqrt{2} e^{- 4.0 \left(x - 1.0\right)^{2}}}{\sqrt{\pi}} + \frac{2.0 \sqrt{2} e^{- 64.0 \left(0.5 x - 1\right)^{2}}}{\sqrt{\pi}} + \frac{3.0 \sqrt{2} e^{- 324.0 \left(0.333333333333333 x - 1\right)^{2}}}{\sqrt{\pi}}\]

Here’s a small helper function to plot this model:

Hide code cell source
def plot_model(model: AmplitudeModel) -> None:
    total_plot = sp.plotting.plot(
        model.intensity.subs(model.dynamics).subs(model.initial_values),
        (x, 0, 4),
        show=False,
        line_color="black",
    )
    p1 = sp.plotting.plot(
        model.dynamics[sp.Symbol(R"\mathrm{dyn}_1")].subs(model.initial_values),
        (x, 0, 4),
        line_color="red",
        show=False,
    )
    p2 = sp.plotting.plot(
        model.dynamics[sp.Symbol(R"\mathrm{dyn}_2")].subs(model.initial_values),
        (x, 0, 4),
        line_color="blue",
        show=False,
    )
    p3 = sp.plotting.plot(
        model.dynamics[sp.Symbol(R"\mathrm{dyn}_3")].subs(model.initial_values),
        (x, 0, 4),
        line_color="green",
        show=False,
    )
    total_plot.extend(p1)
    total_plot.extend(p2)
    total_plot.extend(p3)
    total_plot.show()
plot_model(model)

Now we can couple parameters like this:

model.initial_values[sp.Symbol(R"\sigma_1")] = sp.Symbol(R"\sigma_3")
plot_model(model)

model.initial_values[sp.Symbol(R"\sigma_3")] = 1
plot_model(model)

And it’s also possible to insert custom dynamics:

model.dynamics[sp.Symbol(R"\mathrm{dyn}_3")] = sp.sqrt(x)
plot_model(model)

Implementation in TensorWaves#

Credits @spflueger

1. Create a double gaussian amp with SymPy#

When building the model, we should be careful to pass the parameters as arguments as well, otherwise frameworks like jax can’t determine the gradient.

import math

import sympy as sp

x, A1, mu1, sigma1, A2, mu2, sigma2 = sp.symbols(
    "x, A1, mu1, sigma1, A2, mu2, sigma2"
)
gaussian1 = (
    A1 / (sigma1 * sp.sqrt(2.0 * math.pi)) * sp.exp(-((x - mu1) ** 2) / (2 * sigma1))
)
gaussian2 = (
    A2 / (sigma2 * sp.sqrt(2.0 * math.pi)) * sp.exp(-((x - mu2) ** 2) / (2 * sigma2))
)

gauss_sum = gaussian1 + gaussian2
gauss_sum
\[\displaystyle \frac{0.398942280401433 A_{1} e^{- \frac{\left(- \mu_{1} + x\right)^{2}}{2 \sigma_{1}}}}{\sigma_{1}} + \frac{0.398942280401433 A_{2} e^{- \frac{\left(- \mu_{2} + x\right)^{2}}{2 \sigma_{2}}}}{\sigma_{2}}\]
2. Convert this expression to a function using lambdify#

TensorFlow as backend:

import inspect

tf_gauss_sum = sp.lambdify(
    (x, A1, mu1, sigma1, A2, mu2, sigma2), gauss_sum, "tensorflow"
)
print(inspect.getsource(tf_gauss_sum))
def _lambdifygenerated(x, A1, mu1, sigma1, A2, mu2, sigma2):
    return (0.398942280401433*A1*exp(-1/2*pow(-mu1 + x, 2)/sigma1)/sigma1 + 0.398942280401433*A2*exp(-1/2*pow(-mu2 + x, 2)/sigma2)/sigma2)

NumPy as backend:

numpy_gauss_sum = sp.lambdify(
    (x, A1, mu1, sigma1, A2, mu2, sigma2), gauss_sum, "numpy"
)
print(inspect.getsource(numpy_gauss_sum))
def _lambdifygenerated(x, A1, mu1, sigma1, A2, mu2, sigma2):
    return (0.398942280401433*A1*exp(-1/2*(-mu1 + x)**2/sigma1)/sigma1 + 0.398942280401433*A2*exp(-1/2*(-mu2 + x)**2/sigma2)/sigma2)

Jax as backend:

from jax import numpy as jnp
from jax import scipy as jsp

jax_gauss_sum = sp.lambdify(
    (x, A1, mu1, sigma1, A2, mu2, sigma2),
    gauss_sum,
    modules=(jnp, jsp.special),
)
print(inspect.getsource(jax_gauss_sum))
def _lambdifygenerated(x, A1, mu1, sigma1, A2, mu2, sigma2):
    return (0.398942280401433*A1*exp(-1/2*(-mu1 + x)**2/sigma1)/sigma1 + 0.398942280401433*A2*exp(-1/2*(-mu2 + x)**2/sigma2)/sigma2)
3. Natively create the respective packages#
import math

import tensorflow as tf


def gaussian(x, A, mu, sigma):
    return (
        A
        / (sigma * tf.sqrt(tf.constant(2.0, dtype=tf.float64) * math.pi))
        * tf.exp(
            -tf.pow(
                -tf.constant(0.5, dtype=tf.float64) * (x - mu) / sigma,
                2,
            )
        )
    )


def native_tf_gauss_sum(x_, A1_, mu1_, sigma1_, A2_, mu2_, sigma2_):
    return gaussian(x_, A1_, mu1_, sigma1_) + gaussian(x_, A2_, mu2_, sigma2_)


# @jx.pmap
def jax_gaussian(x, A, mu, sigma):
    return (
        A
        / (sigma * jnp.sqrt(2.0 * math.pi))
        * jnp.exp(-((-0.5 * (x - mu) / sigma) ** 2))
    )


def native_jax_gauss_sum(x_, A1_, mu1_, sigma1_, A2_, mu2_, sigma2_):
    return jax_gaussian(x_, A1_, mu1_, sigma1_) + jax_gaussian(
        x_, A2_, mu2_, sigma2_
    )
4. Compare performance#
import numpy as np

parameter_values = (1.0, 0.0, 0.1, 2.0, 2.0, 0.2)
rng = np.random.default_rng(0)
np_x = rng.uniform(-1, 3, 10000)
tf_x = tf.constant(np_x)


def evaluate_with_parameters(function):
    def wrapper():
        return function(np_x, *(parameter_values))

    return wrapper


def call_native_tf():
    func = native_tf_gauss_sum
    params = tuple(tf.Variable(v, dtype=tf.float64) for v in parameter_values)

    def wrapper():
        return func(tf_x, *params)

    return wrapper
import timeit

from jax.config import config

config.update("jax_enable_x64", True)

print(
    "sympy tf lambdify",
    timeit.timeit(evaluate_with_parameters(tf_gauss_sum), number=100),
)
print(
    "sympy numpy lambdify",
    timeit.timeit(evaluate_with_parameters(numpy_gauss_sum), number=100),
)
print(
    "sympy jax lambdify",
    timeit.timeit(evaluate_with_parameters(jax_gauss_sum), number=100),
)
print("native tf", timeit.timeit(call_native_tf(), number=100))

print(
    "native jax",
    timeit.timeit(evaluate_with_parameters(native_jax_gauss_sum), number=100),
)
sympy tf lambdify 0.22086703500099247
sympy numpy lambdify 0.02661015900048369
sympy jax lambdify 0.24337401299999328
native tf 0.25517284799934714
native jax 0.2992750530011108
5. Handling parameters#

Some options:

5.1 Changing parameter values#

Can be done in the model itself…

But how can the values be propagated to the AmplitudeModel?

Well, if an amplitude model only defines parameters with a name and the values are supplied in the function evaluation, then everything is decoupled and there are no problems.

5.2 Changing parameter names#

Names can be changed in the sympy AmplitudeModel. Since this sympy model serves as the source of truth for the Function, all things generated from this model will reflect the name changes as well.

But going even further, since the Parameters are passed into the functions as arguments, the whole naming becomes irrelevant anyways.

tf_var_A1 = tf.Variable(1.0, dtype=tf.float64) <- does not carry a name!!

5.3 Coupling parameters#

This means that one parameter is just assigned to another one?

result = evaluate_with_parameters(jax_gauss_sum)()
result
DeviceArray([0.60618145, 2.03309932, 3.59630909, ..., 0.26144946,
             3.05430146, 2.88912312], dtype=float64)
np_x
array([ 2.86815263,  2.51926301,  1.7962957 , ...,  0.81910382,
        1.67313811, -0.25404634])
import matplotlib.pyplot as plt

plt.hist(np_x, bins=100, weights=result);

parameter_values = (1.0, 0.0, 0.1, 2.0, 2.0, 0.1)
result = evaluate_with_parameters(jax_gauss_sum)()
plt.hist(np_x, bins=100, weights=result);

6. Exchange a gaussian with some other function#

This should be easy if you know the exact expression that you want to replace:

from sympy.abc import C, a, b, x

expr = sp.sin(a * x) + sp.cos(b * x)
expr
\[\displaystyle \sin{\left(a x \right)} + \cos{\left(b x \right)}\]
expr.subs(sp.sin(a * x), C)
\[\displaystyle C + \cos{\left(b x \right)}\]
7. Matrix operations?#
from sympy.physics.quantum.dagger import Dagger

spin_density = sp.MatrixSymbol("rho", 3, 3)
amplitudes = sp.Matrix([[1 + sp.I], [2 + sp.I], [3 + sp.I]])

dummy_intensity = sp.re(
    Dagger(amplitudes) * spin_density * amplitudes,
    evaluate=False,
    # evaluate=False is important otherwise it generates some function that cant
    # be lambdified anymore
)
dummy_intensity
\[\begin{split}\displaystyle \operatorname{re}{\left(\left[\begin{matrix}1 - i & 2 - i & 3 - i\end{matrix}\right] \rho \left[\begin{matrix}1 + i\\2 + i\\3 + i\end{matrix}\right]\right)}\end{split}\]
tf_intensity = sp.lambdify(
    spin_density,
    dummy_intensity,
    modules=(tf,),
)
print(inspect.getsource(tf_intensity))
def _lambdifygenerated(rho):
    return (real(matmul(matmul(constant([[1 - 1j, 2 - 1j, 3 - 1j]]), rho), constant([[1 + 1j], [2 + 1j], [3 + 1j]]))))
real0 = tf.constant(0, dtype=tf.float64)
real1 = tf.constant(1, dtype=tf.float64)
intensity_result = tf_intensity(
    np.array([
        [
            tf.complex(real1, real0),
            tf.complex(real0, real0),
            -tf.complex(real0, real1),
        ],
        [
            tf.complex(real0, real0),
            tf.complex(real1, real0),
            tf.complex(real0, real0),
        ],
        [
            tf.complex(real0, real1),
            tf.complex(real0, real0),
            tf.complex(real1, real0),
        ],
    ]),
)
intensity_result
<tf.Tensor: shape=(1, 1), dtype=float64, numpy=array([[13.]])>
Python operator library#

See Python’s built-in operator library

What we have now#

Build test model

import expertsystem as es

result = es.generate_transitions(
    initial_state=[("J/psi(1S)", [-1, 1])],
    final_state=["p", "p~", "eta"],
    allowed_intermediate_particles=["N(1440)"],
    allowed_interaction_types="strong",
)
model = es.generate_amplitudes(result)
for particle in result.get_intermediate_particles():
    model.dynamics.set_breit_wigner(particle.name)
es.io.write(model, "recipe.yml")

Visualize the decay:

import graphviz

graphs = result.collapse_graphs()
dot = es.io.convert_to_dot(graphs)
graphviz.Source(dot)

model.parameters
FitParameters([
    FitParameter(name='Magnitude_J/psi(1S)_to_N(1440)+_0.5+p~_-0.5;N(1440)+_to_eta_0+p_0.5;', value=1.0, fix=False),
    FitParameter(name='Magnitude_J/psi(1S)_to_N(1440)+_0.5+p~_0.5;N(1440)+_to_eta_0+p_0.5;', value=1.0, fix=False),
    FitParameter(name='Magnitude_J/psi(1S)_to_N(1440)~-_0.5+p_-0.5;N(1440)~-_to_eta_0+p~_0.5;', value=1.0, fix=False),
    FitParameter(name='Magnitude_J/psi(1S)_to_N(1440)~-_0.5+p_0.5;N(1440)~-_to_eta_0+p~_0.5;', value=1.0, fix=False),
    FitParameter(name='MesonRadius_J/psi(1S)', value=1.0, fix=True),
    FitParameter(name='MesonRadius_N(1440)+', value=1.0, fix=True),
    FitParameter(name='MesonRadius_N(1440)~-', value=1.0, fix=True),
    FitParameter(name='Phase_J/psi(1S)_to_N(1440)+_0.5+p~_-0.5;N(1440)+_to_eta_0+p_0.5;', value=0.0, fix=False),
    FitParameter(name='Phase_J/psi(1S)_to_N(1440)+_0.5+p~_0.5;N(1440)+_to_eta_0+p_0.5;', value=0.0, fix=False),
    FitParameter(name='Phase_J/psi(1S)_to_N(1440)~-_0.5+p_-0.5;N(1440)~-_to_eta_0+p~_0.5;', value=0.0, fix=False),
    FitParameter(name='Phase_J/psi(1S)_to_N(1440)~-_0.5+p_0.5;N(1440)~-_to_eta_0+p~_0.5;', value=0.0, fix=False),
    FitParameter(name='Position_N(1440)+', value=1.44, fix=False),
    FitParameter(name='Position_N(1440)~-', value=1.44, fix=False),
    FitParameter(name='Width_N(1440)+', value=0.35, fix=False),
    FitParameter(name='Width_N(1440)~-', value=0.35, fix=False),
])
Implementation with operators#

See this answer on Stack Overflow:

import operator

MAKE_BINARY = lambda opfn: lambda self, other: BinaryOp(  # noqa: E731
    self, asMagicNumber(other), opfn
)
MAKE_RBINARY = lambda opfn: lambda self, other: BinaryOp(  # noqa: E731
    asMagicNumber(other), self, opfn
)


class MagicNumber:
    __add__ = MAKE_BINARY(operator.add)
    __sub__ = MAKE_BINARY(operator.sub)
    __mul__ = MAKE_BINARY(operator.mul)
    __radd__ = MAKE_RBINARY(operator.add)
    __rsub__ = MAKE_RBINARY(operator.sub)
    __rmul__ = MAKE_RBINARY(operator.mul)
    __truediv__ = MAKE_BINARY(operator.truediv)
    __rtruediv__ = MAKE_RBINARY(operator.truediv)
    __floordiv__ = MAKE_BINARY(operator.floordiv)
    __rfloordiv__ = MAKE_RBINARY(operator.floordiv)

    def __neg__(self):
        return UnaryOp(self, lambda x: -x)

    @property
    def value(self):
        return self.eval()


class Constant(MagicNumber):
    def __init__(self, value):
        self.value_ = value

    def eval(self):
        return self.value_


class Parameter(Constant):
    def __init__(self):
        super().__init__(0.0)

    def setValue(self, v):
        self.value_ = v

    value = property(fset=setValue, fget=lambda self: self.value_)


class BinaryOp(MagicNumber):
    def __init__(self, op1, op2, operation):
        self.op1 = op1
        self.op2 = op2
        self.opn = operation

    def eval(self):
        return self.opn(self.op1.eval(), self.op2.eval())


class UnaryOp(MagicNumber):
    def __init__(self, op1, operation):
        self.op1 = op1
        self.operation = operation

    def eval(self):
        return self.opn(self.op1.eval())


asMagicNumber = lambda x: (  # noqa: E731
    x if isinstance(x, MagicNumber) else Constant(x)
)
asMagicNumber(2).eval()
2
Other ideas#
Option 1: parameter container

Remove name from the FitParameter class and give the FitParameters collection class the responsibility to keep track of ‘names’ of the FitParameters as keys in a dict. In the AmplitudeModel, locations where a FitParameter should be inserted are indicated by an immutable (!) str that should exist as a key in the FitParameters.

Such a setup best reflects the structure of the AmplitudeModel that we have now (best illustrated by expected_recipe, note in particular YAML anchors like &par1/*par1). It also allows one to couple FitParameters. See following snippet:

from attrs import define, frozen


# the new FitParameter class would have this structure
@define
class Parameter:
    value: float
    fix: bool = False


# the new FitParameters collection would have such a structure
mapping = {
    "par1": Parameter(1.0),
    "par2": Parameter(2.0, fix=False),
}


# intensity nodes and dynamics classes contain immutable strings
class Dynamics:
    pass


@frozen
class CustomDynamics(Dynamics):
    par: str


dyn1 = CustomDynamics(par="par1")
dyn2 = CustomDynamics(par="par2")

# Parameters would be coupled like this
mapping["par1"] = mapping["par2"]
assert mapping["par2"] is mapping["par1"]
assert mapping["par1"] == {
    "par1": Parameter(1.0),
    "par2": Parameter(1.0),
}
Option 2: read-only parameter manager

Remove the FitParameters collection class altogether and use something like immutable InitialParameter instances in the dynamics and intensity section of the AmplitudeModel. The AmplitudeModel then starts to serve as a read-only’ template. A fitter package like tensorwaves can then loop over the AmplitudeModel structure to extract the InitialParameter instances and convert them to something like an FitParameter.

Here’s a rough sketch with tensorwaves in mind.

from typing import Generator

import attrs
from attrs import define, field

from expertsystem.amplitude.model import (
    AmplitudeModel,
    Dynamics,
    Node,
    ParticleDynamics,
)
from expertsystem.reaction.particle import Particle


@define
class InitialParameter:
    name: str = field()
    value: float = field()
    # fix: bool = field(default=False)


@define
class FitParameter:
    name: str = field(on_setattr=attrs.setters.frozen)
    value: float = field()
    fix: bool = field(default=False)


class FitParameterManager:
    """Manages all fit parameters of the model"""

    def __init__(self, model: AmplitudeModel) -> None:
        self.__model: AmplitudeModel
        self.__parameter_couplings: dict[str, str]

    @property
    def parameters(self) -> list[FitParameter]:
        initial_parameters = list(__yield_parameter(self.__model))
        self.__apply_couplings()
        return self.__convert(initial_parameters)

    def couple_parameters(self, parameter1: str, parameter2: str) -> None:
        pass

    def __convert(self, params: list[InitialParameter]) -> list[FitParameter]:
        pass


@define
class CustomDynamics(Dynamics):
    parameter: InitialParameter = field(kw_only=True)

    @staticmethod
    def from_particle(particle: Particle):
        pass


def __yield_parameter(
    instance: object,
) -> Generator[InitialParameter, None, None]:
    if isinstance(instance, InitialParameter):
        yield instance
    elif isinstance(instance, (ParticleDynamics, Node)):
        for item in instance.values():
            yield from __yield_parameter(item)
    elif isinstance(instance, (list,)):
        for item in instance:
            yield from __yield_parameter(item)
    elif attrs.has(instance.__class__):
        for field in attrs.fields(instance.__class__):
            field_value = getattr(instance, field.name)
            yield from __yield_parameter(field_value)


# usage in tensorwaves
amp_model = AmplitudeModel()
kinematics: HelicityKinematics = ...
builder = IntensityBuilder(kinematics)

intensity = builder.create(amp_model)  # this would call amp_model.parameters
parameters: dict[str, float] = intensity.parameters
# PROBLEM?: fix status is lost at this point

data_sample = generate_data(...)
dataset = kinematics.convert(data_sample)

parameters["Width_f(0)(980)"] = 0.2  # name is immutable at this point

# name of a parameter can be changed in the AmplitudeModel though
# and then call builder again
intensity(dataset, parameters)
Evaluation#
Pros and Cons#
Customized Python classes (current state)#
  • Positive

    • “Faster” implementation / prototyping possible compared to python operators

    • No additional dependencies

  • Negative

    • Not open-closed to new models

    • Conversion to various back-ends not DRY

    • Function replacement or extension feature becomes very difficult to handle.

    • Model is not complete, since no complete mathematical description is used. For example Breit-Wigner functions are referred to directly and their implementations is not defined in the amplitude model.

SymPy#
  • Positive

    • Easy to render amplitude model as LaTeX

    • Model description is complete! Absolutely all information about the model is included. (reproducibility)

    • Follows open-closed principle. New models and formalism can be added without any changes to other interfacing components (here: tensorwaves)

    • Use lambdify to convert the expression to any back-end

    • Use Expr.subs (substitute) to couple parameters or replace components of the model, for instance to set custom dynamics

  • Negative

    • lambdify becomes a core dependency while its behavior cannot be modified, but is defined by sympy.

    • Need to keep track of components in the expression tree with symbol mappings

Python’s operator library#
  • Positive

    • More control over different components of in the expression tree

    • More control over convert functionality to functions

    • No additional dependencies

  • Negative

    • Essentially re-inventing SymPy

Decision outcome#

Use SymPy. Initially, we leave the existing amplitude builders (modules helicity_decay and canonical_decay) alongside a SymPy implementation, so that it’s possible to compare the results. Once it turns out the this set-up results in the same results and a comparable performance, we replace the old amplitude builders with the new SymPy implementation.

[ADR-002] Inserting dynamics#

  • Status: proposed

  • Deciders: @redeboer @spflueger

Context and problem statement#

Physics models usually include assumptions that simplify the structure of the model. For example, splitting a model into a product of independent parts, in which every part contains a certain responsibility. In case of partial wave amplitude models, we can make a separation into a spin part and a dynamical part. While the spin part controls the probability w.r.t angular kinematic variables, the dynamics controls the probability on variable like the invariant mass of states.

Generally, a dynamics part is simply a function, which is defined in complex space, and consists of:

  • a mathematical expression (sympy.Expr)

  • a set of parameters in that expression that can be tweaked (optimized)

  • a set of (kinematic) variables to which the expression applies

Technical story#
Issues with existing set-up#
  • There is no clear way to apply dynamics functions to a specific decaying particle, that is, to a specific edge of the StateTransitionGraphs (STG). Currently, we just work with a mapping of Particles to some dynamics expression, but this is not feasible when there there are identical particles on the edges.

  • The set of variables to which a dynamics expression applies, is determined by the position within the STG that it is applied to. For instance, a relativistic Breit-Wigner that is applied to the resonance in some 1-to-3 isobar decay (described by an STG with final state edges 2, 3, 4 and intermediate edge 1) would work on the invariant mass of edge 3 and 4 (mSq_3+4).

  • Just like variables, parameters need to be identifiable from their position within the STG (take a relativistic Breit-Wigner with form factors, which would require break-up momentum as a parameter), but also require some suggested starting value (e.g. expected pole position). These starting values are usually taken from the edge and node properties within the STG.

Decision drivers#

The following points are nice to have or can influence the decision but are not essential and can be part of the users responsibility.

  1. The parameters that a dynamics functions requires, are registered automatically and linked together.

  2. Kinematic variables used in dynamics functions are also linked appropriately.

  3. It is easy to define custom dynamics (no boilerplate code).

Solution requirements#
  1. It is easy to apply dynamics to specific components of the STGs. Note: it’s important that the dynamics can be applied to resonances of some selected graphs and not generally all graphs in which the resonance appears.

  2. Where possible, suggested (initial) parameter values are provided as well.

  3. It is possible to use and inspect the dynamics expression itself independently from the expertsystem.

  4. Follow open-closed principle. Probably the most important decision driver. The solution should be flexible enough to handle any possible scenario, without having to change the interface defined in requirement 1!

Considered solutions#
Group 1: expression builder#

To satisfy requirement 1, we propose the following syntax:

# model: ModelInfo
# graph: StateTransitionGraph
model.set_dynamics(graph, edge_id=1, expression_builder)

Another style would be to have ModelInfo contain a reference to the list of StateTransitionGraphs. The user then needs some other way to express which edges to apply the dynamics function to:

model.set_dynamics(
    filter=lambda p: p.name.startswith("f(0)"),
    edge_id=1,
    expression_builder,
)

Here, expression_builder is some function or method that can create a dynamics expression. It can also be a class that contains both the implementation of the expression and a static method to build itself from a StateTransitionGraph.

The dynamics expression needs to be formulated in such a way that it satisfies the rest of the requirements. The following options illustrate three different ways of formulating a dynamics expression, each taking a relativistic Breit-Wigner and a relativistic Breit-Wigner with form factor as example.

Using composition#
Hide code cell content
from __future__ import annotations

import sympy as sp
from attrs import frozen
from helpers import (
    StateTransitionGraph,
    blatt_weisskopf,
    determine_attached_final_state,
    two_body_momentum_squared,
)

try:
    from typing import Protocol
except ImportError:
    from typing_extensions import Protocol

A frozen DynamicExpression class keeps track of variables, parameters, and the dynamics expression in which they should appear:

@frozen
class DynamicsExpression:
    variables: tuple[sp.Symbol, ...]
    parameters: tuple[sp.Symbol, ...]
    expression: sp.Expr

    def substitute(self) -> sp.Expr:
        return self.expression(*self.variables, *self.parameters)

The expression attribute can be formulated as a simple Python function that takes sympy.Symbols as arguments and returns a sympy.Expr:

def relativistic_breit_wigner(
    mass: sp.Symbol, mass0: sp.Symbol, gamma0: sp.Symbol
) -> sp.Expr:
    return gamma0 * mass0 / (mass0**2 - mass**2 - gamma0 * mass0 * sp.I)
def relativistic_breit_wigner_with_form_factor(
    mass: sp.Symbol,
    mass0: sp.Symbol,
    gamma0: sp.Symbol,
    m_a: sp.Symbol,
    m_b: sp.Symbol,
    angular_momentum: sp.Symbol,
    meson_radius: sp.Symbol,
) -> sp.Expr:
    q_squared = two_body_momentum_squared(mass, m_a, m_b)
    q0_squared = two_body_momentum_squared(mass0, m_a, m_b)
    ff2 = blatt_weisskopf(q_squared, meson_radius, angular_momentum)
    ff02 = blatt_weisskopf(q0_squared, meson_radius, angular_momentum)
    width = gamma0 * (mass0 / mass) * (ff2 / ff02)
    width = width * sp.sqrt(q_squared / q0_squared)
    return (
        relativistic_breit_wigner(mass, mass0, width) * mass0 * gamma0 * sp.sqrt(ff2)
    )

The DynamicsExpression container class enables us to provide the expression with correctly named Symbols for the decay that is being described. Here, we use some naming scheme for an \(f_0(980)\) decaying to final state edges 3 and 4 (say \(\pi^0\pi^0\)):

bw_decay_f0 = DynamicsExpression(
    variables=sp.symbols("m_3+4", seq=True),
    parameters=sp.symbols(R"m_f(0)(980) \Gamma_f(0)(980)"),
    expression=relativistic_breit_wigner,
)
bw_decay_f0.substitute()
\[\displaystyle \frac{\Gamma_{f(0)(980)} m_{f(0)(980)}}{- i \Gamma_{f(0)(980)} m_{f(0)(980)} - m_{3+4}^{2} + m_{f(0)(980)}^{2}}\]

For each dynamics expression, we have to provide a ‘builder’ function that can create a DynamicsExpression for a specific edge within the StateTransitionGraph:

def relativistic_breit_wigner_from_graph(
    graph: StateTransitionGraph, edge_id: int
) -> DynamicsExpression:
    edge_ids = determine_attached_final_state(graph, edge_id)
    final_state_ids = map(str, edge_ids)
    mass = sp.Symbol(f"m_{{{'+'.join(final_state_ids)}}}")
    particle, _ = graph.get_edge_props(edge_id)
    mass0 = sp.Symbol(f"m_{{{particle.latex}}}")
    gamma0 = sp.Symbol(Rf"\Gamma_{{{particle.latex}}}")
    return DynamicsExpression(
        variables=(mass),
        parameters=(mass0, gamma0),
        expression=relativistic_breit_wigner(mass, mass0, gamma0),
    )
def relativistic_breit_wigner_with_form_factor_from_graph(
    graph: StateTransitionGraph, edge_id: int
) -> DynamicsExpression:
    edge_ids = determine_attached_final_state(graph, edge_id)
    final_state_ids = map(str, edge_ids)
    mass = sp.Symbol(f"m_{{{'+'.join(final_state_ids)}}}")
    particle, _ = graph.get_edge_props(edge_id)
    mass0 = sp.Symbol(f"m_{{{particle.latex}}}")
    gamma0 = sp.Symbol(Rf"\Gamma_{{{particle.latex}}}")
    m_a = sp.Symbol(f"m_{edge_ids[0]}")
    m_b = sp.Symbol(f"m_{edge_ids[1]}")
    angular_momentum = particle.spin  # helicity formalism only!
    meson_radius = sp.Symbol(Rf"R_{{{particle.latex}}}")
    return DynamicsExpression(
        variables=(mass),
        parameters=(
            mass0,
            gamma0,
            m_a,
            m_b,
            angular_momentum,
            meson_radius,
        ),
        expression=relativistic_breit_wigner_with_form_factor(
            mass,
            mass0,
            gamma0,
            m_a,
            m_b,
            angular_momentum,
            meson_radius,
        ),
    )

The fact that DynamicsExpression.expression is just a Python function, allows one to inspect the dynamics formulation of these functions independently, purely in terms of SymPy:

m, m0, w0 = sp.symbols(R"m m_0 \Gamma")
evaluated_bw = relativistic_breit_wigner(m, 1.0, 0.3)
relativistic_breit_wigner(m, m0, w0)
\[\displaystyle \frac{\Gamma m_{0}}{- i \Gamma m_{0} - m^{2} + m_{0}^{2}}\]
sp.plot(sp.Abs(evaluated_bw), (m, 0, 2), axis_center=(0, 0), ylim=(0, 1))
sp.plot(sp.arg(evaluated_bw), (m, 0, 2), axis_center=(0, 0), ylim=(0, sp.pi));

This closes the gap between the code and the theory that is being implemented.

Alternative signature#

An alternative way to specify the expression is:

def expression(
    variables: tuple[sp.Symbol, ...], parameters: tuple[sp.Symbol, ...]
) -> sp.Expr:
    pass

Here, one would however need to unpack the variables and parameters. The advantage is that the signature becomes more general.

Type checking#

There is no way to enforce the appropriate signature on the builder function, other than following a Protocol:

class DynamicsBuilder(Protocol):
    def __call__(
        self, graph: StateTransitionGraph, edge_id: int
    ) -> DynamicsExpression: ...

This DynamicsBuilder protocol would be used in the syntax proposed at Considered solutions.

It carries another subtle problem, though: a Protocol is only used in static type checking, while potential problems with the implementation of the builder and its respective expression only arrise at runtime.

Subclassing sympy.Function#
Hide code cell content
from abc import abstractmethod

import sympy as sp
from helpers import (
    StateTransitionGraph,
    blatt_weisskopf,
    determine_attached_final_state,
    two_body_momentum_squared,
)

One way to address the Cons of Using composition, is to sub-class Function. The expression is implemented by overwriting the inherited eval() method and the builder is provided through the class through an additional from_graph class method. The interface would look like this:

class DynamicsFunction(sp.Function):
    @classmethod
    @abstractmethod
    def eval(cls, *args: sp.Symbol) -> sp.Expr:
        """Implementation of the dynamics function."""

    @classmethod
    @abstractmethod
    def from_graph(cls, graph: StateTransitionGraph, edge_id: int) -> sp.Basic:
        pass

As can be seen from the implementation of a relativistic Breit-Wigner, the implementation of the expression is nicely kept together with the implementation of the expression:

class RelativisticBreitWigner(DynamicsFunction):
    @classmethod
    def eval(cls, *args: sp.Symbol) -> sp.Expr:
        mass = args[0]
        mass0 = args[1]
        gamma0 = args[2]
        return gamma0 * mass0 / (mass0**2 - mass**2 - gamma0 * mass0 * sp.I)

    @classmethod
    def from_graph(
        cls, graph: StateTransitionGraph, edge_id: int
    ) -> "RelativisticBreitWigner":
        edge_ids = determine_attached_final_state(graph, edge_id)
        final_state_ids = map(str, edge_ids)
        mass = sp.Symbol(f"m_{{{'+'.join(final_state_ids)}}}")
        particle, _ = graph.get_edge_props(edge_id)
        mass0 = sp.Symbol(f"m_{{{particle.latex}}}")
        gamma0 = sp.Symbol(Rf"\Gamma_{{{particle.latex}}}")
        return cls(mass, mass0, gamma0)

It becomes a bit less clear when using a form factor, but the DynamicsFunction base class enforces a correct interfaces:

class RelativisticBreitWignerWithFF(DynamicsFunction):
    @classmethod
    def eval(cls, *args: sp.Symbol) -> sp.Expr:
        # Arguments
        mass = args[0]
        mass0 = args[1]
        gamma0 = args[2]
        m_a = args[3]
        m_b = args[4]
        angular_momentum = args[5]
        meson_radius = args[6]
        # Computed variables
        q_squared = two_body_momentum_squared(mass, m_a, m_b)
        q0_squared = two_body_momentum_squared(mass0, m_a, m_b)
        ff2 = blatt_weisskopf(q_squared, meson_radius, angular_momentum)
        ff02 = blatt_weisskopf(q0_squared, meson_radius, angular_momentum)
        width = gamma0 * (mass0 / mass) * (ff2 / ff02)
        width = width * sp.sqrt(q_squared / q0_squared)
        # Expression
        return (
            RelativisticBreitWigner(mass, mass0, width)
            * mass0
            * gamma0
            * sp.sqrt(ff2)
        )

    @classmethod
    def from_graph(
        cls, graph: StateTransitionGraph, edge_id: int
    ) -> "RelativisticBreitWignerWithFF":
        edge_ids = determine_attached_final_state(graph, edge_id)
        final_state_ids = map(str, edge_ids)
        mass = sp.Symbol(f"m_{{{'+'.join(final_state_ids)}}}")
        particle, _ = graph.get_edge_props(edge_id)
        mass0 = sp.Symbol(f"m_{{{particle.latex}}}")
        gamma0 = sp.Symbol(Rf"\Gamma_{{{particle.latex}}}")
        m_a = sp.Symbol(f"m_{edge_ids[0]}")
        m_b = sp.Symbol(f"m_{edge_ids[1]}")
        angular_momentum = particle.spin  # helicity formalism only!
        meson_radius = sp.Symbol(Rf"R_{{{particle.latex}}}")
        return cls(
            mass,
            mass0,
            gamma0,
            m_a,
            m_b,
            angular_momentum,
            meson_radius,
        )

The expression_builder used in the syntax proposed at Considered solutions, would now just be a class that is derived of DynamicsFunction.

The sympy.Function class provides mixin methods, so that the derived class behaves as a sympy expression. So the expression can be inspected with the usual sympy tools (compare the Pros of Using composition):

m, m0, w0 = sp.symbols(R"m m_0 \Gamma")
evaluated_bw = RelativisticBreitWigner(m, 1.0, 0.3)
sp.plot(sp.Abs(evaluated_bw), (m, 0, 2), axis_center=(0, 0), ylim=(0, 1))
sp.plot(sp.arg(evaluated_bw), (m, 0, 2), axis_center=(0, 0), ylim=(0, sp.pi))
RelativisticBreitWigner(m, m0, w0)

Subclassing sympy.Expr#
Hide code cell content
from __future__ import annotations

from abc import abstractmethod
from typing import Callable

import sympy as sp
from helpers import (
    StateTransitionGraph,
    blatt_weisskopf,
    determine_attached_final_state,
    two_body_momentum_squared,
)

The major disadvantage of Subclassing sympy.Function, is that there is no way to identify which Symbols are variables and which are parameters. This can be solved by sub-classing from sympy.core.expr.Expr.

An example of a class that does this is WignerD. There, the implementation of the dynamics expression can be evaluated through a doit() call. This method can call anything, but sympy seems to follow the convention that it returns an ‘evaluated’ version of the class itself, where ‘evaluated’ means that any randomly named method of the class has been called on the *args that are implemented through the __new__ method (the examples below make this clearer).

For our purposes, the follow DynamicsExpr base class illustrates the interface that we expect. Here, evaluate is where expression is implemented and (just as in Subclassing sympy.Function) from_graph is the builder method.

class DynamicsExpr(sp.Expr):
    @classmethod
    @abstractmethod
    def __new__(cls, *args: sp.Symbol, **hints) -> sp.Expr:
        pass

    @abstractmethod
    def doit(self, **hints) -> sp.Expr:
        pass

    @abstractmethod
    def evaluate(self) -> sp.Expr:
        pass

    @classmethod
    @abstractmethod
    def from_graph(cls, graph: StateTransitionGraph, edge_id: int) -> sp.Basic:
        pass

The __new__ and doit methods split the construction from the evaluation of the expression. This allows one to distinguish variables and parameters and present them as properties:

class RelativisticBreitWigner(DynamicsExpr):
    def __new__(cls, *args: sp.Symbol, **hints) -> sp.Expr:
        if len(args) != 3:
            msg = f"3 parameters expected, got {len(args)}"
            raise ValueError(msg)
        args = sp.sympify(args)
        evaluate = hints.get("evaluate", False)
        if evaluate:
            return sp.Expr.__new__(cls, *args).evaluate()
        return sp.Expr.__new__(cls, *args)

    @property
    def mass(self) -> sp.Symbol:
        return self.args[0]

    @property
    def mass0(self) -> sp.Symbol:
        return self.args[1]

    @property
    def gamma0(self) -> sp.Symbol:
        return self.args[2]

    @property
    def variables(self) -> set[sp.Symbol]:
        return {self.mass}

    @property
    def parameters(self) -> set[sp.Symbol]:
        return {self.mass0, self.gamma0}

    def doit(self, **hints) -> sp.Expr:
        return RelativisticBreitWigner(*self.args, **hints, evaluate=True)

    def evaluate(self) -> sp.Expr:
        return (
            self.gamma0
            * self.mass0
            / (self.mass0**2 - self.mass**2 - self.gamma0 * self.mass0 * sp.I)
        )

    @classmethod
    def from_graph(
        cls, graph: StateTransitionGraph, edge_id: int
    ) -> RelativisticBreitWigner:
        edge_ids = determine_attached_final_state(graph, edge_id)
        final_state_ids = map(str, edge_ids)
        mass = sp.Symbol(f"m_{{{'+'.join(final_state_ids)}}}")
        particle, _ = graph.get_edge_props(edge_id)
        mass0 = sp.Symbol(f"m_{{{particle.latex}}}")
        gamma0 = sp.Symbol(Rf"\Gamma_{{{particle.latex}}}")
        return cls(mass, mass0, gamma0)
Hide code cell content
class RelativisticBreitWignerWithFF(DynamicsExpr):
    def __new__(cls, *args: sp.Symbol, **hints) -> sp.Expr:
        if len(args) != 7:
            msg = f"7 parameters expected, got {len(args)}"
            raise ValueError(msg)
        args = sp.sympify(args)
        evaluate = hints.get("evaluate", False)
        if evaluate:
            return sp.Expr.__new__(cls, *args).evaluate()
        return sp.Expr.__new__(cls, *args)

    def doit(self, **hints) -> sp.Expr:
        return RelativisticBreitWignerWithFF(*self.args, **hints, evaluate=True)

    @property
    def mass(self) -> sp.Symbol:
        return self.args[0]

    @property
    def mass0(self) -> sp.Symbol:
        return self.args[1]

    @property
    def gamma0(self) -> sp.Symbol:
        return self.args[2]

    @property
    def m_a(self) -> sp.Symbol:
        return self.args[3]

    @property
    def m_b(self) -> sp.Symbol:
        return self.args[4]

    @property
    def angular_momentum(self) -> sp.Symbol:
        return self.args[5]

    @property
    def meson_radius(self) -> sp.Symbol:
        return self.args[6]

    def evaluate(self) -> sp.Expr:
        # Computed variables
        q_squared = two_body_momentum_squared(self.mass, self.m_a, self.m_b)
        q0_squared = two_body_momentum_squared(self.mass0, self.m_a, self.m_b)
        ff2 = blatt_weisskopf(q_squared, self.meson_radius, self.angular_momentum)
        ff02 = blatt_weisskopf(q0_squared, self.meson_radius, self.angular_momentum)
        width = self.gamma0 * (self.mass0 / self.mass) * (ff2 / ff02)
        width = width * sp.sqrt(q_squared / q0_squared)
        # Expression
        return (
            RelativisticBreitWigner(self.mass, self.mass0, width)
            * self.mass0
            * self.gamma0
            * sp.sqrt(ff2)
        )

    @classmethod
    def from_graph(
        cls, graph: StateTransitionGraph, edge_id: int
    ) -> RelativisticBreitWignerWithFF:
        edge_ids = determine_attached_final_state(graph, edge_id)
        final_state_ids = map(str, edge_ids)
        mass = sp.Symbol(f"m_{{{'+'.join(final_state_ids)}}}")
        particle, _ = graph.get_edge_props(edge_id)
        mass0 = sp.Symbol(f"m_{{{particle.latex}}}")
        gamma0 = sp.Symbol(Rf"\Gamma_{{{particle.latex}}}")
        m_a = sp.Symbol(f"m_{edge_ids[0]}")
        m_b = sp.Symbol(f"m_{edge_ids[1]}")
        angular_momentum = particle.spin  # helicity formalism only!
        meson_radius = sp.Symbol(Rf"R_{{{particle.latex}}}")
        return cls(
            mass,
            mass0,
            gamma0,
            m_a,
            m_b,
            angular_momentum,
            meson_radius,
        )

The following illustrates the difference with Subclassing sympy.Function. First, notice that a class derived from DynamicsExpr is still identifiable upon construction:

m, m0, w0 = sp.symbols(R"m m_0 \Gamma")
rel_bw = RelativisticBreitWigner(m, m0, w0)
rel_bw
\[\displaystyle RelativisticBreitWigner\left(m, m_{0}, \Gamma\right)\]

The way in which this expression is rendered in a Jupyter notebook can be changed by overwriting the _pretty and/or _latex methods.

Only once doit() is called, is the DynamicsExpr converted into a mathematical expression:

evaluated_bw = rel_bw.doit()
evaluated_bw
\[\displaystyle \frac{\Gamma m_{0}}{- i \Gamma m_{0} - m^{2} + m_{0}^{2}}\]
sp.plot(
    sp.Abs(evaluated_bw.subs({m0: 1, w0: 0.2})),
    (m, 0, 2),
    axis_center=(0, 0),
    ylim=(0, 1),
);

Decorator#

There are a lot of implicit conventions that need to be followed to provide a correct implementation of a DynamicsExpr. Some of this may be mitigated by proving some class decorator that can easily construct the __new__() and doit() methods for you.

def dynamics_expression(
    n_args: int,
) -> Callable[[type], type[DynamicsExpr]]:
    def decorator(decorated_class: type) -> type[DynamicsExpr]:
        def __new__(cls, *args: sp.Symbol, **hints) -> sp.Expr:
            if len(args) != n_args:
                msg = f"{n_args} parameters expected, got {len(args)}"
                raise ValueError(msg)
            args = sp.sympify(args)
            evaluate = hints.get("evaluate", False)
            if evaluate:
                return sp.Expr.__new__(cls, *args).evaluate()
            return sp.Expr.__new__(cls, *args)

        def doit(self, **hints) -> sp.Expr:
            return decorated_class(*self.args, **hints, evaluate=True)

        decorated_class.__new__ = __new__
        decorated_class.doit = doit
        return decorated_class

    return decorator

This saves some lines of code:

@dynamics_expression(n_args=3)
class Gauss(DynamicsExpr):
    @property
    def mass(self) -> sp.Symbol:
        return self.args[0]

    @property
    def mu(self) -> sp.Symbol:
        return self.args[1]

    @property
    def sigma(self) -> sp.Symbol:
        return self.args[2]

    @property
    def variables(self) -> set[sp.Symbol]:
        return {self.mass}

    @property
    def parameters(self) -> set[sp.Symbol]:
        return {self.mu, self.sigma}

    def evaluate(self) -> sp.Expr:
        return sp.exp(-((self.mass - self.mu) ** 2) / self.sigma**2)

    @classmethod
    def from_graph(
        cls, graph: StateTransitionGraph, edge_id: int
    ) -> RelativisticBreitWigner:
        edge_ids = determine_attached_final_state(graph, edge_id)
        final_state_ids = map(str, edge_ids)
        mass = sp.Symbol(f"m_{{{'+'.join(final_state_ids)}}}")
        particle, _ = graph.get_edge_props(edge_id)
        mass0 = sp.Symbol(Rf"\mu_{{{particle.latex}}}")
        gamma0 = sp.Symbol(Rf"\sigma_{{{particle.latex}}}")
        return cls(mass, mass0, gamma0)
x, mu, sigma = sp.symbols(R"x \mu \sigma")
Gauss(x, mu, w0)
\[\displaystyle Gauss\left(x, \mu, \Gamma\right)\]
Gauss(x, mu, sigma).doit()
\[\displaystyle e^{- \frac{\left(- \mu + x\right)^{2}}{\sigma^{2}}}\]
Issue with lambdify#

It’s not possible to plot a DynamicsExpr directly as long as no lambdify hook has been provided: doit() has to be executed first.

sp.plot(sp.Abs(rel_bw.subs({m0: 1, w0: 0.2})).doit(), (m, 0, 2));

Group 2: expression-only#

A second branch of solutions would propose the following interface:

# model: ModelInfo
# graph: StateTransitionGraph
model.set_dynamics(graph, edge_id=1, expression)

The key difference is the usage of general sympy expression sympy.Expr as an argument instead of constructing this through some builder object.

Solution evaluation#
1: Expression builder#

All of the solutions have the drawback arising from the choice of interface using a expression_builder. This enforces the logic of correctly coupling variables and parameters into these builders. This is extremely hard to get right, since the code has to be able to handle arbitrarily complex models. And always knowing what the user would like to do is more or less impossible. Therefore it is much better to use a already built expression that is assumed to be correctly built (see solution group 2).

All of the solutions in this group also have the following additional drawbacks. These are however more related to the correct building of the dynamics expression:

  • There is an implicit assumption on the signature of the expression: the first arguments are assumed to be the (kinematic) variables and the remaining arguments are parameters. In addition, the arguments cannot be keywords, but have to be positional.

  • The number of variables and parameters is only verified at runtime (no static typing, other than a check that each of the elements is sympy.Symbol).

Composition is the cleanest design, but is less in tune with the design of sympy. Subclassing sympy.Function and Subclassing sympy.Expr follow sympy implementations, but result in an obscure inheritance hierarchy with implicit conventions. This can result in some nasty bugs, for instance if one were to __call__ method in either the sympy.Function or sympy.Expr implementation.

Pros and Cons that are specific to each of the implementations are listed below.

Using composition#
  • Positive

    • Implementation of the expression is transparent

  • Negative

    • Alternative signature.

    • The only way to see that relativistic_breit_wigner_from_graph is the builder for relativistic_breit_wigner, is from its name. This makes it implementing custom dynamics inconvenient and error-prone.

    • Signature of the builder can only be checked with a Protocol, see Type checking.

Subclassing sympy.Function#
  • Positive

    • DynamicsFunction behaves as a Function

    • Implementation of the builder is kept together with the implementation of the expression.

  • Negative

    • It’s not possible to identify variables and parameters

Subclassing sympy.Expr#
  • Positive

    • When recursing through the amplitude model, it is still possible to identify instances of DynamicsExpr (before doit() has been called).

    • Additional properties and methods can be added and carried around by the class.

  • Negative

2: Expression-only#

Positive: This choice of interface follows the principle of SOLID more than solution group 1. By handing a complete expression of the dynamics to the setter, its sole responsibility is to insert this expression at the correct place in the full model expression.

Negative: There are no direct negative aspects to this solution, as it just splits up responsibilities. The construction of the expression with the correct linking of parameters and initial values etc has to be performed by some other code. This code is subject to the same issues mentioned in the individual solutions of group 1.

Decision outcome#

Use a composition based solution from group 2.

Important is the definition of the interface following solution group 2. This ensures to be open-closed and keep the responsibilities separated.

The expertsystem favors composition over inheritance: we intend to use inheritance only to define interfaces, not to insert behavior. As such, the design of the expertsystem is fundamentally different than that of SymPy. That’s another reason to favor composition here: the interfaces are not determined by the dependency and instead remain contained within the dynamics class.

We decide to keep responsibilities as separated as possible. This means that:

  1. The only responsibility of set_dynamics method is to attribute some expression (sympy.Expr) the correct symbol within the complete amplitude model. For now, this position is specified using some StateTransitionGraph and an edge ID, but this syntax may be improved later (see ComPWA/expertsystem#458ps://github.com/ComPWA/expertsystem/issues/458)):

    def set_dynamics(
        self,
        graph: StateTransitionGraph,
        edge_id: int,
        expression: sp.Expr,
    ) -> None:
        # dynamics_symbol = graph + edge_id
        # self.dynamics[dynamics_symbol] = expression
        pass
    

    It is assumed that the expression is correct.

  2. The user has the responsibility of formulating the expression with the correct symbols. To aid the user in the construction of such expressions some building code can handle some of the common tasks, such as

    • A VariablePool can facilitate finding the correct symbol names (to avoid typos).

      mass = variable_pool.get_invariant_mass(graph, edge_id)
      
    • A dynamics module provides descriptions of common line-shapes as well as some helper functions. An example would be:

      inv_mass, mass0, gamma0 = build_relativistic_breit_wigner(graph, edge_id, particle)
      rel_bw: sympy.Expr = relativistic_breit_wigner(inv_mass, mass0, gamma0)
      model.set_dynamics(graph, edge, rel_bw)
      
  3. The SympyModel has the responsibility of defining a the full model in terms of an expression and keeping track of variables and parameters, for instance:

    from __future__ import annotations
    from attrs import define, field
    import sympy as sp
    
    
    @define
    class SympyModel:
        top: sp.Expr
        # intensities: dict[sp.Symbol, sp.Expr] = field(factory=dict)
        # amplitudes: dict[sp.Symbol, sp.Expr] = field(factory=dict)
        dynamics: dict[sp.Symbol, sp.Expr] = field(factory=dict)
        parameters: set[sp.Symbol] = field(factory=set)
        variables: set[sp.Symbol] = field(factory=set)  # or: VariablePool
    
        def full_expression(self) -> sp.Expr:
            ...
    

For new ADRs, please use adr/template.md as basis. This template was inspired by MADR. General information about architectural decision records is available at adr.github.io.

Technical reports#

These pages are a collection of findings while working on ComPWA packages such as ampform, qrules, and tensorwaves. Most of these findings were not implemented, but may become relevant later on or could be useful to other frameworks as well.

TR

Title

Details

Tags

Status

TR‑000

Square root over arrays with negative values

This notebook investigates how to write a square root function in sympy that computes the positive square root for negative values. The lambdified version of this ‘complex square root’ should have the same behavior for each computational backend.

lambdification sympy

✅ tensorwaves#284

TR‑001

Custom lambdification

See also SymPy’s tutorial page on the printing modules.

lambdification sympy

✅ ampform#72, tensorwaves#284

TR‑002

Faster lambdification by splitting expressions

This notebook investigates how to speed up sympy.lambdify by splitting up the expression tree of a complicated expression into components, lambdifying those, and then combining them back again.

lambdification sympy

✅ tensorwaves#281

TR‑003

Chew-Mandelstam S-wave and dispersion integrals

Section S-wave has been implemented in ampform#265.

physics sympy

TR‑004

Investigation of analyticity

physics

TR‑005

Symbolic K-matrix expressions

Implementation of this report is tracked through ampform#67.

physics

TR‑006

Interactive 3D plots

This report illustrates how to interact with matplotlib 3D plots through Matplotlib sliders and ipywidgets.

tips

✅ ampform#38

TR‑007

MatrixSymbols

This report is a sequel to TR-005. In that report, the \(\boldsymbol{K}\) was constructed with a sympy.Matrix, but it might be more elegant to work with MatrixSymbols.

sympy

TR‑008

Indexed free symbols

This report has been implemented in ampform#111. Additionally, tensorwaves#427 makes it possible to lambdify sympy.Expr with Indexed symbols directly.

sympy

✅ ampform#111

TR‑009

Symbolic expressions for Lorentz-invariant K-matrix

This report is a sequel to TR-005.

physics sympy

✅ ampform#120

TR‑010

P-vector

This report is a sequel to TR-005 and TR-009.

physics sympy

✅ ampform#131

TR‑011

Helicity angles as symbolic expressions

This report has been implemented in and ampform#177 and tensorwaves#345. The report contains some bugs which were also addressed in these PRs.

physics sympy

✅ ampform#177, tensorwaves#345

TR‑012

Extended DataSample performance

ampform#198 makes it easier to generate expressions for kinematic variables that are not contained in the HelicityModel.expression. In TensorWaves, this results in a DataSample with more keys.

A question was raised whether this affects the duration of fits. This report shows that this is not the case (see Conclusion).

lambdification sympy

TR‑013

Spin alignment with data

In this report, we attempt to check the effect of activating spin alignment (ampform#245) and compare it with Figure 2 in [Marangotto, 2020].

See also TR-014 and TR-015.

physics

TR‑014

Amplitude model with sum notation

See also TR-013 and TR-015.

sympy

✅ ampform#245

TR‑015

Spin alignment implementation

This report has been implemented through ampform#245. For details on how to use it, see this notebook. See also TR-013 and TR-014.

physics sympy

✅ ampform#245

TR‑016

Complex integral

As noted in TR-003, scipy.integrate.quad() cannot handle complex integrals. In addition, one can get into trouble with vectorized input (numpy.arrays) on a lambdified sympy.Integral. This report discusses both problems and proposes some solutions.

sympy

TR‑017

Symbolic phase space boundary for a three-body decay

This reports shows how define the physical phase space region on a Dalitz plot using a Kibble function.

physics sympy

✅ compwa.github.io#139

TR‑018

Intensity distribution generator with importance sampling

This reports sets out how data generation with TensorWaves works and what would be the best approach to tackle tensorwaves#402.

physics tensorwaves

TR‑019

Integrating Jupyter notebook with Julia notebooks in MyST-NB

This report shows how to define a Julia kernel for Jupyter notebooks, so that it can be executed and converted to static pages with MyST-NB.

DX tips

✅ compwa.github.io#174

TR‑020

Amplitude analysis with zfit

This reports builds a simple symbolic amplitude model with qrules and ampform and feeds it to zfit instead of tensorwaves.

physics sympy tensorwaves

✅ compwa.github.io#151

TR‑021

Polarimeter vector field

Mikhail Mikhasenko @mmikhasenko,
Remco de Boer @redeboer


This report formulates the polarimeter vector field for in \(\Lambda_c \to p\pi K\) with SymPy and visualizes it as an interactive widget with TensorWaves and ipywidgets.

physics polarimetry polarization

✅ compwa.github.io#129

TR‑022

Polarimetry: Computing the B-matrix for Λc→pKπ

The \(B\)-matrix forms an extension of the polarimeter vector field \(\vec\alpha\) (arXiv:2301.07010, see also TR-021) that takes the polarization of the proton into account. See arXiv:2302.07665, Eq. (B6).

physics polarimetry polarization

✅ compwa.github.io#196

TR‑023

Support for Plotly plots in Technical Reports

3d documentation jupyter sphinx

✅ compwa.github.io#206

TR‑024

Symbolic expressions and model serialization

Investigation into dumping SymPy expressions to human-readable format for model preservation. The notebook was motivated by the COMAP-V workshop on analysis preservation. See also SymPy printing, parsing, and expression manipulation.

documentation

🚧 polarimetry#319

TR‑025

Rotated square root cut

Investigation of the branch cut in the two Riemann sheets of a square root and what happens if the cut is rotated around \(z=0\).

✅ compwa.github.io#236

Complex square roots#

Hide code cell content
import inspect

import jax
import jax.numpy as jnp
import numpy as np
import sympy as sp
from black import FileMode, format_str
from IPython.display import display
Negative input values#

When using numpy as back-end, sympy lambdifies a sqrt() to a numpy.sqrt:

x = sp.Symbol("x")
sqrt_expr = sp.sqrt(x)
sqrt_expr
\[\displaystyle \sqrt{x}\]
np_sqrt = sp.lambdify(x, sqrt_expr, "numpy")
source = inspect.getsource(np_sqrt)
print(source)
def _lambdifygenerated(x):
    return (sqrt(x))

As expected, if input values for the numpy.sqrt are negative, numpy raises a RuntimeWarning and returns NaN:

sample = np.linspace(-1, 1, 5)
np_sqrt(sample)
array([       nan,        nan, 0.        , 0.70710678, 1.        ])

If we want numpy to return imaginary numbers for negative input values, one can use complex input data instead (e.g. numpy.complex64). Negative values are then treated as lying just above the real axis, so that their square root is a positive imaginary number:

complex_sample = sample.astype(np.complex64)
np_sqrt(complex_sample)
array([0.        +1.j        , 0.        +0.70710677j,
       0.        +0.j        , 0.70710677+0.j        ,
       1.        +0.j        ], dtype=complex64)

A sympy.sqrt lambdified to JAX exhibits the same behavior:

jax_sqrt = jax.jit(sp.lambdify(x, sqrt_expr, jnp))
source = inspect.getsource(jax_sqrt)
print(source)
def _lambdifygenerated(x):
    return (sqrt(x))
jax_sqrt(sample)
DeviceArray([       nan,        nan, 0.        , 0.70710677, 1.        ],            dtype=float32)
jax_sqrt(complex_sample)
DeviceArray([-4.3711388e-08+1.j        , -3.0908620e-08+0.70710677j,
              0.0000000e+00+0.j        ,  7.0710677e-01+0.j        ,
              1.0000000e+00+0.j        ], dtype=complex64)

There is a problem with this approach though: once input data is complex, all square roots in a larger expression (some amplitude model) compute imaginary solutions for negative values, while this is not always the desired behavior.

Take for instance the two square roots appearing in PhaseSpaceFactor — does the \(\sqrt{s}\) also have to be evaluatable for negative \(s\)?

Complex square root#

Numpy also offers a special function that evaluates negative values even if the input values are real: numpy.emath.sqrt():

np.emath.sqrt(-1)
1j

Unfortunately, the jax.numpy API does not interface to numpy.emath. It is possible to decorate numpy.emath.sqrt() be decorated with jax.jit(), but that only works with static, hashable arguments:

jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
jax_csqrt_error(-1)
Hide code cell output
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module>
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs)
    975 app.initialize(argv)
--> 976 app.start()

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self)
    711 try:
--> 712     self.io_loop.start()
    713 except KeyboardInterrupt:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self)
    198     asyncio.set_event_loop(self.asyncio_loop)
--> 199     self.asyncio_loop.run_forever()
    200 finally:

File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self)
    569 while True:
--> 570     self._run_once()
    571     if self._stopping:

File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self)
   1858     else:
-> 1859         handle._run()
   1860 handle = None

File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/events.py:81, in Handle._run(self)
     80 try:
---> 81     self._context.run(self._callback, *self._args)
     82 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self)
    509 try:
--> 510     await self.process_one()
    511 except Exception:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait)
    498         return None
--> 499 await dispatch(*args)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg)
    405     if inspect.isawaitable(result):
--> 406         await result
    407 except Exception:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent)
    729 if inspect.isawaitable(reply_content):
--> 730     reply_content = await reply_content
    732 # Flush output before sending the reply.

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id)
    382 if with_cell_id:
--> 383     res = shell.run_cell(
    384         code,
    385         store_history=store_history,
    386         silent=silent,
    387         cell_id=cell_id,
    388     )
    389 else:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs)
    527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
   2880 try:
-> 2881     result = self._run_cell(
   2882         raw_cell, store_history, silent, shell_futures, cell_id
   2883     )
   2884 finally:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
   2935 try:
-> 2936     return runner(coro)
   2937 except BaseException as e:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id)
   3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3136        interactivity=interactivity, compiler=compiler, result=result)
   3138 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result)
   3337     asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
   3339     return True

    [... skipping hidden 1 frame]

Input In [13], in <cell line: 2>()
      1 jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
----> 2 jax_csqrt_error(-1)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs)
    142 try:
--> 143   return fun(*args, **kwargs)
    144 except Exception as e:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs)
    425 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 426 out_flat = xla.xla_call(
    427     flat_fun,
    428     *args_flat,
    429     device=device,
    430     backend=backend,
    431     name=flat_fun.__name__,
    432     donated_invars=donated_invars)
    433 out_pytree_def = out_tree()

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params)
   1564 def bind(self, fun, *args, **params):
-> 1565   return call_bind(self, fun, *args, **params)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params)
   1555 with maybe_new_sublevel(top_trace):
-> 1556   outs = primitive.process(top_trace, fun, tracers, params)
   1557 return map(full_lower, apply_todos(env_trace_todo(), outs))

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params)
   1567 def process(self, trace, fun, tracers, params):
-> 1568   return trace.process_call(self, fun, tracers, params)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params)
    608 def process_call(self, primitive, f, tracers, params):
--> 609   return primitive.impl(f, *tracers, **params)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 578   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    579                                *unsafe_map(arg_spec, args))
    580   try:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args)
    261 else:
--> 262   ans = call(fun, *args)
    263   cache[key] = (ans, fun.stores)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    651 abstract_args, _ = unzip2(arg_specs)
--> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
    653 if any(isinstance(c, core.Tracer) for c in consts):

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name)
   1208 main.jaxpr_stack = ()  # type: ignore
-> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1210 del fun, main

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1187 in_tracers = map(trace.new_arg, in_avals)
-> 1188 ans = fun.call_wrapped(*in_tracers)
   1189 out_tracers = map(trace.full_raise, ans)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
    165 try:
--> 166   ans = self.f(*args, **dict(self.params, **kwargs))
    167 except:
    168   # Some transformations yield from inside context managers, so we have to
    169   # interrupt them before reraising the exception. Otherwise they will only
    170   # get garbage-collected at some later time, running their cleanup tasks only
    171   # after this exception is handled, which can corrupt the global state.

File <__array_function__ internals>:180, in sqrt(*args, **kwargs)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/numpy/lib/scimath.py:247, in sqrt(x)
    200 """
    201 Compute the square root of x.
    202 
   (...)
    245 -2j
    246 """
--> 247 x = _fix_real_lt_zero(x)
    248 return nx.sqrt(x)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/numpy/lib/scimath.py:134, in _fix_real_lt_zero(x)
    113 """Convert `x` to complex if it has real, negative components.
    114 
    115 Otherwise, output is just the array version of the input (via asarray).
   (...)
    132 
    133 """
--> 134 x = asarray(x)
    135 if any(isreal(x) & (x < 0)):

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:472, in Tracer.__array__(self, *args, **kw)
    471 def __array__(self, *args, **kw):
--> 472   raise TracerArrayConversionError(self)

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
Input In [13], in <cell line: 2>()
      1 jax_csqrt_error = jax.jit(np.emath.sqrt, backend="cpu")
----> 2 jax_csqrt_error(-1)

File <__array_function__ internals>:180, in sqrt(*args, **kwargs)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/numpy/lib/scimath.py:247, in sqrt(x)
    198 @array_function_dispatch(_unary_dispatcher)
    199 def sqrt(x):
    200     """
    201     Compute the square root of x.
    202 
   (...)
    245     -2j
    246     """
--> 247     x = _fix_real_lt_zero(x)
    248     return nx.sqrt(x)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/numpy/lib/scimath.py:134, in _fix_real_lt_zero(x)
    112 def _fix_real_lt_zero(x):
    113     """Convert `x` to complex if it has real, negative components.
    114 
    115     Otherwise, output is just the array version of the input (via asarray).
   (...)
    132 
    133     """
--> 134     x = asarray(x)
    135     if any(isreal(x) & (x < 0)):
    136         x = _tocomplex(x)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)
jax_csqrt = jax.jit(np.emath.sqrt, backend="cpu", static_argnums=0)
jax_csqrt(-1)
DeviceArray(0.+1.j, dtype=complex64)
jax_csqrt(sample)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [15], in <cell line: 1>()
----> 1 jax_csqrt(sample)

ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1.  -0.5  0.   0.5  1. ]. The error was:
TypeError: unhashable type: 'numpy.ndarray'
Conditional square root#

To be able to control which square roots in the complete expression should be evaluatable for negative values, one could use Piecewise:

def complex_sqrt(x: sp.Symbol) -> sp.Expr:
    return sp.Piecewise(
        (sp.sqrt(-x) * sp.I, x < 0),
        (sp.sqrt(x), True),
    )


complex_sqrt(x)
\[\begin{split}\displaystyle \begin{cases} i \sqrt{- x} & \text{for}\: x < 0 \\\sqrt{x} & \text{otherwise} \end{cases}\end{split}\]
display(
    complex_sqrt(-4),
    complex_sqrt(+4),
)
\[\displaystyle 2 i\]
\[\displaystyle 2\]

Be careful though when lambdifying this expression: do not use the __dict__ of the numpy module as backend, but use the module itself instead. When using __dict__, lambdify() will return an if-else statement, which is inefficient and, worse, will result in problems with JAX:

Warning

Do not use the module __dict__ for the modules argument of lambdify().

np_complex_sqrt_no_select = sp.lambdify(x, complex_sqrt(x), np.__dict__)
source = inspect.getsource(np_complex_sqrt_no_select)
print(source)
def _lambdifygenerated(x):
    return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))
np_complex_sqrt_no_select(-1)
1j
jax_complex_sqrt_no_select = jax.jit(np_complex_sqrt_no_select)
jax_complex_sqrt_no_select(-1)

When instead using the numpy module (or "numpy"), lambdify() correctly lambdifies to numpy.select() to represent the cases.

np_complex_sqrt = sp.lambdify(x, complex_sqrt(x), np)
source = inspect.getsource(np_complex_sqrt)
def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
        default=nan,
    )

Still, JAX does not handle this correctly. First, lambdifying JAX again results in this if-else syntax:

jnp_complex_sqrt = sp.lambdify(x, complex_sqrt(x), jnp)
source = inspect.getsource(jnp_complex_sqrt)
print(source)
def _lambdifygenerated(x):
    return (((1j*sqrt(-x)) if (x < 0) else (sqrt(x))))

But even if we lambdify to numpy and decorate the result with a jax.jit() decorator, the resulting function does not work properly:

jax_complex_sqrt_error = jax.jit(np_complex_sqrt)
source = inspect.getsource(jax_complex_sqrt_error)
def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
        default=nan,
    )
jax_complex_sqrt_error(-1)
Hide code cell output
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel_launcher.py:17, in <module>
     15 from ipykernel import kernelapp as app
---> 17 app.launch_new_instance()

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/traitlets/config/application.py:976, in Application.launch_instance(cls, argv, **kwargs)
    975 app.initialize(argv)
--> 976 app.start()

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelapp.py:712, in IPKernelApp.start(self)
    711 try:
--> 712     self.io_loop.start()
    713 except KeyboardInterrupt:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/tornado/platform/asyncio.py:199, in BaseAsyncIOLoop.start(self)
    198     asyncio.set_event_loop(self.asyncio_loop)
--> 199     self.asyncio_loop.run_forever()
    200 finally:

File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:570, in BaseEventLoop.run_forever(self)
    569 while True:
--> 570     self._run_once()
    571     if self._stopping:

File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/base_events.py:1859, in BaseEventLoop._run_once(self)
   1858     else:
-> 1859         handle._run()
   1860 handle = None

File ~/miniconda3/envs/compwa-org/lib/python3.8/asyncio/events.py:81, in Handle._run(self)
     80 try:
---> 81     self._context.run(self._callback, *self._args)
     82 except (SystemExit, KeyboardInterrupt):

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:510, in Kernel.dispatch_queue(self)
    509 try:
--> 510     await self.process_one()
    511 except Exception:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:499, in Kernel.process_one(self, wait)
    498         return None
--> 499 await dispatch(*args)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:406, in Kernel.dispatch_shell(self, msg)
    405     if inspect.isawaitable(result):
--> 406         await result
    407 except Exception:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/kernelbase.py:730, in Kernel.execute_request(self, stream, ident, parent)
    729 if inspect.isawaitable(reply_content):
--> 730     reply_content = await reply_content
    732 # Flush output before sending the reply.

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/ipkernel.py:383, in IPythonKernel.do_execute(self, code, silent, store_history, user_expressions, allow_stdin, cell_id)
    382 if with_cell_id:
--> 383     res = shell.run_cell(
    384         code,
    385         store_history=store_history,
    386         silent=silent,
    387         cell_id=cell_id,
    388     )
    389 else:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/ipykernel/zmqshell.py:528, in ZMQInteractiveShell.run_cell(self, *args, **kwargs)
    527 self._last_traceback = None
--> 528 return super().run_cell(*args, **kwargs)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2881, in InteractiveShell.run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
   2880 try:
-> 2881     result = self._run_cell(
   2882         raw_cell, store_history, silent, shell_futures, cell_id
   2883     )
   2884 finally:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:2936, in InteractiveShell._run_cell(self, raw_cell, store_history, silent, shell_futures, cell_id)
   2935 try:
-> 2936     return runner(coro)
   2937 except BaseException as e:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/async_helpers.py:129, in _pseudo_sync_runner(coro)
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3135, in InteractiveShell.run_cell_async(self, raw_cell, store_history, silent, shell_futures, transformed_cell, preprocessing_exc_tuple, cell_id)
   3133 interactivity = "none" if silent else self.ast_node_interactivity
-> 3135 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3136        interactivity=interactivity, compiler=compiler, result=result)
   3138 self.last_execution_succeeded = not has_raised

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/IPython/core/interactiveshell.py:3338, in InteractiveShell.run_ast_nodes(self, nodelist, cell_name, interactivity, compiler, result)
   3337     asy = compare(code)
-> 3338 if await self.run_code(code, result, async_=asy):
   3339     return True

    [... skipping hidden 1 frame]

Input In [26], in <cell line: 1>()
----> 1 jax_complex_sqrt_error(-1)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/traceback_util.py:143, in api_boundary.<locals>.reraise_with_filtered_traceback(*args, **kwargs)
    142 try:
--> 143   return fun(*args, **kwargs)
    144 except Exception as e:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/_src/api.py:426, in _cpp_jit.<locals>.cache_miss(*args, **kwargs)
    425 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 426 out_flat = xla.xla_call(
    427     flat_fun,
    428     *args_flat,
    429     device=device,
    430     backend=backend,
    431     name=flat_fun.__name__,
    432     donated_invars=donated_invars)
    433 out_pytree_def = out_tree()

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1565, in CallPrimitive.bind(self, fun, *args, **params)
   1564 def bind(self, fun, *args, **params):
-> 1565   return call_bind(self, fun, *args, **params)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1556, in call_bind(primitive, fun, *args, **params)
   1555 with maybe_new_sublevel(top_trace):
-> 1556   outs = primitive.process(top_trace, fun, tracers, params)
   1557 return map(full_lower, apply_todos(env_trace_todo(), outs))

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:1568, in CallPrimitive.process(self, trace, fun, tracers, params)
   1567 def process(self, trace, fun, tracers, params):
-> 1568   return trace.process_call(self, fun, tracers, params)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:609, in EvalTrace.process_call(self, primitive, f, tracers, params)
    608 def process_call(self, primitive, f, tracers, params):
--> 609   return primitive.impl(f, *tracers, **params)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:578, in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    577 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 578   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
    579                                *unsafe_map(arg_spec, args))
    580   try:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:262, in cache.<locals>.memoized_fun(fun, *args)
    261 else:
--> 262   ans = call(fun, *args)
    263   cache[key] = (ans, fun.stores)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/xla.py:652, in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    651 abstract_args, _ = unzip2(arg_specs)
--> 652 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
    653 if any(isinstance(c, core.Tracer) for c in consts):

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1209, in trace_to_jaxpr_final(fun, in_avals, transform_name)
   1208 main.jaxpr_stack = ()  # type: ignore
-> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1210 del fun, main

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1188, in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1187 in_tracers = map(trace.new_arg, in_avals)
-> 1188 ans = fun.call_wrapped(*in_tracers)
   1189 out_tracers = map(trace.full_raise, ans)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/linear_util.py:166, in WrappedFun.call_wrapped(self, *args, **kwargs)
    165 try:
--> 166   ans = self.f(*args, **dict(self.params, **kwargs))
    167 except:
    168   # Some transformations yield from inside context managers, so we have to
    169   # interrupt them before reraising the exception. Otherwise they will only
    170   # get garbage-collected at some later time, running their cleanup tasks only
    171   # after this exception is handled, which can corrupt the global state.

File <lambdifygenerated-4>:2, in _lambdifygenerated(x)
      1 def _lambdifygenerated(x):
----> 2     return (select([less(x, 0),True], [1j*sqrt(-x),sqrt(x)], default=nan))

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/jax/core.py:472, in Tracer.__array__(self, *args, **kw)
    471 def __array__(self, *args, **kw):
--> 472   raise TracerArrayConversionError(self)

UnfilteredStackTrace: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TracerArrayConversionError                Traceback (most recent call last)
Input In [26], in <cell line: 1>()
----> 1 jax_complex_sqrt_error(-1)

File <lambdifygenerated-4>:2, in _lambdifygenerated(x)
      1 def _lambdifygenerated(x):
----> 2     return (select([less(x, 0),True], [1j*sqrt(-x),sqrt(x)], default=nan))

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)> (https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError)

The very same function in created purely with jax.numpy does work without problems, so it seems this is a SymPy problem:

@jax.jit
def jax_complex_sqrt(x):
    return jnp.select(
        [jnp.less(x, 0), True],
        [1j * jnp.sqrt(-x), jnp.sqrt(x)],
        default=jnp.nan,
    )
jax_complex_sqrt(sample)
DeviceArray([0.        +1.j        , 0.        +0.70710677j,
             0.        +0.j        , 0.70710677+0.j        ,
             1.        +0.j        ], dtype=complex64)

A solution to this is presented in Handle for JAX.

Custom lambdification#

Hide code cell content
import inspect

import jax
import jax.numpy as jnp
import numpy as np
import sympy as sp
from black import FileMode, format_str
Overwrite printer methods#

As noted in TR-000, it’s hard to lambdify a sympy.sqrt to JAX. One possible way out is to define a custom class that derives from sympy.Expr and overwrite its printer methods.

from sympy.printing.printer import Printer


class ComplexSqrt(sp.Expr):
    def __new__(cls, x, *args, **kwargs):
        x = sp.sympify(x)
        expr = sp.Expr.__new__(cls, x, *args, **kwargs)
        if hasattr(x, "free_symbols") and not x.free_symbols:
            return expr.evaluate()
        return expr

    def evaluate(self):
        x = self.args[0]
        if not x.is_real:
            return sp.sqrt(x)
        return sp.Piecewise(
            (sp.I * sp.sqrt(-x), x < 0),
            (sp.sqrt(x), True),
        )

    def _latex(self, printer: Printer, *args) -> str:
        x = printer._print(self.args[0])
        return Rf"\sqrt[\mathrm{{c}}]{{{x}}}"

    def _numpycode(self, printer: Printer, *args) -> str:
        printer.module_imports["numpy.lib"].add("scimath")
        x = printer._print(self.args[0])
        return f"scimath.sqrt({x})"

    def _pythoncode(self, printer: Printer, *args) -> str:
        printer.module_imports["cmath"].add("sqrt as csqrt")
        x = printer._print(self.args[0])
        return f"csqrt({x})"

As opposed to the derivation of a sympy.Expr, this class evaluates directly, because the evaluate key-word argument is not used processed by the __new__ method:

ComplexSqrt(-4)
\[\displaystyle 2 i\]

The _latex() method ensures that ComplexSqrt renders nicely in notebooks:

x = sp.Symbol("x")
ComplexSqrt(x)
\[\displaystyle \sqrt[\mathrm{c}]{x}\]
Plot custom class#

In addition, one may modify this Lambdifier class, so that sympy.plot() also works on this custom class:

from sympy.plotting.experimental_lambdify import Lambdifier

Lambdifier.builtin_functions_different["ComplexSqrt"] = "sqrt"
%config InlineBackend.figure_formats = ['svg']
x = sp.Symbol("x")
expr = ComplexSqrt(x)
p1 = sp.plot(sp.re(expr), (x, -1, 2), show=False, line_color="red")
p2 = sp.plot(sp.im(expr), (x, -1, 2), show=False)
p1.append(p2[0])
p1.show()

Lambdifying#

The important part, lambdifying to numpy or math works well as well now:

lambdified_py = sp.lambdify(x, ComplexSqrt(x), "math")
source = inspect.getsource(lambdified_py)
print(source)
def _lambdifygenerated(x):
    return (csqrt(x))
numpy_lambdified = sp.lambdify(x, ComplexSqrt(x), "numpy")
source = inspect.getsource(numpy_lambdified)
print(source)
def _lambdifygenerated(x):
    return (scimath.sqrt(x))
sample = np.linspace(-1, +1, 5)
numpy_lambdified(sample)
array([0.        +1.j        , 0.        +0.70710678j,
       0.        +0.j        , 0.70710678+0.j        ,
       1.        +0.j        ])

Just as noted in Complex square root though, numpy.emath is not provided by the NumPy API of JAX. As discussed there, we can at most decorate the numpy.emath version with jax.jit() and work with static arguments only:

jax_lambdified = jax.jit(numpy_lambdified, backend="cpu", static_argnums=0)
jax_lambdified(-1)
DeviceArray(0.+1.j, dtype=complex64)

In that case, unhashable (non-static) input samples are still not accepted:

jax_lambdified(sample)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-12-94f797bd7204> in <module>
----> 1 jax_lambdified(sample)

ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class 'numpy.ndarray'>, [-1.  -0.5  0.   0.5  1. ]. The error was:
TypeError: unhashable type: 'numpy.ndarray'
Handle for JAX#

As concluded in Conditional square root, the alternative to lambdify to numpy.emath is to lambdify to numpy.select(). This has some caveats, though, like that you should not use __dict__. Worse, JAX is not immediately supported as backend. Fortunately, we now know how to overwrite lambdify methods.

An additional tool we need now is to define a new printer class for JAX, so that we can also define a special rendering method for ComplexSqrt in the case of JAX. Most of its printing methods should be the same as that of SymPy’s NumPyPrinter, the rest we can overwrite:

Note

Alternative would be to add a method _jaxcode to the ComplexSqrt class above. See Printing.

from sympy.printing.numpy import NumPyPrinter


class JaxPrinter(NumPyPrinter):
    _module = "jax"

    def _print_ComplexSqrt(self, expr: sp.Expr) -> str:
        arg = expr.args[0]
        x = self._print(arg)
        return (
            f"select([less({x}, 0), True], [1j * sqrt(-{x}), sqrt({x})],"
            " default=nan,)"
        )
numpy_expr = sp.lambdify(x, ComplexSqrt(x), modules=np, printer=JaxPrinter)
source = inspect.getsource(numpy_expr)
def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
        default=nan,
    )
jax_expr = sp.lambdify(x, ComplexSqrt(x), modules=jnp, printer=JaxPrinter)
source = inspect.getsource(jax_expr)
def _lambdifygenerated(x):
    return select(
        [less(x, 0), True],
        [1j * sqrt(-x), sqrt(x)],
        default=nan,
    )
jax_expr(sample)
DeviceArray([0.        +1.j        , 0.        +0.70710677j,
             0.        +0.j        , 0.70710677+0.j        ,
             1.        +0.j        ], dtype=complex64)

The lambdified function can of course also be decorated with jax.jit():

jit_expr = jax.jit(jax_expr)
Performance check#
rng = np.random.default_rng()
sample = rng.normal(size=1_000_000)
jax_sample = jnp.array(sample)
%timeit jit_expr(jax_sample)
1.91 ms Âą 116 Âľs per loop (mean Âą std. dev. of 7 runs, 100 loops each)
%timeit jax_expr(jax_sample)
6.31 ms Âą 42.6 Âľs per loop (mean Âą std. dev. of 7 runs, 100 loops each)
%timeit numpy_expr(sample)
16.9 ms Âą 614 Âľs per loop (mean Âą std. dev. of 7 runs, 100 loops each)

Speed up lambdifying#

Hide code cell content
%config InlineBackend.figure_formats = ['svg']
from __future__ import annotations

import inspect
import logging
import timeit
import warnings
from typing import Callable, Generator, Sequence

import ampform
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import qrules
import sympy as sp
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from tensorwaves.data import generate_phsp
from tensorwaves.data.phasespace import TFUniformRealNumberGenerator
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.model import LambdifiedFunction, SympyModel

LOGGER = logging.getLogger()
Create dummy expression#

First, let’s create an amplitude model with ampform. We’ll use this model as complicated sympy.Expr in the rest of this notebooks.

result = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)(980)"],
    allowed_interaction_types=["strong", "EM"],
    formalism_type="canonical-helicity",
)
dot = qrules.io.asdot(result, collapse_graphs=True, render_final_state_id=False)
graphviz.Source(dot)

model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.generate()
free_symbols = sorted(model.expression.free_symbols, key=lambda s: s.name)
free_symbols
[C[J/\psi(1S) \to f_{0}(980)_{0} \gamma_{+1}; f_{0}(980) \to \pi^{0}_{0} \pi^{0}_{0}],
 Gamma_f(0)(980),
 d_f(0)(980),
 m_1,
 m_12,
 m_2,
 m_f(0)(980),
 phi_1+2,
 phi_1,1+2,
 theta_1+2,
 theta_1,1+2]
Helicity model components#

A HelicityModel has the benefit that it comes with components (intensities and amplitudes) that together form its expression. Let’s separate these components into amplitude and intensity.

amplitudes = {
    name: expr for name, expr in model.components.items() if name.startswith("A")
}
sorted(amplitudes)
['A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{+1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{+1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{-1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{+1} \\to f_{0}(980)_{0} \\gamma_{-1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{+1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{+1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1,L=0,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]',
 'A[J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1,L=2,S=1}; f_{0}(980)_{0} \\to \\pi^{0}_{0} \\pi^{0}_{0,L=0,S=0}]']
intensities = {
    name: expr for name, expr in model.components.items() if name.startswith("I")
}
assert len(amplitudes) + len(intensities) == len(model.components)
Component structure#

Note that each intensity consists of a subset of these amplitudes. This means that intensities have a larger expression tree than amplitudes.

amplitude_to_symbol = {
    expr: sp.Symbol(f"A{i}") for i, expr in enumerate(amplitudes.values(), 1)
}
intensity_to_symbol = {
    expr: sp.Symbol(f"I{i}") for i, expr in enumerate(intensities.values(), 1)
}
intensity_expr = model.expression.subs(intensity_to_symbol, simultaneous=True)
intensity_expr
\[\displaystyle I_{1} + I_{2} + I_{3} + I_{4}\]
dot = sp.dotprint(intensity_expr)
graphviz.Source(dot)

amplitude_expr = model.expression.subs(amplitude_to_symbol, simultaneous=True)
amplitude_expr
\[\displaystyle \left|{A_{1} + A_{2}}\right|^{2} + \left|{A_{3} + A_{4}}\right|^{2} + \left|{A_{5} + A_{6}}\right|^{2} + \left|{A_{7} + A_{8}}\right|^{2}\]
dot = sp.dotprint(amplitude_expr)
graphviz.Source(dot)

Performance check#

Lambdifying the whole HelicityModel.expression is slowest. The lambdify() function first prints the expression as a str (!) with (in this case) numpy syntax and then uses eval() to convert that back to actual numpy objects:

Hide code cell content
runtime = {}
start = timeit.default_timer()
%%time
np_complete_model = sp.lambdify(free_symbols, model.expression.doit(), "numpy")
CPU times: user 1.46 s, sys: 703 Âľs, total: 1.46 s
Wall time: 1.46 s
Hide code cell content
stop = timeit.default_timer()
runtime["complete model"] = stop - start

Printing to str and converting back with eval() becomes exponentially slow the larger the expression tree. This means that it’s more efficient to lambdify sub-trees of the expression tree separately. Lambdifying the four intensities of this model separately, the effect is not noticeable:

%%time
for expr, symbol in intensity_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    start = timeit.default_timer()
    sp.lambdify(free_symbols, expr.doit(), "numpy")
    stop = timeit.default_timer()
    runtime[symbol.name] = stop - start
CPU times: user 1.56 s, sys: 4.94 ms, total: 1.56 s
Wall time: 1.56 s

…but each of the eight amplitudes separately does result in a significant speed-up:

%%time
np_amplitudes = {}
for expr, symbol in amplitude_to_symbol.items():
    logging.info(f"Lambdifying {symbol.name}")
    start = timeit.default_timer()
    np_expr = sp.lambdify(free_symbols, expr.doit(), "numpy")
    stop = timeit.default_timer()
    runtime[symbol.name] = stop - start
    np_amplitudes[symbol] = np_expr
CPU times: user 547 ms, sys: 3.85 ms, total: 550 ms
Wall time: 547 ms
Recombining components#

Recall what amplitude module expressed in its amplitude components looks like:

amplitude_expr
\[\displaystyle \left|{A_{1} + A_{2}}\right|^{2} + \left|{A_{3} + A_{4}}\right|^{2} + \left|{A_{5} + A_{6}}\right|^{2} + \left|{A_{7} + A_{8}}\right|^{2}\]

We have to lambdify that top expression as well:

sorted_amplitude_symbols = sorted(np_amplitudes, key=lambda s: s.name)
np_amplitude_expr = sp.lambdify(sorted_amplitude_symbols, amplitude_expr, "numpy")
source = inspect.getsource(np_amplitude_expr)
print(source)
def _lambdifygenerated(A1, A2, A3, A4, A5, A6, A7, A8):
    return (abs(A1 + A2)**2 + abs(A3 + A4)**2 + abs(A5 + A6)**2 + abs(A7 + A8)**2)

We now have a lambdified expression for the complete amplitude model, as well as lambdified expressions that are to be plugged in to its arguments.

def componentwise_lambdified(*args):
    """Lambdified amplitude model, recombined from its amplitude components.

    .. warning:: Order of the ``args`` has to be the same as that
        of the ``args`` of the lambdified amplitude components.
    """
    amplitude_values = []
    for amp_symbol in sorted_amplitude_symbols:
        np_amplitude = np_amplitudes[amp_symbol]
        values = np_amplitude(*args)
        amplitude_values.append(values)
    return np_amplitude_expr(*amplitude_values)
Test with data#

Okay, so does all this work? Let’s first generate a phase space sample with good-old tensorwaves. We can then use this sample as input to the component-wise lambdified function.

sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameter_defaults,
)
intensity = LambdifiedFunction(sympy_model, backend="jax")
data_converter = HelicityTransformer(model.adapter)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_sample = generate_phsp(
    10_000, model.adapter.reaction_info, random_generator=rng
)
phsp_set = data_converter.transform(phsp_sample)
Hide code cell source
fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(phsp_set["m_12"], bins=50, alpha=0.5, density=True)
ax.hist(
    phsp_set["m_12"],
    bins=50,
    alpha=0.5,
    density=True,
    weights=np.array(intensity(phsp_set)),
)
plt.show()

The arguments of the component-wise lambdified amplitude model should be covered by the entries in the phase space set and the provided parameter defaults:

kinematic_variable_names = set(phsp_set)
parameter_names = {symbol.name for symbol in model.parameter_defaults}
free_symbol_names = {symbol.name for symbol in free_symbols}
assert free_symbol_names <= kinematic_variable_names ^ parameter_names

That allows us to sort the input arrays and parameter defaults so that they can be used as positional argument input to the component-wise lambdified amplitude model:

merged_par_var_values = {
    symbol.name: value for symbol, value in model.parameter_defaults.items()
}
merged_par_var_values.update(phsp_set)
args_values = [merged_par_var_values[symbol.name] for symbol in free_symbols]

Finally, here’s the result of plugging that back into the component-wise lambdified expression:

componentwise_result = componentwise_lambdified(*args_values)
componentwise_result
array([0.00048765, 0.00033425, 0.00524706, ..., 0.00140122, 0.00714365,
       0.00030117])

And it’s indeed the same as that the intensity computed by tensorwaves (direct lambdify):

tensorwaves_result = np.array(intensity(phsp_set))
mean_difference = (componentwise_result - tensorwaves_result).mean()
mean_difference
-7.307471250984975e-11
Arbitrary expressions#

The problem with Test with data is that it requires a HelicityModel. In tensorwaves, we want to work with general sympy.Exprs though (see SympyModel), where we don’t have sub-ampform.helicity.HelicityModel.components available.

Instead, we have to split up the lambdifying in a more general way that can handle arbitrary sympy.core.expr.Exprs. For that we need:

  1. A general method of traversing through a SymPy expression tree. This can be done with Advanced Expression Manipulation.

  2. A fast method to estimate the complexity of a model, so that we can decide whether a node in the expression tree is small enough to be lambdified without much runtime. The best measure for complexity is count_ops() (“count operations”), see notes under Simplify.

Expression complexity#

Let’s tackle 2. first and use the HelicityModel.expression and its components that we lambdified earlier on. Here’s an overview of the number of operations versus the time it took to lambdify each component:

Hide code cell source
df = pd.DataFrame(runtime.values(), index=runtime, columns=["runtime (s)"])
operations = [sp.count_ops(model.expression)]
operations.extend(sp.count_ops(expr) for expr in intensity_to_symbol)
operations.extend(sp.count_ops(expr) for expr in amplitude_to_symbol)
df.insert(0, "operations", operations)
df
operations runtime (s)
complete model 823 0.980456
I1 209 0.279897
I2 203 0.235227
I3 207 0.215937
I4 201 0.233635
A1 103 0.045300
A2 103 0.040710
A3 100 0.039767
A4 100 0.035684
A5 102 0.036551
A6 102 0.036198
A7 99 0.042208
A8 99 0.040205

From this we can already roughly see that the lambdify runtime scales roughly with the number of SymPy operations.

To better visualize this, we can lambdify the expressions in BlattWeisskopfSquared for each angular momentums and compute their runtime a number of times with timeit. Note that the BlattWeisskopfSquared becomes increasingly complex the higher the angular momentum.

Hide code cell source
from ampform.dynamics import BlattWeisskopfSquared

angular_momentum, z = sp.symbols("L z")
BlattWeisskopfSquared(angular_momentum, z).doit()
\[\begin{split}\displaystyle \begin{cases} 1 & \text{for}\: L = 0 \\\frac{2 z}{z + 1} & \text{for}\: L = 1 \\\frac{13 z^{2}}{9 z + \left(z - 3\right)^{2}} & \text{for}\: L = 2 \\\frac{277 z^{3}}{z \left(z - 15\right)^{2} + \left(2 z - 5\right) \left(18 z - 45\right)} & \text{for}\: L = 3 \\\frac{12746 z^{4}}{25 z \left(2 z - 21\right)^{2} + \left(z^{2} - 45 z + 105\right)^{2}} & \text{for}\: L = 4 \\\frac{998881 z^{5}}{z^{5} + 15 z^{4} + 315 z^{3} + 6300 z^{2} + 99225 z + 893025} & \text{for}\: L = 5 \\\frac{118394977 z^{6}}{z^{6} + 21 z^{5} + 630 z^{4} + 18900 z^{3} + 496125 z^{2} + 9823275 z + 108056025} & \text{for}\: L = 6 \\\frac{19727003738 z^{7}}{z^{7} + 28 z^{6} + 1134 z^{5} + 47250 z^{4} + 1819125 z^{3} + 58939650 z^{2} + 1404728325 z + 18261468225} & \text{for}\: L = 7 \\\frac{4392846440677 z^{8}}{z^{8} + 36 z^{7} + 1890 z^{6} + 103950 z^{5} + 5457375 z^{4} + 255405150 z^{3} + 9833098275 z^{2} + 273922023375 z + 4108830350625} & \text{for}\: L = 8 \end{cases}\end{split}\]
Hide code cell content
operations = []
runtime = []
for angular_momentum in range(9):
    ff2 = BlattWeisskopfSquared(angular_momentum, z)
    operations.append(sp.count_ops(ff2.doit()))
    n_iterations = 10
    t = timeit.timeit(
        setup=f"""
import sympy as sp
from ampform.dynamics import BlattWeisskopfSquared
z = sp.Symbol("z")
ff2 = BlattWeisskopfSquared({angular_momentum}, z)
    """,
        stmt='sp.lambdify(z, ff2.doit(), "numpy")',
        number=n_iterations,
    )
    runtime.append(t / n_iterations * 1_000)
Hide code cell source
df = pd.DataFrame(
    {
        "operations": operations,
        "runtime (ms)": runtime,
    },
)
df
operations runtime (ms)
0 0 0.81877
1 3 1.24712
2 7 1.64094
3 12 2.52622
4 14 2.29422
5 16 1.88900
6 19 2.24741
7 22 2.72068
8 25 3.01171
Hide code cell source
fig, ax = plt.subplots(figsize=(8, 4))
plt.scatter(x=df["operations"], y=df["runtime (ms)"])
ax.set_ylim(bottom=0)
ax.set_xlabel("operations")
ax.set_ylabel("runtime (ms)")
plt.show()

Identifying nodes#

Now imagine that we don’t know anything about the expression that we created before other than that it is a sympy.Expr.

Approach 1: Generator#

A first attempt is to use a generator to recursively identify components in the expression that lie within a certain ‘complexity’ (as computed by count_ops()).

def recurse_tree(
    expression: sp.Expr, *, min_complexity: int = 0, max_complexity: int
) -> Generator[sp.Expr, None, None]:
    for arg in expression.args:
        complexity = sp.count_ops(arg)
        if complexity < max_complexity and complexity > min_complexity:
            yield arg
        else:
            yield from recurse_tree(
                arg,
                min_complexity=min_complexity,
                max_complexity=max_complexity,
            )

We can then use this generator function to create a mapping of these sub-expressions within the expression tree to Symbols. That mapping can then be used in xreplace() to replace the sub-expressions with those symbols.

%%time
expression = model.expression.doit()
sub_expressions = {}
for i, expr in enumerate(recurse_tree(expression, max_complexity=100)):
    symbol = sp.Symbol(f"f{i}")
    complexity = sp.count_ops(expr)
    sub_expressions[expr] = symbol
expression.xreplace(sub_expressions)
CPU times: user 314 ms, sys: 135 Âľs, total: 314 ms
Wall time: 313 ms
\[\displaystyle \left|{f_{0} + f_{1}}\right|^{2} + \left|{f_{2} + f_{3}}\right|^{2} + \left|{f_{4} + f_{5}}\right|^{2} + \left|{f_{6} + f_{7}}\right|^{2}\]
Approach 2: Direct substitution#

There is one problem though: xreplace() is not accurate for larger expressions. It would therefore be better to directly substitute the sub-expression with a symbol while we loop over the nodes in the expression tree. The following function can do that:

def split_expression(
    expression: sp.Expr,
    max_complexity: int,
    min_complexity: int = 0,
) -> tuple[sp.Expr, dict[sp.Symbol, sp.Expr]]:
    i = 0
    symbol_mapping = {}

    def recursive_split(sub_expression: sp.Expr) -> sp.Expr:
        nonlocal i
        for arg in sub_expression.args:
            complexity = sp.count_ops(arg)
            if complexity < max_complexity and complexity > min_complexity:
                symbol = sp.Symbol(f"f{i}")
                i += 1
                symbol_mapping[symbol] = arg
                sub_expression = sub_expression.xreplace({arg: symbol})
            else:
                new_arg = recursive_split(arg)
                sub_expression = sub_expression.xreplace({arg: new_arg})
        return sub_expression

    top_expression = recursive_split(expression)
    return top_expression, symbol_mapping

And indeed, this is much faster than Approach 1: Generator (it’s even possible to parallelize this for loop):

%time
top_expression, sub_expressions = split_expression(expression, max_complexity=100)
CPU times: user 7 Âľs, sys: 1 Âľs, total: 8 Âľs
Wall time: 15.5 Âľs
top_expression
\[\displaystyle \left|{f_{0} + f_{1}}\right|^{2} + \left|{f_{2} + f_{3}}\right|^{2} + \left|{f_{4} + f_{5}}\right|^{2} + \left|{f_{6} + f_{7}}\right|^{2}\]
sub_expressions[sp.Symbol("f0")]
\[\displaystyle \frac{C[J/\psi(1S) \to f_{0}(980)_{0} \gamma_{+1}; f_{0}(980) \to \pi^{0}_{0} \pi^{0}_{0}] \Gamma_{f(0)(980)} m_{f(0)(980)} \left(\frac{\cos{\left(\theta_{1+2} \right)}}{2} + \frac{1}{2}\right) e^{i \phi_{1+2}}}{- \frac{i \Gamma_{f(0)(980)} m_{f(0)(980)} \sqrt{\frac{\left(m_{12}^{2} - \left(m_{1} - m_{2}\right)^{2}\right) \left(m_{12}^{2} - \left(m_{1} + m_{2}\right)^{2}\right)}{m_{12}^{2}}} \sqrt{m_{f(0)(980)}^{2}}}{\sqrt{\frac{\left(m_{f(0)(980)}^{2} - \left(m_{1} - m_{2}\right)^{2}\right) \left(m_{f(0)(980)}^{2} - \left(m_{1} + m_{2}\right)^{2}\right)}{m_{f(0)(980)}^{2}}} \left|{m_{12}}\right|} - m_{12}^{2} + m_{f(0)(980)}^{2}}\]
Lambdify and combine#

Now that we have the machinery to split up arbitrary expressions by complexity, we need to lambdify the top expression as well as each of the sub-expressions and recombine them. The following function can do that and return a recombined Callable.

def optimized_lambdify(
    args: Sequence[sp.Symbol],
    expr: sp.Expr,
    modules: str | None = None,
    min_complexity: int = 0,
    max_complexity: int = 100,
) -> Callable:
    top_expression, definitions = split_expression(
        expression,
        min_complexity=min_complexity,
        max_complexity=max_complexity,
    )
    top_symbols = sorted(definitions, key=lambda s: s.name)
    top_lambdified = sp.lambdify(top_symbols, top_expression, modules)
    sub_lambdified = [
        sp.lambdify(args, definitions[symbol], modules) for symbol in top_symbols
    ]

    def recombined_function(*args):
        new_args = [sub_expr(*args) for sub_expr in sub_lambdified]
        return top_lambdified(*new_args)

    return recombined_function

We can use the same input values as in Test with data to check that the resulting lambdified expression results in the same output.

%time
treewise_lambdified = optimized_lambdify(free_symbols, expression, "numpy")
CPU times: user 8 Âľs, sys: 1 Âľs, total: 9 Âľs
Wall time: 17.4 Âľs
treewise_result = treewise_lambdified(*args_values)
treewise_result
array([0.00048765, 0.00033425, 0.00524706, ..., 0.00140122, 0.00714365,
       0.00030117])

And it’s indeed the same as that the intensity computed by tensorwaves (direct lambdify):

mean_difference = (treewise_result - tensorwaves_result).mean()
mean_difference
-7.307471274905997e-11
Comparison#

Now have a look at a slightly more complicated model:

Hide code cell source
result = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [+1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism_type="canonical-helicity",
)
model_builder = ampform.get_builder(result)
for name in result.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
complex_model = model_builder.generate()
dot = qrules.io.asdot(result, collapse_graphs=True, render_final_state_id=False)
graphviz.Source(dot)

This makes it clear that the functions defined in Arbitrary expressions results in a huge speed-up!

new_expression = complex_model.expression.doit()
new_free_symbols = sorted(new_expression.free_symbols, key=lambda s: s.name)
%%time
np_expr = sp.lambdify(new_free_symbols, new_expression)
CPU times: user 4.57 s, sys: 3.16 ms, total: 4.57 s
Wall time: 4.57 s
%%time
np_expr = optimized_lambdify(new_free_symbols, new_expression)
CPU times: user 261 ms, sys: 87 Âľs, total: 262 ms
Wall time: 260 ms

Chew-Mandelstam#

This report is an attempt formulate the Chew-Mandelstam function described in PDG2021, §Resonances, p.13 (Section 50.3.5) with SymPy, so that it can be implemented in AmpForm.

Hide code cell content
%config InlineBackend.figure_formats = ['svg']
import inspect
import warnings
from functools import partial

import black
import matplotlib.pyplot as plt
import mpl_interactions.ipyplot as iplt
import numpy as np
import qrules
import quadpy
import symplot
import sympy as sp
from ampform.dynamics import (
    BlattWeisskopfSquared,
    BreakupMomentumSquared,
    ComplexSqrt,
    PhaseSpaceFactor,
    PhaseSpaceFactorComplex,
)
from IPython.display import Math

warnings.filterwarnings("ignore")
PDG = qrules.load_pdg()
S-wave#

As can be seen in Eq. (50.40) on PDG2021, §Resonances, p.13, the Chew-Mandelstam function \(\Sigma_a\) for a particle \(a\) decaying to particles \(1, 2\) has a simple form for angular momentum \(L=0\) (\(S\)-wave):

(1)#\[\Sigma_a(s) = \frac{1}{16\pi^2} \left[ \frac{2q_a}{\sqrt{s}} \log\frac{m_1^2+m_2^2-s+2\sqrt{s}q_a}{2m_1m_2} - \left(m_1^2-m_2^2\right) \left(\frac{1}{s}-\frac{1}{(m_1+m_2)^2}\right) \log\frac{m_1}{m_2} \right]\]

The only question is how to deal with negative values for the squared break-up momentum \(q_a^2\). Here, we will use AmpForm’s ComplexSqrt:

\[\begin{split}\displaystyle q_a = \sqrt[\mathrm{c}]{q_a^{2}} = \begin{cases} i \sqrt{- q_a^{2}} & \text{for}\: q_a^{2} < 0 \\\sqrt{q_a^{2}} & \text{otherwise} \end{cases}\end{split}\]
def breakup_momentum(s, m1, m2):
    q_squared = BreakupMomentumSquared(s, m1, m2)
    return ComplexSqrt(q_squared)


def chew_mandelstam_s_wave(s, m1, m2):
    # evaluate=False in order to keep same style as PDG
    q = breakup_momentum(s, m1, m2)
    left_term = sp.Mul(
        2 * q / sp.sqrt(s),
        sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2)),
        evaluate=False,
    )
    right_term = (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
    return sp.Mul(
        1 / (16 * sp.pi**2),
        left_term - right_term,
        evaluate=False,
    )

To check whether this implementation is correct, let’s plug some Symbols into this function and compare it to Eq. (50.40) on PDG2021, §Resonances, p.13:

s, m1, m2 = sp.symbols("s m1 m2", real=True)
chew_mandelstam_s_wave_expr = chew_mandelstam_s_wave(s, m1, m2)
chew_mandelstam_s_wave_expr
\[\displaystyle \frac{1}{16 \pi^{2}} \left(\frac{2 \sqrt[\mathrm{c}]{q^2\left(s\right)}}{\sqrt{s}} \log{\left(\frac{m_{1}^{2} + m_{2}^{2} + 2 \sqrt{s} \sqrt[\mathrm{c}]{q^2\left(s\right)} - s}{2 m_{1} m_{2}} \right)} - \left(m_{1}^{2} - m_{2}^{2}\right) \left(- \frac{1}{\left(m_{1} + m_{2}\right)^{2}} + \frac{1}{s}\right) \log{\left(\frac{m_{1}}{m_{2}} \right)}\right)\]

It should be noted that this equation is not well-defined along the real axis, that is, for \(\mathrm{Im}(s) = 0\). For this reason, we split \(s\) into a real part \(s'\) with a small imaginary offset (the PDG indicates this with \(s+0i\)). We parametrized this imaginary offset with \(\epsilon\), and for the interactive plot, we do so with a power of \(10\):

epsilon = sp.Symbol("epsilon", positive=True)
s_prime = sp.Symbol(R"s^{\prime}", real=True)
s_plus = s_prime + sp.I * sp.Pow(10, -epsilon)
\[\displaystyle s \to s^{\prime} + 10^{- \epsilon} i\]

We are now ready to use mpl_interactions and AmpForm’s symplot to visualize this function:

chew_mandelstam_s_wave_prime = chew_mandelstam_s_wave_expr.subs(s, s_plus)
np_chew_mandelstam_s_wave, sliders = symplot.prepare_sliders(
    expression=chew_mandelstam_s_wave_prime.doit(),
    plot_symbol=s_prime,
)
np_phase_space_factor = sp.lambdify(
    args=(s_prime, m1, m2, epsilon),
    expr=PhaseSpaceFactorComplex(s_plus, m1, m2).doit(),
    modules="numpy",
)

As starting values for the interactive plot, we assume \(\pi\eta\) scattering (just like in the PDG section) and use their masses as values for \(m_1\) and \(m_1\), respectively.

s_min, s_max = -0.15, 1.4
m1_val = PDG["pi0"].mass
m2_val = PDG["eta"].mass

plot_domain = np.linspace(s_min, s_max, 500)
sliders.set_ranges(
    m1=(0, 2, 200),
    m2=(0, 2, 200),
    epsilon=(1, 12),
)
sliders.set_values(
    m1=m1_val,
    m2=m2_val,
    epsilon=4,
)

For comparison, we plot the Chew-Mandelstam function for \(S\)-waves next to AmpForm’s PhaseSpaceFactorComplex. Have a look at the resulting plots and compare to Figure 50.4 on PDG2021, §Resonances, p.12.

Hide code cell content
fig, axes = plt.subplots(ncols=2, figsize=(11, 4.5), tight_layout=True)
ax1, ax2 = axes
for ax in axes:
    ax.axhline(0, linewidth=0.5, c="black")

real_style = {"label": "Real part", "c": "black", "linestyle": "dashed"}
imag_style = {"label": "Imag part", "c": "red"}
threshold_style = {"label": R"$s_\mathrm{thr}$", "c": "grey", "linewidth": 0.5}

ylim = (-1, +1)
y_factor = 16 * np.pi
controls = iplt.axvline(
    lambda *args, **kwargs: (kwargs["m1"] + kwargs["m2"]) ** 2,
    **sliders,
    ax=ax1,
    **threshold_style,
)
iplt.axvline(
    lambda *args, **kwargs: (kwargs["m1"] + kwargs["m2"]) ** 2,
    controls=controls,
    ax=ax2,
    **threshold_style,
)
iplt.plot(
    plot_domain,
    lambda *args, **kwargs: (
        y_factor * 1j * np_phase_space_factor(*args, **kwargs)
    ).real,
    controls=controls,
    ylim=ylim,
    alpha=0.7,
    ax=ax1,
    **real_style,
)
iplt.plot(
    plot_domain,
    lambda *args, **kwargs: (
        y_factor * 1j * np_phase_space_factor(*args, **kwargs)
    ).imag,
    controls=controls,
    ylim=ylim,
    alpha=0.7,
    ax=ax1,
    **imag_style,
)

iplt.plot(
    plot_domain,
    lambda *args, **kwargs: y_factor
    * np_chew_mandelstam_s_wave(*args, **kwargs).real,
    controls=controls,
    ylim=ylim,
    alpha=0.7,
    ax=ax2,
    **real_style,
)
iplt.plot(
    plot_domain,
    lambda *args, **kwargs: y_factor
    * np_chew_mandelstam_s_wave(*args, **kwargs).imag,
    controls=controls,
    ylim=ylim,
    alpha=0.7,
    ax=ax2,
    **imag_style,
)

for ax in axes:
    ax.legend(loc="lower right")
    ax.set_xticks(np.arange(0, 1.21, 0.3))
    ax.set_yticks(np.arange(-1, 1.1, 0.5))
    ax.set_xlabel("$s$ (GeV$^2$)")

ax1.set_ylabel(R"$16\pi \; i\rho(s)$")
ax2.set_ylabel(R"$16\pi \; \Sigma(s)$")
ax1.set_title(R"Complex phase space factor $\rho$")
ax2.set_title("Chew-Mandelstam $S$-wave ($L=0$)")
plt.show()
https://user-images.githubusercontent.com/29308176/164984924-764a9558-6afd-46a9-8f24-8cc92ce1bc49.svg
General dispersion integral#

For higher angular momenta, the PDG notes that one has to compute the dispersion integral given by Eq. (50.41) on PDG2021, §Resonances, p.13:

(2)#\[ \Sigma_a(s+0i) = \frac{s-s_{\mathrm{thr}_a}}{\pi} \int^\infty_{s_{\mathrm{thr}_a}} \frac{ \rho_a(s')n_a^2(s') }{ (s' - s_{\mathrm{thr}_a})(s'-s-i0) } \mathop{}\!\mathrm{d}s' \]

Equation (1) is the analytic solution for \(L=0\).

From Equations (50.26-27) on PDG2021, §Resonances, p.9, it can be deduced that the function \(n_a^2\) is the same as AmpForm’s BlattWeisskopfSquared (note that this function is normalized, whereas the PDG’s \(F_j\) function has \(1\) in the nominator). Furthermore, the PDG seems to suggest that \(z = q_a/q_0\), but this is an unconventional choice and is probably a mistake. For this reason, we simply use BlattWeisskopfSquared for the definition of \(n_a^2\):

def na2(s, m1, m2, L, q0):
    q_squared = BreakupMomentumSquared(s, m1, m2)
    return BlattWeisskopfSquared(
        z=q_squared / (q0**2),
        angular_momentum=L,
    )

For \(\rho_a\), we use AmpForm’s PhaseSpaceFactor:

q0 = sp.Symbol("q0", real=True)
L = sp.Symbol("L", integer=True, positive=True)
s_thr = (m1 + m2) ** 2
integrand = (PhaseSpaceFactor(s_prime, m1, m2) * na2(s_prime, m1, m2, L, q0)) / (
    (s_prime - s_thr) * (s_prime - s - epsilon * sp.I)
)
integrand
\[\displaystyle \frac{B_{L}^2\left(\frac{q^2\left(s^{\prime}\right)}{q_{0}^{2}}\right) \rho\left(s^{\prime}\right)}{\left(s^{\prime} - \left(m_{1} + m_{2}\right)^{2}\right) \left(- i \epsilon - s + s^{\prime}\right)}\]

Next, we lambdify() this integrand to a numpy expression so that we can integrate it efficiently:

np_integrand = sp.lambdify(
    args=(s_prime, s, L, epsilon, m1, m2, q0),
    expr=integrand.doit(),
    modules="numpy",
)

As discussed in TR-016, scipy.integrate.quad() cannot integrate over complex-valued functions, while quadpy runs into trouble with vectorized input to the integrand. The following function, from Vectorized input offers a quick solution:

@np.vectorize
def vectorized_quad(func, a, b, **func_kwargs):
    values, _ = quadpy.quad(partial(func, **func_kwargs), a, b)
    return values

Note

Integrals can be expressed with SymPy, with some caveats. See SymPy integral.

Now, for comparison, we compute this integral for a few values of \(L>0\):

s_domain = np.linspace(s_min, s_max, num=50)
max_L = 3
l_values = list(range(1, max_L + 1))
print("Computing for L ∈", l_values)
Computing for L ∈ [1, 2, 3]

It is handy to store the resulting values of each dispersion integral in a dict with \(L\) as keys:

%%time
s_thr_val = float(s_thr.subs({m1: m1_val, m2: m2_val}))
integral_values = {
    l_val: vectorized_quad(
        np_integrand,
        a=s_thr_val,
        b=np.inf,
        s=s_domain,
        L=l_val,
        epsilon=1e-3,
        m1=m1_val,
        m2=m2_val,
        q0=1.0,
    )
    for l_val in l_values
}
CPU times: user 2.61 s, sys: 7.39 ms, total: 2.62 s
Wall time: 2.62 s

Finally, as can be seen from Eq. (2), the resulting values from the integral have to be shifted with a factor \(\frac{s-s_{\mathrm{thr}_a}}{\pi}\) to get \(\Sigma_a\). We also scale the values with \(16\pi\) so that it can be compared with the plot generated in S-wave.

sigma = {
    l_val: (s_domain - s_thr_val) / np.pi * integral_values[l_val]
    for l_val in l_values
}
sigma_scaled = {l_val: 16 * np.pi * sigma[l_val] for l_val in l_values}

Chew-Mandelstam for higher angular momenta

Note

In SymPy expressions we’ll see that the dispersion integral indeed reproduces the same shape as the analytic expression from S-wave.

SymPy expressions#

In the following, we attempt to implement Equation (2) using SymPy integral.

Hide code cell content
from sympy.printing.pycode import _unpack_integral_limits


class UnevaluatableIntegral(sp.Integral):
    abs_tolerance = 1e-5
    rel_tolerance = 1e-5
    limit = 50

    def doit(self, **hints):
        args = [arg.doit(**hints) for arg in self.args]
        return self.func(*args)

    def _numpycode(self, printer, *args):
        integration_vars, limits = _unpack_integral_limits(self)
        if len(limits) != 1:
            msg = f"Cannot handle {len(limits)}-dimensional integrals"
            raise ValueError(msg)
        integrate = "quadpy_quad"
        printer.module_imports["quadpy"].update({f"quad as {integrate}"})
        limit_str = "{}, {}".format(*tuple(map(printer._print, limits[0])))
        args = ", ".join(map(printer._print, integration_vars))
        expr = printer._print(self.args[0])
        return (
            f"{integrate}(lambda {args}: {expr}, {limit_str},"
            f" epsabs={self.abs_tolerance}, epsrel={self.abs_tolerance},"
            f" limit={self.limit})[0]"
        )
def dispersion_integral(
    s,
    m1,
    m2,
    angular_momentum,
    meson_radius=1,
    s_prime=sp.Symbol("x", real=True),
    epsilon=sp.Symbol("epsilon", positive=True),
):
    s_thr = (m1 + m2) ** 2
    q_squared = BreakupMomentumSquared(s_prime, m1, m2)
    ff_squared = BlattWeisskopfSquared(
        angular_momentum=L, z=q_squared * meson_radius**2
    )
    phsp_factor = PhaseSpaceFactor(s_prime, m1, m2)
    return sp.Mul(
        (s - s_thr) / sp.pi,
        UnevaluatableIntegral(
            (phsp_factor * ff_squared)
            / (s_prime - s_thr)
            / (s_prime - s - sp.I * epsilon),
            (s_prime, s_thr, sp.oo),
        ),
        evaluate=False,
    )


x = sp.Symbol("x", real=True)
integral_expr = dispersion_integral(s, m1, m2, angular_momentum=L, s_prime=x)
integral_expr
\[\displaystyle \frac{s - \left(m_{1} + m_{2}\right)^{2}}{\pi} \int\limits_{\left(m_{1} + m_{2}\right)^{2}}^{\infty} \frac{B_{L}^2\left(q^2\left(x\right)\right) \rho\left(x\right)}{\left(x - \left(m_{1} + m_{2}\right)^{2}\right) \left(- i \epsilon - s + x\right)}\, dx\]

Warning

We have to keep track of the integration variable (\(s'\) in Equation (2)), so that we don’t run into trouble if we use lambdify() with common sub-expressions. The problem is that the integration variable should not be extracted as a common sub-expression, otherwise the lambdified quadpy.quad() expression cannot handle vectorized input.

To keep the function under the integral simple, we substitute angular momentum \(L\) with a definite value before we lambdify:

UnevaluatableIntegral.abs_tolerance = 1e-4
UnevaluatableIntegral.rel_tolerance = 1e-4
integral_func_s_wave = sp.lambdify(
    [s, m1, m2, epsilon],
    integral_expr.subs(L, 0).doit(),
    # integration symbol should not be extracted as common sub-expression!
    cse=partial(sp.cse, ignore=[x], list=False),
)
integral_func_s_wave = np.vectorize(integral_func_s_wave)

integral_func_p_wave = sp.lambdify(
    [s, m1, m2, epsilon],
    integral_expr.subs(L, 1).doit(),
    cse=partial(sp.cse, ignore=[x], list=False),
)
integral_func_p_wave = np.vectorize(integral_func_p_wave)
Hide code cell source
src = inspect.getsource(integral_func_s_wave.pyfunc)
src = black.format_str(src, mode=black.FileMode())
print(src)
def _lambdifygenerated(s, m1, m2, epsilon):
    x0 = pi ** (-1.0)
    x1 = (m1 + m2) ** 2
    x2 = -x1
    return (
        x0
        * (s + x2)
        * quadpy_quad(
            lambda x: (1 / 16)
            * x0
            * sqrt((x + x2) * (x - (m1 - m2) ** 2) / x)
            / (sqrt(x) * (x + x2) * (-1j * epsilon - s + x)),
            x1,
            PINF,
            epsabs=0.0001,
            epsrel=0.0001,
            limit=50,
        )[0]
    )
s_values = np.linspace(-0.15, 1.4, num=200)
%time s_wave_values = integral_func_s_wave(s_values, m1_val, m2_val, epsilon=1e-5)
%time p_wave_values = integral_func_p_wave(s_values, m1_val, m2_val, epsilon=1e-5)
CPU times: user 5.13 s, sys: 0 ns, total: 5.13 s
Wall time: 5.13 s
CPU times: user 2.41 s, sys: 0 ns, total: 2.41 s
Wall time: 2.41 s

Note that the dispersion integral for \(L=0\) indeed reproduces the same shape as in S-wave!

Hide code cell source
s_wave_values *= 16 * np.pi
p_wave_values *= 16 * np.pi

s_values = np.linspace(-0.15, 1.4, num=200)
fig, axes = plt.subplots(nrows=2, figsize=(6, 7), sharex=True)
ax1, ax2 = axes
fig.suptitle(
    f"Symbolic dispersion integrals for $m_1={m1_val:.2f}, m_2={m2_val:.2f}$"
)
for ax in axes:
    ax.axhline(0, linewidth=0.5, c="black")
    ax.axvline(s_thr_val, **threshold_style)
    ax.set_title(f"$L = {l_val}$")
    ax.set_ylabel(R"$16\pi \; \Sigma(s)$")
axes[-1].set_xlabel("$s$ (GeV$^2$)")

ax1.set_title("$S$-wave ($L=0$)")
ax1.plot(s_values, s_wave_values.real, **real_style)
ax1.plot(s_values, s_wave_values.imag, **imag_style)

ax2.set_title("$P$-wave ($L=1$)")
ax2.plot(s_values, p_wave_values.real, **real_style)
ax2.plot(s_values, p_wave_values.imag, **imag_style)

ax1.legend()
fig.tight_layout()
plt.show()

Symbolic Chew-Mandelstam plots

Analyticity#

Hide code cell content
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.dynamics import PhaseSpaceFactor, relativistic_breit_wigner_with_ff
from IPython.display import Math, display
from ipywidgets import widgets
from matplotlib import cm
from mpl_interactions import heatmap_slicer

warnings.filterwarnings("ignore")

plt.rcParams.update({"font.size": 14})

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)
Branch points of \(\rho(s)\)#

Investigation of Section 2.1.2 in [Aitchison, 2015].

Hide code cell source
s = sp.Symbol("s")
m1, m2 = sp.symbols("m1 m2", real=True)
rho = 16 * sp.pi * PhaseSpaceFactor(s, m1, m2).doit()
rho
\[\displaystyle \frac{\sqrt{\frac{\left(s - \left(m_{1} - m_{2}\right)^{2}\right) \left(s - \left(m_{1} + m_{2}\right)^{2}\right)}{s}}}{\sqrt{s}}\]

Or, assuming both decay products to be of unit mass:

Hide code cell source
rho.subs({
    m1: 1,
    m2: 1,
})
\[\displaystyle \frac{\sqrt{s - 4}}{\sqrt{s}}\]
Hide code cell content
np_rho = sp.lambdify((s, m1, m2), rho, "numpy")

m1_val = 1.8
m2_val = 0.5
s_thr = (m1_val + m2_val) ** 2
s_diff = abs(m1_val - m2_val) ** 2

x = np.linspace(-1, +7, num=100)
y = np.linspace(-2, +2, num=100)
X, Y = np.meshgrid(x, y)
s_values = X + Y * 1j
rho_values = np_rho(s_values, m1=m1_val, m2=m2_val)
_images/0d4492eae102215748d5b9a29f44168bd01fa46518c6c6ae23c1a2d751312d22.png
Hide code cell source
fig, axes = heatmap_slicer(
    x,
    y,
    (rho_values.real, rho_values.imag),
    heatmap_names=(R"Re($\rho$)", R"Im($\rho$)"),
    labels=("Re($s$)", "Im($s$)"),
    interaction_type="move",
    slices="both",
    vmin=-5,
    vmax=5,
    figsize=(12, 3),
)
for ax in axes[2:]:
    ax.set_ylim(rho_min, rho_max)
    tick_width = 5
    tick_min = np.around(rho_min / tick_width, decimals=0) * tick_width
    ax.set_yticks(np.arange(tick_min, rho_max + 0.1, 5))
axes[2].set_title("Re($s$)")
axes[3].set_title("Im($s$)")
for ax in axes[:3]:
    ax.axvline(s_diff, c="black", linewidth=0.3, linestyle="dotted")
    ax.axvline(s_thr, c="black", linewidth=0.3, linestyle="dotted")
for ax in axes:
    ax.axvline(0, c="black", linewidth=0.5)
    ax.axhline(0, c="black", linewidth=0.5)
axes[3].axvline(0, c="black", linewidth=0.5)
plt.show()
_images/951312927dfbc8ef09ea561efa499789a1d8e33694795e26c7044aa26cb317e6.png
Physical vs. unphysical sheet#

Interactive reproduction of Figure 49.1 on PDG2020, §Resonances, p.2. The formulas below come from a relativistic_breit_wigner_with_ff() with \(L=0\). As phase space factor, we used the square root of BreakupMomentumSquared instead of the default PhaseSpaceFactor, because this introduces only one branch point in the \(s\)-plane (namely the one over the nominator).

Hide code cell source
from ampform.dynamics import BreakupMomentumSquared


def breakup_momentum(s: sp.Symbol, m_a: sp.Symbol, m_b: sp.Symbol) -> sp.Expr:
    return sp.sqrt(BreakupMomentumSquared(s, m_a, m_b).doit())


s = sp.Symbol("s")
m0, gamma0, m1, m2 = sp.symbols("m0 Gamma0 m1 m2", real=True, positive=True)

unphysical_amp = relativistic_breit_wigner_with_ff(
    s,
    m0,
    gamma0,
    m_a=m1,
    m_b=m2,
    angular_momentum=0,
    meson_radius=1,
    phsp_factor=breakup_momentum,
).doit()

sqrt_term = unphysical_amp.args[2].args[0].args[2]
physical_amp = unphysical_amp.subs(sqrt_term, sp.sqrt(sqrt_term**2))

display(
    Math(R"\mathrm{Physical:} \quad " + sp.latex(physical_amp)),
    Math(R"\mathrm{Unphysical:} \quad " + sp.latex(unphysical_amp)),
)
\[\displaystyle \mathrm{Physical:} \quad \frac{\Gamma_{0} m_{0}}{\Gamma_{0} m_{0}^{2} \sqrt{- \frac{\left(s - \left(m_{1} - m_{2}\right)^{2}\right) \left(s - \left(m_{1} + m_{2}\right)^{2}\right)}{s \left(m_{0}^{2} - \left(m_{1} - m_{2}\right)^{2}\right) \left(m_{0}^{2} - \left(m_{1} + m_{2}\right)^{2}\right)}} + m_{0}^{2} - s}\]
\[\displaystyle \mathrm{Unphysical:} \quad \frac{\Gamma_{0} m_{0}}{- \frac{i \Gamma_{0} m_{0}^{2} \sqrt{\frac{\left(s - \left(m_{1} - m_{2}\right)^{2}\right) \left(s - \left(m_{1} + m_{2}\right)^{2}\right)}{s}}}{\sqrt{\left(m_{0}^{2} - \left(m_{1} - m_{2}\right)^{2}\right) \left(m_{0}^{2} - \left(m_{1} + m_{2}\right)^{2}\right)}} + m_{0}^{2} - s}\]
Hide code cell source
args = (s, m0, gamma0, m1, m2)
np_amp_physical = sp.lambdify(args, physical_amp, "numpy")
np_amp_unphysical = sp.lambdify(args, unphysical_amp, "numpy")

x_min, x_max = -0.2, 1.3
y_min, y_max = -1.8, +1.8
z_min, z_max = -2.5, +2.5

x = np.linspace(x_min, x_max, num=50)
y_neg = np.linspace(y_min, -1e-4, num=30)
y_pos = np.linspace(1e-4, y_max, num=30)

X, Y_neg = np.meshgrid(x, y_neg)
X, Y_pos = np.meshgrid(x, y_pos)
s_values_neg = X + Y_neg * 1j
s_values_pos = X + Y_pos * 1j

z_cut_min = 0.75 * z_min
z_cut_max = 0.75 * z_max
cut_off_min = np.vectorize(lambda z: z if z > z_cut_min else z_cut_min)
cut_off_max = np.vectorize(lambda z: z if z < z_cut_max else z_cut_max)

plot_style = {
    "linewidth": 0,
    "alpha": 0.7,
    "antialiased": True,
    "rstride": 1,
    "cstride": 1,
}
axis_style = {
    "c": "black",
    "linewidth": 0.7,
    "linestyle": "dashed",
}

fig, axes = plt.subplots(
    ncols=2,
    figsize=(10, 6),
    subplot_kw={"projection": "3d"},
    tight_layout=True,
)
ax1, ax2 = axes
fig.suptitle("$S$-wave Breit-Wigner ($L=0$) plotted over the complex $s$-plane")

m0_min = np.sign(x_min) * np.sqrt(np.abs(x_min))
m0_max = np.sign(x_max) * np.sqrt(np.abs(x_max))

sliders = {
    "m0": widgets.FloatSlider(
        min=m0_min,
        max=m0_max,
        value=0.8,
        step=0.01,
        description="$m_0$",
    ),
    "gamma0": widgets.FloatSlider(
        min=0.0,
        max=y_max,
        value=0.3,
        step=0.01,
        description=R"$\Gamma_0$",
    ),
    "m1": widgets.FloatSlider(
        min=1e-4,
        max=m0_max / 2,
        step=0.01,
        description="$m_1$",
    ),
    "m2": widgets.FloatSlider(
        min=1e-4,
        max=m0_max / 2,
        step=0.01,
        description="$m_2$",
    ),
}


@widgets.interact(**sliders)
def plot(m0, gamma0, m1, m2):
    def plot_expression(ax, amp, neg_color="green"):
        ax.clear()
        z_values_neg = amp(s_values_neg, m0, gamma0, m1, m2).imag
        z_values_pos = amp(s_values_pos, m0, gamma0, m1, m2).imag
        Z_neg = cut_off_min(cut_off_max(z_values_neg))
        Z_pos = cut_off_min(cut_off_max(z_values_pos))

        s_thr = (m1 + m2) ** 2
        x0 = x[x >= s_thr] + 1e-4j
        y0 = np.zeros(len(x0))
        z0 = amp(x0, m0, gamma0, m1, m2).imag

        ax.plot_surface(X, Y_neg, Z_neg, **plot_style, color=neg_color)
        ax.plot_surface(X, Y_pos, Z_pos, **plot_style, color="green")
        ax.plot(x0, y0, z0, linewidth=2.5, c="darkred", zorder=8)
        ax.scatter([x0[0]], [0], [z0[0]], c="darkred", s=20, zorder=9)

        ax.set_xlabel("Re($s$)", labelpad=-15)
        ax.set_ylabel("Im($s$)", labelpad=-15)
        ax.set_zlabel("Im($A$)", labelpad=-15)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        ax.set_zlim(z_min, z_max)

    plot_expression(ax1, np_amp_physical)
    plot_expression(ax2, np_amp_unphysical, neg_color="gold")

    ax1.text(x_min, y_max, z_max / 2, "physical sheet", c="green")
    ax2.text(x_min, y_min, -z_max, "unphysical sheet", c="gold")

    fig.canvas.draw_idle()
_images/4e480cc0843122bddd3f2626313ac7ca0c8f0263f34913d48ec001f42854c683.svg

K-matrix#

This report investigates how to implement \(K\)-matrix dynamics with SymPy. We here describe only the version that is not Lorentz-invariant, because it is simplest and allows us to check whether the case \(n_R=1, n=1\) (single resonance, single channel) reduces to a Breit-Wigner function. We followed the physics as described by PDG2020, §Resonances and [Chung et al., 1995, Peters, 2004, Meyer, 2008]. For the Lorentz-invariant version, see TR-009.

A brief overview of the origin of the \(\boldsymbol{K}\)-matrix is given first. This overview follows [Chung et al., 1995], but skips over quite a few details, as this is only an attempt to provide some context of what is going on.

Hide code cell content
from __future__ import annotations

import os
import warnings

import graphviz
import matplotlib.pyplot as plt
import mpl_interactions.ipyplot as iplt
import numpy as np
import symplot
import sympy as sp
from IPython.display import Math, display
from ipywidgets import widgets as ipywidgets
from matplotlib import cm
from mpl_interactions.controller import Controls

warnings.filterwarnings("ignore")
STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)
Physics#

The \(\boldsymbol{K}\)-matrix formalism is used to describe coupled, two-body scattering processes of the type \(c_id_i \to R \to a_ib_i\), with \(i\) representing each separate channel and \(R\) a number of resonances that these channels have in common.

Hide code cell source
dot = """
digraph {
    rankdir=LR;
    node [shape=point, width=0];
    edge [arrowhead=none];
    "Na" [shape=none, label="aᾢ"];
    "Nb" [shape=none, label="bᾢ"];
    "Nc" [shape=none, label="cᾢ"];
    "Nd" [shape=none, label="dᾢ"];
    { rank=same "Nc", "Nd" };
    { rank=same "Na", "Nb" };
    "Nc" -> "N0";
    "Nd" -> "N0";
    "N1" -> "Na";
    "N1" -> "Nb";
    "N0" -> "N1" [label="R"];
    "N0" [shape=none, label=""];
    "N1" [shape=none, label=""];
}
"""
graph = graphviz.Source(dot)
graph
https://user-images.githubusercontent.com/29308176/164994485-fc4843c3-856b-4853-857a-679e258cf7c8.svg
Partial wave expansion#

In amplitude analysis, the main aim is to express the differential cross section \(\frac{d\sigma}{d\Omega}\) (that is, the intensity distribution in each spherical direction \(\Omega=(\phi,\theta)\) as we can observe in experiments). This differential cross section can be expressed in terms of the scattering amplitude \(A\) as:

(1)#\[ \frac{d\sigma}{d\Omega} = \left|A(\Omega)\right|^2 \]

We can now further express \(A\) in terms of partial wave amplitudes by splitting it up in terms of its angular momentum components \(J\):

(2)#\[ A(\Omega) = \frac{1}{2q_i}\sum_J\left(2J+1\right) T^J(s) {D^J_{\lambda\mu}}^*\left(\phi,\theta,0\right) \]

with \(\lambda=\lambda_a-\lambda_b\) and \(\mu=\lambda_c-\lambda_d\) the helicity differences of the final and initial states \(ab,cd\).

The above sketch is just with one channel in mind, but the same holds true though for a number of channels \(n\), with the only difference that the \(T\) operator becomes a \(\boldsymbol{T}\)-matrix of rank \(n\).

Transition operator#

The important point is that we have now expressed \(A\) in terms of an angular part (depending on \(\Omega\)) and a dynamical part \(\boldsymbol{T}\) that depends on the Mandelstam variable \(s\).

The dynamical part \(\boldsymbol{T}\) is usually called the transition operator. The reason is that it describes the interacting part of the scattering operator \(\boldsymbol{S}\), which describes the (complex) amplitude \(\langle f|\boldsymbol{S}|i\rangle\) of an initial state \(|i\rangle\) transitioning to a final state \(|f\rangle\). The scattering operator describes both the non-interacting amplitude and the transition amplitude, so it relates to the transition operator as:[1]

(3)#\[ \boldsymbol{S} = \boldsymbol{I} + i\boldsymbol{T} \]

with \(\boldsymbol{I}\) the identity operator. With this in mind, there is an important restriction that the \(T\)-operator needs to comply with: unitarity. This means that \(\boldsymbol{S}\) should conserve probability, namely \(\boldsymbol{S}^\dagger\boldsymbol{S} = \boldsymbol{I}\).

K-matrix formalism#

Now there is a trick to ensure unitarity of \(\boldsymbol{S}\). We can express \(\boldsymbol{S}\) in terms of an operator \(\boldsymbol{K}\) by applying a Cayley transformation:

(4)#\[ \boldsymbol{S} = (\boldsymbol{I} + i\boldsymbol{K})(I - i\boldsymbol{K})^{-1} \]

Unitarity is conserved if \(K\) is real. Finally, the \(\boldsymbol{T}\)-matrix can be expressed in terms of \(\boldsymbol{K}\) as follows:

(5)#\[ \boldsymbol{T} = \boldsymbol{K} \left(\boldsymbol{I} - i\boldsymbol{K}\right)^{-1} \]
Resonances#

The challenge is now to choose a correct parametrization for the elements of \(\boldsymbol{K}\) so that it correctly describes the resonances we observe. There are several choices, but a common one is the following summation over the resonances \(R\):

(6)#\[ K_{ij} = \sum_R\frac{g_{R,i}^*g_{R,j}}{m_R^2-m^2} \]

with \(g_{R,i}\) the residue functions that can be further expressed as

(7)#\[ g_{R,i}=\gamma_{R,i}\sqrt{m_R\Gamma_R} \]
Implementation#

The challenge is to generate a correct parametrization for an arbitrary number of coupled channels \(n\) and an arbitrary number of resonances \(n_R\). Our approach is to construct an \(n \times n\) sympy.Matrix with Symbols as its elements. We then use substitute these Symbols with certain parametrizations using subs(). In order to generate symbols for \(n_R\) resonances and \(n\) channels, we use indexed symbols.

This approach is less elegant and (theoretically) slower than using MatrixSymbols. That approach is explored in TR-007.

It would be nice to use a Symbol to represent the number of channels \(n\) and specify its value later.

n_channels = sp.Symbol("n", integer=True, positive=True)

Unfortunately, this does not work well in the Matrix class. We therefore set variables \(n\) to a specific int value and define some other Symbols for the rest of the implementation.[2] The value we choose in this example is n_channels=1, because we want to see if this reproduces a Breit-Wigner function.[3]

n_channels = 1
i, j, R, n_resonances = sp.symbols("i j R n_R", integer=True, negative=False)
m = sp.Symbol("m", real=True)
M = sp.IndexedBase("m", shape=(n_resonances,))
Gamma = sp.IndexedBase("Gamma", shape=(n_resonances,))
gamma = sp.IndexedBase("gamma", shape=(n_resonances, n_channels))

The parametrization of \(K_{ij}\) from Eq. (6) can be expressed as follows:

def Kij(
    m: sp.Symbol,
    M: sp.IndexedBase,
    Gamma: sp.IndexedBase,
    gamma: sp.IndexedBase,
    i: int,
    j: int,
    n_resonances: int | sp.Symbol,
) -> sp.Expr:
    g_i = gamma[R, i] * sp.sqrt(M[R] * Gamma[R])
    g_j = gamma[R, j] * sp.sqrt(M[R] * Gamma[R])
    parametrization = (g_i * g_j) / (M[R] ** 2 - m**2)
    return sp.Sum(parametrization, (R, 0, n_resonances - 1))
Hide code cell source
n_R = sp.Symbol("n_R")
kij = Kij(m, M, Gamma, gamma, i, j, n_R)
Math("K_{ij} = " + f"{sp.latex(kij)}")
\[\displaystyle K_{ij} = \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,i} {\gamma}_{R,j} {m}_{R}}{- m^{2} + {m}_{R}^{2}}\]

We now define the \(\boldsymbol{K}\)-matrix in terms of a Matrix with IndexedBase instances as elements that can serve as Symbols. These Symbols will be substituted with the parametrization later. We could of course have inserted the parametrization directly, but this slows down matrix multiplication in the following steps.

K_symbol = sp.IndexedBase("K", shape=(n_channels, n_channels))
K = sp.Matrix(
    [[K_symbol[i, j] for j in range(n_channels)] for i in range(n_channels)]
)
display(K_symbol[i, j], K)
\[\displaystyle {K}_{i,j}\]
\[\displaystyle \left[\begin{matrix}{K}_{0,0}\end{matrix}\right]\]

The \(\boldsymbol{T}\)-matrix can now be computed from Eq. (5):

T = K * (sp.eye(n_channels) - sp.I * K).inv()
T
\[\displaystyle \left[\begin{matrix}\frac{{K}_{0,0}}{- i {K}_{0,0} + 1}\end{matrix}\right]\]

Next, we need to substitute the elements \(K_{i,j}\) with the parametrization we defined above:

T_subs = T.subs({
    K[i, j]: Kij(m, M, Gamma, gamma, i, j, n_resonances)
    for i in range(n_channels)
    for j in range(n_channels)
})
T_subs
\[\displaystyle \left[\begin{matrix}\frac{\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}}{- i \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + 1}\end{matrix}\right]\]

Warning

It is important to perform doit() after subs(), otherwise the Sum cannot be evaluated and there will be no warning of a failed substitution.

Now indeed, when taking \(n_R=1\), the resulting element from the \(\boldsymbol{T}\)-matrix looks like a Breit-Wigner function (compare relativistic_breit_wigner())!

Hide code cell source
n_resonances_val = 1
rel_bw = T_subs[0, 0].subs(n_resonances, n_resonances_val).doit()
if n_resonances_val == 1 or n == 2:
    rel_bw = rel_bw.simplify()
rel_bw
\[\displaystyle - \frac{{\Gamma}_{0} {\gamma}_{0,0}^{2} {m}_{0}}{m^{2} + i {\Gamma}_{0} {\gamma}_{0,0}^{2} {m}_{0} - {m}_{0}^{2}}\]
Generalization#

The above procedure has been condensed into a function that can handle an arbitrary number of resonances and an arbitrary number of channels.

def create_symbol_matrix(name: str, n: int) -> sp.Matrix:
    symbol = sp.IndexedBase("K", shape=(n, n))
    return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(n)])


def k_matrix(n_resonances: int, n_channels: int) -> sp.Matrix:
    # Define symbols
    m = sp.Symbol("m", real=True)
    M = sp.IndexedBase("m", shape=(n_resonances,))
    Gamma = sp.IndexedBase("Gamma", shape=(n_resonances,))
    gamma = sp.IndexedBase("gamma", shape=(n_resonances, n_channels))
    # Define K-matrix and T-matrix
    K = create_symbol_matrix("K", n_channels)
    T = K * (sp.eye(n_channels) - sp.I * K).inv()
    # Substitute elements
    return T.subs({
        K[i, j]: Kij(m, M, Gamma, gamma, i, j, n_resonances)
        for i in range(n_channels)
        for j in range(n_channels)
    })

Single channel, single resonance:

k_matrix(n_resonances=1, n_channels=1)[0, 0].doit().simplify()
\[\displaystyle - \frac{{\Gamma}_{0} {\gamma}_{0,0}^{2} {m}_{0}}{m^{2} + i {\Gamma}_{0} {\gamma}_{0,0}^{2} {m}_{0} - {m}_{0}^{2}}\]

Single channel, \(n_R\) resonances

k_matrix(n_resonances=sp.Symbol("n_R"), n_channels=1)[0, 0]
\[\displaystyle \frac{\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}}{- i \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + 1}\]

Two channels, one resonance (FlattĂŠ function):

k_matrix(n_resonances=1, n_channels=2)[0, 0].doit().simplify()
\[\displaystyle - \frac{{\Gamma}_{0} {\gamma}_{0,0}^{2} {m}_{0}}{m^{2} + i {\Gamma}_{0} {\gamma}_{0,0}^{2} {m}_{0} + i {\Gamma}_{0} {\gamma}_{0,1}^{2} {m}_{0} - {m}_{0}^{2}}\]

Two channels, \(n_R\) resonances:

expr = k_matrix(n_resonances=sp.Symbol("n_R"), n_channels=2)[0, 0]
Math(sp.multiline_latex("", expr))
\[\begin{split}\displaystyle \begin{align*} \mathtt{\text{}} = & \frac{\left(i \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} - 1\right) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}}{\left(\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}\right) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + i \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + i \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} - \left(\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0} {\gamma}_{R,1} {m}_{R}}{- m^{2} + {m}_{R}^{2}}\right)^{2} - 1} \\ & + \frac{i \left(\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0} {\gamma}_{R,1} {m}_{R}}{- m^{2} + {m}_{R}^{2}}\right)^{2}}{- \left(\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}\right) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} - i \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} - i \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + \left(\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma}_{R} {\gamma}_{R,0} {\gamma}_{R,1} {m}_{R}}{- m^{2} + {m}_{R}^{2}}\right)^{2} + 1} \end{align*}\end{split}\]
Visualization#

Now, let’s use matplotlib, mpl_interactions, and symplot to visualize the \(\boldsymbol{K}\)-matrix for arbitrary \(n\) and \(n_R\).

Hide code cell source
def plot_k_matrix(
    n_channels: int,
    n_resonances: int,
    title: str = "",
) -> None:
    # Convert to Symbol: symplot cannot handle IndexedBase
    i = sp.Symbol("i", integer=True, negative=False)
    expr = k_matrix(n_resonances, n_channels)[i, i].doit()
    expr = symplot.substitute_indexed_symbols(expr)
    np_expr, sliders = symplot.prepare_sliders(expr, plot_symbol=m)
    symbol_to_arg = {symbol: arg for arg, symbol in sliders._arg_to_symbol.items()}

    # Set plot domain
    x_min, x_max = 1e-3, 3
    y_min, y_max = -0.5, +0.5

    plot_domain = np.linspace(x_min, x_max, num=500)
    x_values = np.linspace(x_min, x_max, num=160)
    y_values = np.linspace(y_min, y_max, num=80)
    X, Y = np.meshgrid(x_values, y_values)
    plot_domain_complex = X + Y * 1j

    # Set slider values and ranges
    m0_values = np.linspace(x_min, x_max, num=n_resonances + 2)
    m0_values = m0_values[1:-1]
    for R in range(n_resonances):
        for i in range(n_channels):
            sliders.set_ranges({
                "i": (0, n_channels - 1),
                f"m{R}": (0, 3, 100),
                f"Gamma{R}": (-1, 1, 100),
                Rf"\gamma_{{{R},{i}}}": (0, 2, 100),
            })
            sliders.set_values({
                f"m{R}": m0_values[R],
                f"Gamma{R}": (R + 1) * 0.1,
                Rf"\gamma_{{{R},{i}}}": 1 - 0.1 * R + 0.1 * i,
            })

    # Create interactive plots
    controls = Controls(**sliders)
    fig, (ax_2d, ax_3d) = plt.subplots(
        nrows=2,
        figsize=(8, 6),
        sharex=True,
        tight_layout=True,
    )

    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.footer_visible = False
    if not title:
        title = (
            Rf"${n_channels} \times {n_channels}$ $K$-matrix"
            f" with {n_resonances} resonances"
        )
    fig.suptitle(title)

    ax_2d.set_ylabel("$|T|^{2}$")
    ax_2d.set_yticks([])
    ax_3d.set_xlabel("Re $m$")
    ax_3d.set_ylabel("Im $m$")
    ax_3d.set_xticks([])
    ax_3d.set_yticks([])
    ax_3d.set_facecolor("white")

    ax_3d.axhline(0, linewidth=0.5, c="black", linestyle="dotted")

    # 2D plot
    def plot(channel: int):
        def wrapped(*args, **kwargs) -> sp.Expr:
            kwargs["i"] = channel
            return np.abs(np_expr(*args, **kwargs)) ** 2

        return wrapped

    for i in range(n_channels):
        iplt.plot(
            plot_domain,
            plot(i),
            ax=ax_2d,
            controls=controls,
            ylim="auto",
            label=f"channel {i}",
        )
    if n_channels > 1:
        ax_2d.legend(loc="upper right")
    mass_line_style = {
        "c": "red",
        "alpha": 0.3,
    }
    for name in controls.params:
        if not name.startswith("m"):
            continue
        iplt.axvline(controls[name], ax=ax_2d, **mass_line_style)

    # 3D plot
    color_mesh = None
    resonances_indicators = []

    def plot3(*, z_cutoff, complex_rendering, **kwargs):
        nonlocal color_mesh
        Z = np_expr(plot_domain_complex, **kwargs)
        if complex_rendering == "imag":
            Z_values = Z.imag
            ax_title = "Re $T$"
        elif complex_rendering == "real":
            Z_values = Z.real
            ax_title = "Im $T$"
        elif complex_rendering == "abs":
            Z_values = np.abs(Z)
            ax_title = "$|T|$"
        else:
            raise NotImplementedError

        if n_channels == 1:
            ax_3d.set_title(ax_title)
        else:
            i = kwargs["i"]
            ax_3d.set_title(f"{ax_title}, channel {i}")

        if color_mesh is None:
            color_mesh = ax_3d.pcolormesh(X, Y, Z_values, cmap=cm.coolwarm)
        else:
            color_mesh.set_array(Z_values)
        color_mesh.set_clim(vmin=-z_cutoff, vmax=+z_cutoff)

        if resonances_indicators:
            for R, (line, text) in enumerate(resonances_indicators):
                mass = kwargs[f"m{R}"]
                line.set_xdata(mass)
                text.set_x(mass + (x_max - x_min) * 0.008)
        else:
            for R in range(n_resonances):
                mass = kwargs[f"m{R}"]
                resonances_indicators.append(
                    (
                        ax_3d.axvline(mass, **mass_line_style),
                        ax_3d.text(
                            x=mass + (x_max - x_min) * 0.008,
                            y=0.95 * y_min,
                            s=f"$m_{R}$",
                            c="red",
                        ),
                    ),
                )

    # Create switch for imag/real/abs
    name = "complex_rendering"
    sliders._sliders[name] = ipywidgets.RadioButtons(
        options=["imag", "real", "abs"],
        description=R"\(s\)-plane plot",
    )
    sliders._arg_to_symbol[name] = name

    # Create cut-off slider for z-direction
    name = "z_cutoff"
    sliders._sliders[name] = ipywidgets.FloatSlider(
        value=1.5,
        min=0.01,
        max=10,
        step=0.1,
        description=R"\(z\)-cutoff",
    )
    sliders._arg_to_symbol[name] = name

    # Create GUI
    sliders_copy = dict(sliders)
    h_boxes = []
    for R in range(n_resonances):
        buttons = [
            sliders_copy.pop(f"m{R}"),
            sliders_copy.pop(f"Gamma{R}"),
        ]
        if n_channels == 1:
            dummy_name = symbol_to_arg[Rf"\gamma_{{{R},0}}"]
            buttons.append(sliders_copy.pop(dummy_name))
        h_box = ipywidgets.HBox(buttons)
        h_boxes.append(h_box)
    remaining_sliders = sorted(sliders_copy.values(), key=lambda s: s.description)
    if n_channels == 1:
        remaining_sliders.remove(sliders["i"])
    ui = ipywidgets.VBox(h_boxes + remaining_sliders)
    output = ipywidgets.interactive_output(plot3, controls=sliders)
    display(ui, output)
plot_k_matrix(n_resonances=3, n_channels=1)

record

plot_k_matrix(n_resonances=2, n_channels=2)


Interactive 3D plots#

Hide code cell content
import os

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import sympy as sp
from IPython.display import display
from ipywidgets import widgets as ipywidgets
from matplotlib import cm
from matplotlib import widgets as mpl_widgets

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

This report illustrates how to interact with matplotlib 3D plots through Matplotlib sliders and ipywidgets. This might be implemented later on in symplot and/or mpl_interactions (see ianhi/mpl-interactions#89).

In this example, we create a surface plot (see plot_surface()) for the following function.

x, y, a, b = sp.symbols("x y a b")
expression = sp.sqrt(x**a + sp.sin(y / b) ** 2)
expression
\[\displaystyle \sqrt{x^{a} + \sin^{2}{\left(\frac{y}{b} \right)}}\]

The function is formulated with sympy, but we use lambdify() to express it as a numpy function.

numpy_function = sp.lambdify(
    args=(x, y, a, b),
    expr=expression,
    modules="numpy",
)

A surface plot has to be generated over a numpy.meshgrid(). This defines the \(xy\)-plane over which we want to plot our function.

x_min, x_max = 0.1, 2
y_min, y_max = -50, +50
x_values = np.linspace(x_min, x_max, num=20)
y_values = np.linspace(y_min, y_max, num=40)
X, Y = np.meshgrid(x_values, y_values)

The \(z\)-values for plot_surface() can now be simply computed as follows:

a_init = -0.5
b_init = 20
Z = numpy_function(X, Y, a=a_init, b=b_init)

We now want to create sliders for \(a\) and \(b\), so that we can live-update the surface plot through those sliders.

Matplotlib widgets#

Matplotlib provides its own way to define matplotlib.widgets.

fig1, ax1 = plt.subplots(ncols=1, subplot_kw={"projection": "3d"})

# Create sliders and insert them within the figure
plt.subplots_adjust(bottom=0.25)
a_slider = mpl_widgets.Slider(
    ax=plt.axes([0.2, 0.1, 0.65, 0.03]),
    label=f"${sp.latex(a)}$",
    valmin=-2,
    valmax=2,
    valinit=a_init,
)
b_slider = mpl_widgets.Slider(
    ax=plt.axes([0.2, 0.05, 0.65, 0.03]),
    label=f"${sp.latex(b)}$",
    valmin=10,
    valmax=50,
    valinit=b_init,
    valstep=1,
)


# Define what to do when a slider changes
def update_plot(val=None):
    a = a_slider.val
    b = b_slider.val
    ax1.clear()
    Z = numpy_function(X, Y, a, b)
    ax1.plot_surface(
        X,
        Y,
        Z,
        rstride=3,
        cstride=1,
        cmap=cm.coolwarm,
        antialiased=False,
    )
    ax1.set_xlabel(f"${sp.latex(x)}$")
    ax1.set_ylabel(f"${sp.latex(y)}$")
    ax1.set_zlabel(f"${sp.latex(expression)}$")
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_zticks([])
    ax1.set_facecolor("white")
    fig1.canvas.draw_idle()


a_slider.on_changed(update_plot)
b_slider.on_changed(update_plot)

# Plot the surface as initialization
update_plot()
plt.show()

Interactive inline matplotlib output

ipywidgets#

As an alternative, you can use ipywidgets. This package has lot more sliders to offer than Matplotlib, and they look nicer, but it only work within a Jupyter notebook.

For more info, see Using Interact.

Using interact#

Simplest option is to use the ipywidgets.interact() function:

fig2, ax2 = plt.subplots(ncols=1, subplot_kw={"projection": "3d"})


@ipywidgets.interact(a=(-2.0, 2.0), b=(10, 50))
def plot2(a=a_init, b=b_init):
    ax2.clear()
    Z = numpy_function(X, Y, a, b)
    ax2.plot_surface(
        X,
        Y,
        Z,
        rstride=3,
        cstride=1,
        cmap=cm.coolwarm,
        antialiased=False,
    )
    ax2.set_xlabel(f"${sp.latex(x)}$")
    ax2.set_ylabel(f"${sp.latex(y)}$")
    ax2.set_zlabel(f"${sp.latex(expression)}$")
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_zticks([])
    ax2.set_facecolor("white")
    fig2.canvas.draw_idle()
Using interactive_output#

You can have more control with ipywidgets.interactive_output(). That allows defining the sliders independently, so that you can arrange them as a user interface:

fig3, ax3 = plt.subplots(ncols=1, subplot_kw={"projection": "3d"})
a_ipyslider = ipywidgets.FloatSlider(
    description=f"${sp.latex(a)}$",
    value=a_init,
    min=-2,
    max=2,
    step=0.1,
    readout_format=".1f",
)
b_ipyslider = ipywidgets.IntSlider(
    description=f"${sp.latex(b)}$",
    value=b_init,
    min=10,
    max=50,
)


def plot3(a=a_init, b=b_init):
    ax3.clear()
    Z = numpy_function(X, Y, a, b)
    ax3.plot_surface(
        X,
        Y,
        Z,
        rstride=3,
        cstride=1,
        cmap=cm.coolwarm,
        antialiased=False,
    )
    ax3.set_xlabel(f"${sp.latex(x)}$")
    ax3.set_ylabel(f"${sp.latex(y)}$")
    ax3.set_zlabel(f"${sp.latex(expression)}$")
    ax3.set_xticks([])
    ax3.set_yticks([])
    ax3.set_zticks([])
    ax3.set_facecolor("white")
    fig3.canvas.draw_idle()


ui = ipywidgets.HBox([a_ipyslider, b_ipyslider])
output = ipywidgets.interactive_output(
    plot3, controls={"a": a_ipyslider, "b": b_ipyslider}
)
display(ui, output)

ipywidgets interactive output with interactive_output()

Plotly with ipywidgets#

3D plots with Plotly look a lot nicer and make it possible for the user to pan and zoom the 3D object. As an added bonus, Plotly figures render as interactive 3D objects in the static HTML Sphinx build.

Making 3D Plotly plots interactive with ipywidgets is quite similar to the previous examples with matplotlib. Two recommendations are:

  1. Set continuous_update=False, because plotly is slower than matplotlib in updating the figure.

  2. Save the camera orientation and update it after calling Figure.show().

  3. When embedding the notebook a static webpage with MyST-NB, avoid calling Figure.show() through ipywidgets.interactive_output(), because it causes the notebook to hang in some cycle (see CI for ComPWA/compwa.github.io@d9240f1). In the example below, the update_plotly() function is aborted if the notebook is run through Sphinx.

Hide code cell source
plotly_a = ipywidgets.FloatSlider(
    description=f"${sp.latex(a)}$",
    value=a_init,
    min=-2,
    max=2,
    step=0.1,
    continuous_update=False,
    readout_format=".1f",
)
plotly_b = ipywidgets.IntSlider(
    description=f"${sp.latex(b)}$",
    value=b_init,
    min=10,
    max=50,
    continuous_update=False,
)
plotly_controls = {"a": plotly_a, "b": plotly_b}

plotly_surface = go.Surface(
    x=X,
    y=Y,
    z=Z,
    surfacecolor=Z,
    colorscale="RdBu_r",
    name="Surface",
)
plotly_fig = go.Figure(data=[plotly_surface])
plotly_fig.update_layout(height=500)
if STATIC_WEB_PAGE:
    plotly_fig.show()


def update_plotly(a, b):
    if STATIC_WEB_PAGE:
        return
    Z = numpy_function(X, Y, a, b)
    camera_orientation = plotly_fig.layout.scene.camera
    plotly_fig.update_traces(
        x=X,
        y=Y,
        z=Z,
        surfacecolor=Z,
        selector=dict(name="Surface"),
    )
    plotly_fig.show()
    plotly_fig.update_layout(scene=dict(camera=camera_orientation))


plotly_ui = ipywidgets.HBox([plotly_a, plotly_b])
plotly_output = ipywidgets.interactive_output(update_plotly, plotly_controls)
display(plotly_ui, plotly_output)

MatrixSymbols#

Hide code cell content
import os

import sympy as sp
from IPython.display import display

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

Here are some examples of computations with MatrixSymbol.

N = sp.Symbol("n", integer=True, positive=True)
i, j = sp.symbols("i j", integer=True, negative=False)
K = sp.MatrixSymbol("K", N, N)
display(K, K[i, j], K[0, 0])
\[\displaystyle K\]
\[\displaystyle K_{i, j}\]
\[\displaystyle K_{0, 0}\]
A = sp.MatrixSymbol("A", N, N)
(A * K)[0, 0]
\[\displaystyle \sum_{i_{1}=0}^{n - 1} A_{0, i_{1}} K_{i_{1}, 0}\]

The important thing is that elements of a MatrixSymbol can be substituted:

K[0, 0].subs(K[0, 0], i)
\[\displaystyle i\]
A * K
\[\displaystyle A K\]
(A * K)[0, 0]
\[\displaystyle \sum_{i_{1}=0}^{n - 1} A_{0, i_{1}} K_{i_{1}, 0}\]

Now make the matrices \(2 \times 2\) by specifying \(n\):

A_n2 = A.subs(N, 2)
K_n2 = K.subs(N, 2)
(A_n2 * K_n2)[0, 0]
\[\displaystyle A_{0, 0} K_{0, 0} + A_{0, 1} K_{1, 0}\]
v, w, x, y = sp.symbols("v, w, x, y", real=True)
substitutions = {
    A_n2[0, 0]: v,
    A_n2[0, 1]: w,
    K_n2[0, 0]: x,
    K_n2[1, 0]: y,
}
(A_n2 * K_n2)[0, 0].subs(substitutions)
\[\displaystyle v x + w y\]

Indexed free symbols#

In TR-005, we made use of indexed symbols to create a \(\boldsymbol{K}\)-matrix. The problem with that approach is that IndexedBase and their resulting Indexed instances when taking indices behave strangely in an expression tree.

The following Expr uses a Symbol and a elements in IndexedBases (an Indexed instance):

import sympy as sp

x = sp.Symbol("x")
c = sp.IndexedBase("c")
alpha = sp.IndexedBase("alpha")
expression = c[0, 1] + alpha[2] * x
expression
\[\displaystyle x {\alpha}_{2} + {c}_{0,1}\]

Although seemingly there are just three free_symbols, there are actually five:

expression.free_symbols
{alpha, alpha[2], c, c[0, 1], x}

This becomes problematic when using lambdify(), particularly through symplot.prepare_sliders().

In addition, while c[0, 1] and alpha[2] are Indexed as expected, alpha and c are Symbols, not IndexedBase:

{s: type(s) for s in expression.free_symbols}
{c: sympy.core.symbol.Symbol,
 alpha[2]: sympy.tensor.indexed.Indexed,
 x: sympy.core.symbol.Symbol,
 c[0, 1]: sympy.tensor.indexed.Indexed,
 alpha: sympy.core.symbol.Symbol}

The expression tree partially explains this behavior:

import graphviz

dot = sp.dotprint(expression)
graphviz.Source(dot);

We would like to collapse the nodes under c[0, 1] and alpha[2] to two single Symbol nodes that are still nicely rendered as \(c_{0,1}\) and \(\alpha_2\). The following function does that and converts the [] into subscripts. It does that in such a way that the name of the Symbol remains as short as possible, that is, short enough that it still renders nicely as LaTeX:

from sympy.printing.latex import translate


def to_symbol(idx: sp.Indexed) -> sp.Symbol:
    base_name, _, _ = str(idx).rpartition("[")
    subscript = ",".join(map(str, idx.indices))
    if len(idx.indices) > 1:
        base_name = translate(base_name)
        subscript = "_{" + subscript + "}"
    return sp.Symbol(f"{base_name}{subscript}")

Next, we use subs() to substitute the nodes c[0, 1] and alpha[2] with these Symbols:

def replace_indexed_symbols(expression: sp.Expr) -> sp.Expr:
    return expression.subs({
        s: to_symbol(s) for s in expression.free_symbols if isinstance(s, sp.Indexed)
    })

And indeed, the expression tree has been simplified correctly!

new_expression = replace_indexed_symbols(expression)
dot = sp.dotprint(new_expression)
graphviz.Source(dot);

Lorentz-invariant K-matrix#

Physics#

The Lorentz-invariant description \(\boldsymbol{\hat{T}}\) of the \(\boldsymbol{T}\)-matrix is:

(1)#\[ \boldsymbol{T} = \sqrt{\boldsymbol{\rho^\dagger}} \, \boldsymbol{\hat{T}} \sqrt{\boldsymbol{\rho}} \]

with the phase space factor matrix \(\boldsymbol{\rho}\) defined as:

(2)#\[\begin{split} \sqrt{\boldsymbol{\rho}} = \begin{pmatrix} \rho_0 & \cdots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \cdots & \rho_{n-1} \end{pmatrix} \end{split}\]

and

(3)#\[ \rho_i = \frac{2q_i}{m} = \sqrt{ \left[1-\left(\frac{m_{i,a}+m_{i,b}}{m}\right)^2\right] \left[1-\left(\frac{m_{i,a}-m_{i,b}}{m}\right)^2\right] } \]

This results in a similar transformation for the \(\boldsymbol{K}\)-matrix

(4)#\[ \boldsymbol{K} = \sqrt{\boldsymbol{\rho^\dagger}} \; \boldsymbol{\hat{K}} \sqrt{\boldsymbol{\rho}} \]

with (compare Eq. (5) in TR-005):

(5)#\[ \boldsymbol{\hat{T}} = \boldsymbol{\hat{K}}(\boldsymbol{I} - i\boldsymbol{\rho}\boldsymbol{\hat{K}})^{-1} \]

It’s common to integrate these phase space factors into the parametrization of \(K_{ij}\) as well:

(6)#\[ K_{ij} = \sum_R \frac{g_{R,i}(m)g_{R,j}(m)}{\left(m_R^2-m^2\right)\sqrt{\rho_i\rho_j}} \]

Compare this with Eq. (6) in TR-005.

In addition, one often uses an “energy dependent” coupled_width() \(\Gamma_R(m)\) instead of a fixed width \(\Gamma_R\) as done in TR-005.

Hide code cell content
from __future__ import annotations

import os
import re
import warnings
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import mpl_interactions.ipyplot as iplt
import numpy as np
import symplot
import sympy as sp
from ampform.dynamics import coupled_width, phase_space_factor_complex
from ampform.dynamics.decorator import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
)
from IPython.display import Math, display
from ipywidgets import widgets as ipywidgets
from matplotlib import cm
from mpl_interactions.controller import Controls

if TYPE_CHECKING:
    from sympy.printing.latex import LatexPrinter

warnings.filterwarnings("ignore")
STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)
Implementation#
Wrapping expressions#

To keep a nice rendering, we wrap the expressions for phase_space_factor() and coupled_width() into a class that derives from Expr (see e.g. the implementation of BlattWeisskopfSquared). Note that we need to use partial_doit() to keep these expression symbols after evaluating the Sum.

@implement_doit_method()
class PhaseSpaceFactor(UnevaluatedExpression):
    is_commutative = True

    def __new__(
        cls,
        s: sp.Symbol,
        m_a: sp.Symbol,
        m_b: sp.Symbol,
        i: int,
        **hints,
    ) -> PhaseSpaceFactor:
        return create_expression(cls, s, m_a, m_b, i, **hints)

    def evaluate(self) -> sp.Expr:
        s, m_a, m_b, *_ = self.args
        return phase_space_factor_complex(s, m_a, m_b)

    def _latex(self, printer: LatexPrinter, *args) -> str:
        s = printer._print(self.args[0])
        i = self.args[-1]
        return Rf"\rho_{{{i}}}({s})"


@implement_doit_method()
class CoupledWidth(UnevaluatedExpression):
    is_commutative = True

    def __new__(
        cls,
        s: sp.Symbol,
        mass0: sp.IndexedBase,
        gamma0: sp.IndexedBase,
        m_a: sp.IndexedBase,
        m_b: sp.IndexedBase,
        angular_momentum: int,
        R: int | sp.Symbol,
        i: int,
        **hints,
    ) -> CoupledWidth:
        return create_expression(
            cls, s, mass0, gamma0, m_a, m_b, angular_momentum, R, i, **hints
        )

    def evaluate(self) -> sp.Expr:
        s, mass0, gamma0, m_a, m_b, angular_momentum, R, i = self.args

        def phsp_factor(s, m_a, m_b):
            return PhaseSpaceFactor(s, m_a, m_b, i)

        return coupled_width(
            s,
            mass0[R],
            gamma0[R, i],
            m_a[i],
            m_b[i],
            angular_momentum=angular_momentum,
            meson_radius=1,
            phsp_factor=phsp_factor,
        )

    def _latex(self, printer: LatexPrinter, *args) -> str:
        s = printer._print(self.args[0])
        R = self.args[-2]
        i = self.args[-1]
        return Rf"{{\Gamma_{{{R},{i}}}}}({s})"

And here is what the equations look like:

n_channels = 2
n_resonances, i, R, L = sp.symbols("n_R, i, R, L", integer=True, negative=False)
m = sp.Symbol("m", real=True)
M = sp.IndexedBase("m", shape=(n_resonances,))
Gamma = sp.IndexedBase("Gamma", shape=(n_resonances, n_channels))
gamma = sp.IndexedBase("gamma", shape=(n_resonances, n_channels))
m_a = sp.IndexedBase("m_a", shape=(n_channels,))
m_b = sp.IndexedBase("m_b", shape=(n_channels,))
width_expr = CoupledWidth(m**2, M, Gamma, m_a, m_b, 0, R, i)
phsp_expr = PhaseSpaceFactor(m**2, m_a[i], m_b[i], i)
Hide code cell source
Math(
    sp.multiline_latex(
        lhs=width_expr,
        rhs=width_expr.evaluate(),
    )
)
\[\displaystyle \begin{align*} {\Gamma_{R,i}}(m^{2}) = & \frac{B_{0}^2\left(\frac{\left(m^{2} - \left({m_{a}}_{i} - {m_{b}}_{i}\right)^{2}\right) \left(m^{2} - \left({m_{a}}_{i} + {m_{b}}_{i}\right)^{2}\right)}{4 m^{2}}\right) {\Gamma}_{R,i} \rho_{i}(m^{2})}{B_{0}^2\left(\frac{\left(- \left({m_{a}}_{i} - {m_{b}}_{i}\right)^{2} + {m}_{R}^{2}\right) \left(- \left({m_{a}}_{i} + {m_{b}}_{i}\right)^{2} + {m}_{R}^{2}\right)}{4 {m}_{R}^{2}}\right) \rho_{i}({m}_{R}^{2})} \end{align*}\]
Hide code cell source
Math(
    sp.multiline_latex(
        lhs=phsp_expr,
        rhs=phsp_expr.doit().simplify().subs(sp.Abs(m), m),
    )
)
\[\displaystyle \begin{align*} \rho_{i}(m^{2}) = & \frac{\sqrt[\mathrm{c}]{\frac{\left(m^{2} - \left({m_{a}}_{i} - {m_{b}}_{i}\right)^{2}\right) \left(m^{2} - \left({m_{a}}_{i} + {m_{b}}_{i}\right)^{2}\right)}{4 m^{2}}}}{8 \pi m} \end{align*}\]

Note

In PhaseSpaceFactor, we used PhaseSpaceFactorComplex instead of PhaseSpaceFactor, meaning that we choose the positive square root when values under the square root are negative. The only reason for doing this is, so that there is output in the figure under Visualization. The choice for which square root to choose has to do with analyticity (see TR-004) and choosing which Riemann sheet to connect to. This issue is ignored in this report.

Generalization#

The implementation is quite similar to that of TR-005, with the only difference being additional \(\boldsymbol{\rho}\)-matrix and the insertion of coupled width. Don’t forget to convert back to \(\boldsymbol{T}\) from \(\boldsymbol{\hat{T}}\) with Eq. (1).

def Kij_relativistic(
    m: sp.Symbol,
    M: sp.IndexedBase,
    Gamma: sp.IndexedBase,
    gamma: sp.IndexedBase,
    i: int,
    j: int,
    n_resonances: int | sp.Symbol,
    angular_momentum: int | sp.Symbol = 0,
) -> sp.Expr:
    def residue_function(i):
        return gamma[R, i] * sp.sqrt(
            M[R] * CoupledWidth(m**2, M, Gamma, m_a, m_b, angular_momentum, R, i)
        )

    g_i = residue_function(i)
    g_j = residue_function(j)
    parametrization = (g_i * g_j) / (M[R] ** 2 - m**2)
    return sp.Sum(parametrization, (R, 0, n_resonances - 1))


def relativistic_k_matrix(
    n_resonances: int,
    n_channels: int,
    angular_momentum: int | sp.Symbol = 0,
) -> sp.Matrix:
    # Define symbols
    m = sp.Symbol("m", real=True)
    M = sp.IndexedBase("m", shape=(n_resonances,))
    Gamma = sp.IndexedBase("Gamma", shape=(n_resonances, n_channels))
    gamma = sp.IndexedBase("gamma", shape=(n_resonances, n_channels))
    m_a = sp.IndexedBase("m_a", shape=(n_channels,))
    m_b = sp.IndexedBase("m_b", shape=(n_channels,))
    # Define phase space matrix
    sqrt_rho = sp.zeros(n_channels, n_channels)
    sqrt_rho_dagger = sp.zeros(n_channels, n_channels)
    for i in range(n_channels):
        rho = PhaseSpaceFactor(m**2, m_a[i], m_b[i], i)
        sqrt_rho[i, i] = sp.sqrt(rho)
        sqrt_rho_dagger[i, i] = 1 / sp.conjugate(sqrt_rho[i, i])
    # Define K-matrix and T-matrix
    K = create_symbol_matrix("K", n_channels)
    T_hat = K * (sp.eye(n_channels) - sp.I * rho * K).inv()
    T = sqrt_rho_dagger * T_hat * sqrt_rho
    # Substitute elements
    return T.subs({
        K[i, j]: Kij_relativistic(
            m=m,
            M=M,
            Gamma=Gamma,
            gamma=gamma,
            i=i,
            j=j,
            n_resonances=n_resonances,
            angular_momentum=angular_momentum,
        )
        for i in range(n_channels)
        for j in range(n_channels)
    })


def create_symbol_matrix(name: str, n: int) -> sp.Matrix:
    symbol = sp.IndexedBase(name, shape=(n, n))
    return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(n)])

Single channel, one resonance (compare relativistic_breit_wigner_with_ff()):

expr = relativistic_k_matrix(n_resonances=1, n_channels=1)[0, 0]
Math(
    sp.multiline_latex(
        lhs=expr,
        rhs=symplot.partial_doit(expr, sp.Sum).simplify(doit=False),
    )
)
\[\displaystyle \begin{align*} \frac{\sqrt{\rho_{0}(m^{2})} \sum_{R=0}^{0} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}}{\left(- i \rho_{0}(m^{2}) \sum_{R=0}^{0} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + 1\right) \overline{\sqrt{\rho_{0}(m^{2})}}} = &- \frac{{\Gamma_{0,0}}(m^{2}) {\gamma}_{0,0}^{2} {m}_{0} \sqrt{\rho_{0}(m^{2})}}{\left(m^{2} + i {\Gamma_{0,0}}(m^{2}) {\gamma}_{0,0}^{2} {m}_{0} \rho_{0}(m^{2}) - {m}_{0}^{2}\right) \overline{\sqrt{\rho_{0}(m^{2})}}} \end{align*}\]

Two channels, one resonance (‘Flatté’):

expr = relativistic_k_matrix(n_resonances=1, n_channels=2)[0, 0]
symplot.partial_doit(expr, sp.Sum).simplify(doit=False)
\[\displaystyle \frac{{\Gamma_{0,0}}(m^{2}) {\gamma}_{0,0}^{2} {m}_{0} \sqrt{\rho_{0}(m^{2})}}{\left(- m^{2} - i {\Gamma_{0,0}}(m^{2}) {\gamma}_{0,0}^{2} {m}_{0} \rho_{1}(m^{2}) - i {\Gamma_{0,1}}(m^{2}) {\gamma}_{0,1}^{2} {m}_{0} \rho_{1}(m^{2}) + {m}_{0}^{2}\right) \overline{\sqrt{\rho_{0}(m^{2})}}}\]

Single channel, \(n_R\) resonances:

relativistic_k_matrix(n_resonances, n_channels=1)[0, 0]
\[\displaystyle \frac{\sqrt{\rho_{0}(m^{2})} \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}}{\left(- i \rho_{0}(m^{2}) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + 1\right) \overline{\sqrt{\rho_{0}(m^{2})}}}\]

Two channels, \(n_R\) resonances:

expr = relativistic_k_matrix(n_resonances, n_channels=2)[0, 0]
Math(sp.multiline_latex("", expr))
\[\displaystyle \begin{align*} \mathtt{\text{}} = & \frac{\left(\frac{\left(- i \rho_{1}(m^{2}) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,1}}(m^{2}) {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + 1\right) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}}{- \rho_{1}(m^{2})^{2} \left(\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}\right) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,1}}(m^{2}) {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + \rho_{1}(m^{2})^{2} \left(\sum_{R=0}^{n_{R} - 1} \frac{\sqrt{{\Gamma_{R,0}}(m^{2}) {m}_{R}} \sqrt{{\Gamma_{R,1}}(m^{2}) {m}_{R}} {\gamma}_{R,0} {\gamma}_{R,1}}{- m^{2} + {m}_{R}^{2}}\right)^{2} - i \rho_{1}(m^{2}) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} - i \rho_{1}(m^{2}) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,1}}(m^{2}) {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + 1} + \frac{i \rho_{1}(m^{2}) \left(\sum_{R=0}^{n_{R} - 1} \frac{\sqrt{{\Gamma_{R,0}}(m^{2}) {m}_{R}} \sqrt{{\Gamma_{R,1}}(m^{2}) {m}_{R}} {\gamma}_{R,0} {\gamma}_{R,1}}{- m^{2} + {m}_{R}^{2}}\right)^{2}}{- \rho_{1}(m^{2})^{2} \left(\sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}}\right) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,1}}(m^{2}) {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + \rho_{1}(m^{2})^{2} \left(\sum_{R=0}^{n_{R} - 1} \frac{\sqrt{{\Gamma_{R,0}}(m^{2}) {m}_{R}} \sqrt{{\Gamma_{R,1}}(m^{2}) {m}_{R}} {\gamma}_{R,0} {\gamma}_{R,1}}{- m^{2} + {m}_{R}^{2}}\right)^{2} - i \rho_{1}(m^{2}) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,0}}(m^{2}) {\gamma}_{R,0}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} - i \rho_{1}(m^{2}) \sum_{R=0}^{n_{R} - 1} \frac{{\Gamma_{R,1}}(m^{2}) {\gamma}_{R,1}^{2} {m}_{R}}{- m^{2} + {m}_{R}^{2}} + 1}\right) \sqrt{\rho_{0}(m^{2})}}{\overline{\sqrt{\rho_{0}(m^{2})}}} \end{align*}\]
Visualization#
Hide code cell source
def plot_relativistic_k_matrix(
    n_channels: int,
    n_resonances: int,
    angular_momentum: int | sp.Symbol = 0,
    title: str = "",
) -> None:
    # Convert to Symbol: symplot cannot handle IndexedBase
    epsilon = sp.Symbol("epsilon")
    i, j = sp.symbols("i, j", integer=True, negative=False)
    j = i
    expr = relativistic_k_matrix(
        n_resonances, n_channels, angular_momentum=angular_momentum
    ).doit()[i, j]
    expr = symplot.substitute_indexed_symbols(expr)
    expr = expr.subs(m, m + epsilon * sp.I)
    np_expr, sliders = symplot.prepare_sliders(expr, m)
    symbol_to_arg = {symbol: arg for arg, symbol in sliders._arg_to_symbol.items()}

    # Set plot domain
    x_min, x_max = 1e-3, 3
    y_min, y_max = -0.5, +0.5

    plot_domain = np.linspace(x_min, x_max, num=500)
    x_values = np.linspace(x_min, x_max, num=160)
    y_values = np.linspace(y_min, y_max, num=80)
    X, Y = np.meshgrid(x_values, y_values)
    plot_domain_complex = X + Y * 1j

    # Set slider values and ranges
    m0_values = np.linspace(x_min, x_max, num=n_resonances + 2)
    m0_values = m0_values[1:-1]
    if "L" in sliders:
        sliders.set_ranges(L=(0, 8))
    for R in range(n_resonances):
        for i in range(n_channels):
            sliders.set_ranges({
                "i": (0, n_channels - 1),
                "epsilon": (y_min * 0.2, y_max * 0.2, 0.01),
                f"m{R}": (0, 3, 100),
                Rf"\Gamma_{{{R},{i}}}": (-2, +2, 100),
                Rf"\gamma_{{{R},{i}}}": (0, 10, 100),
                f"m_a{i}": (0, 1, 0.01),
                f"m_b{i}": (0, 1, 0.01),
            })
            sliders.set_values({
                f"m{R}": m0_values[R],
                Rf"\Gamma_{{{R},{i}}}": 2.0 * (0.4 + R * 0.2 - i * 0.3),
                Rf"\gamma_{{{R},{i}}}": 0.25 * (10 - R + i),
                f"m_a{i}": (i + 1) * 0.25,
                f"m_b{i}": (i + 1) * 0.25,
            })

    # Create interactive plots
    controls = Controls(**sliders)
    fig, axes = plt.subplots(
        nrows=2,
        figsize=(8, 6),
        sharex=True,
        tight_layout=True,
    )
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.footer_visible = False
    if not title:
        title = (
            Rf"${n_channels} \times {n_channels}$ $K$-matrix"
            f" with {n_resonances} resonances"
        )
    fig.suptitle(title)

    for ax in axes:
        ax.set_xlim(x_min, x_max)
    ax_2d, ax_3d = axes
    ax_2d.set_ylabel("$|T|^{2}$")
    ax_2d.set_yticks([])
    ax_3d.set_xlabel("Re $m$")
    ax_3d.set_ylabel("Im $m$")
    ax_3d.set_xticks([])
    ax_3d.set_yticks([])
    ax_3d.set_facecolor("white")

    ax_3d.axhline(0, linewidth=0.5, c="black", linestyle="dotted")

    # 2D plot
    def plot(channel: int):
        def wrapped(*args, **kwargs) -> sp.Expr:
            kwargs["i"] = channel
            return np.abs(np_expr(*args, **kwargs)) ** 2

        return wrapped

    for i in range(n_channels):
        iplt.plot(
            plot_domain,
            plot(i),
            ax=axes[0],
            controls=controls,
            ylim="auto",
            label=f"channel {i}",
        )
    if n_channels > 1:
        axes[0].legend(loc="upper right")
    mass_line_style = {
        "c": "red",
        "alpha": 0.3,
    }
    for name in controls.params:
        if not re.match(r"^m[0-9]+$", name):
            continue
        iplt.axvline(controls[name], ax=axes[0], **mass_line_style)

    # 3D plot
    color_mesh = None
    epsilon_indicator = None
    resonances_indicators = []
    threshold_indicators = []

    def plot3(*, z_cutoff, complex_rendering, **kwargs):
        nonlocal color_mesh, epsilon_indicator
        epsilon = kwargs["epsilon"]
        kwargs["epsilon"] = 0
        Z = np_expr(plot_domain_complex, **kwargs)
        if complex_rendering == "imag":
            Z_values = Z.imag
            ax_title = "Re $T$"
        elif complex_rendering == "real":
            Z_values = Z.real
            ax_title = "Im $T$"
        elif complex_rendering == "abs":
            Z_values = np.abs(Z)
            ax_title = "$|T|$"
        else:
            raise NotImplementedError

        if n_channels == 1:
            axes[-1].set_title(ax_title)
        else:
            i = kwargs["i"]
            axes[-1].set_title(f"{ax_title}, channel {i}")

        if color_mesh is None:
            color_mesh = ax_3d.pcolormesh(X, Y, Z_values, cmap=cm.coolwarm)
        else:
            color_mesh.set_array(Z_values)
        color_mesh.set_clim(vmin=-z_cutoff, vmax=+z_cutoff)

        if resonances_indicators:
            for R, (line, text) in enumerate(resonances_indicators):
                mass = kwargs[f"m{R}"]
                line.set_xdata(mass)
                text.set_x(mass + (x_max - x_min) * 0.008)
        else:
            for R in range(n_resonances):
                mass = kwargs[f"m{R}"]
                line = ax_3d.axvline(mass, **mass_line_style)
                text = ax_3d.text(
                    x=mass + (x_max - x_min) * 0.008,
                    y=0.95 * y_min,
                    s=f"$m_{R}$",
                    c="red",
                )
                resonances_indicators.append((line, text))

        if epsilon_indicator is None:
            line = ax.axhline(
                epsilon,
                linewidth=0.5,
                c="blue",
                linestyle="dotted",
                label=R"$\epsilon$",
            )
            text = axes[-1].text(
                x=x_min + 0.008,
                y=epsilon + 0.01,
                s=R"$\epsilon$",
                c="blue",
            )
            epsilon_indicator = line, text
        else:
            line, text = epsilon_indicator
            line.set_xdata(epsilon)
            text.set_y(epsilon + 0.01)

        x_offset = (x_max - x_min) * 0.015
        if threshold_indicators:
            for i, (line_thr, line_diff, text_thr, text_diff) in enumerate(
                threshold_indicators
            ):
                m_a = kwargs[f"m_a{i}"]
                m_b = kwargs[f"m_b{i}"]
                s_thr = m_a + m_b
                m_diff = m_a - m_b
                line_thr.set_xdata(s_thr)
                line_diff.set_xdata(m_diff)
                text_thr.set_x(s_thr)
                text_diff.set_x(m_diff - x_offset)
        else:
            colors = cm.plasma(np.linspace(0, 1, n_channels))
            for i, color in enumerate(colors):
                m_a = kwargs[f"m_a{i}"]
                m_b = kwargs[f"m_b{i}"]
                s_thr = m_a + m_b
                m_diff = m_a - m_b
                line_thr = ax.axvline(s_thr, c=color, linestyle="dotted")
                line_diff = ax.axvline(m_diff, c=color, linestyle="dashed")
                text_thr = ax.text(
                    x=s_thr,
                    y=0.95 * y_min,
                    s=f"$m_{{a{i}}}+m_{{b{i}}}$",
                    c=color,
                    rotation=-90,
                )
                text_diff = ax.text(
                    x=m_diff - x_offset,
                    y=0.95 * y_min,
                    s=f"$m_{{a{i}}}-m_{{b{i}}}$",
                    c=color,
                    rotation=+90,
                )
                threshold_indicators.append(
                    (line_thr, line_diff, text_thr, text_diff)
                )
        for i, (_, line_diff, _, text_diff) in enumerate(threshold_indicators):
            m_a = kwargs[f"m_a{i}"]
            m_b = kwargs[f"m_b{i}"]
            s_thr = m_a + m_b
            m_diff = m_a - m_b
            if m_diff > x_offset + 0.01 and s_thr - abs(m_diff) > x_offset:
                line_diff.set_alpha(0.5)
                text_diff.set_alpha(0.5)
            else:
                line_diff.set_alpha(0)
                text_diff.set_alpha(0)

    # Create switch for imag/real/abs
    name = "complex_rendering"
    sliders._sliders[name] = ipywidgets.RadioButtons(
        options=["imag", "real", "abs"],
        description=R"\(s\)-plane plot",
    )
    sliders._arg_to_symbol[name] = name

    # Create cut-off slider for z-direction
    name = "z_cutoff"
    sliders._sliders[name] = ipywidgets.IntSlider(
        value=30,
        min=+1,
        max=+100,
        description=R"\(z\)-cutoff",
    )
    sliders._arg_to_symbol[name] = name

    # Create GUI
    sliders_copy = dict(sliders)
    h_boxes = []
    for R in range(n_resonances):
        buttons = [sliders_copy.pop(f"m{R}")]
        if n_channels == 1:
            buttons.append(sliders_copy.pop(symbol_to_arg[Rf"\Gamma_{{{R},0}}"]))
            buttons.append(sliders_copy.pop(symbol_to_arg[Rf"\gamma_{{{R},0}}"]))
        h_box = ipywidgets.HBox(buttons)
        h_boxes.append(h_box)
    remaining_sliders = sorted(
        sliders_copy.values(), key=lambda s: (str(type(s)), s.description)
    )
    if n_channels == 1:
        remaining_sliders.remove(sliders["i"])
    ui = ipywidgets.VBox(h_boxes + remaining_sliders)
    output = ipywidgets.interactive_output(plot3, controls=sliders)
    display(ui, output)
plot_relativistic_k_matrix(
    n_resonances=2,
    n_channels=1,
    angular_momentum=L,
    title="Relativistic $K$-matrix, single channel",
)

P-vector#

Hide code cell content
from __future__ import annotations

import os
import re
import warnings
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import mpl_interactions.ipyplot as iplt
import numpy as np
import symplot
import sympy as sp
from ampform.dynamics import (
    BlattWeisskopfSquared,
    breakup_momentum_squared,
    coupled_width,
    phase_space_factor_complex,
)
from ampform.dynamics.decorator import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
)
from IPython.display import display
from ipywidgets import widgets as ipywidgets
from matplotlib import cm
from mpl_interactions.controller import Controls

if TYPE_CHECKING:
    from sympy.printing.latex import LatexPrinter

warnings.filterwarnings("ignore")
STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)
Physics#

As described in TR-005, the \(\boldsymbol{K}\)-matrix describes scattering processes of the type \(cd \to ab\). The \(P\)-vector approach is one of two generalizations for production processes of the type \(c \to ab\). For more details on this approach, [Chung et al., 1995] refers to [Aitchison, 1972].

If we take the production vector \(P\) to be:

(1)#\[ P_i = \sum_R \frac{\beta^0_R\,g_{R,i}(m)}{m_R^2-m^2} \]

and, in its invariant form,

(2)#\[ \hat{P}_i = \sum_R \frac{\beta^0_R\,g_{R,i}(m)}{\left(m_R^2-m^2\right)\sqrt{\rho_i}} \]

with \(g_{R,i}(m)\) given by Eq. (7) (possibly with coupled_width()), then the vector \(F\) describes the resulting amplitudes by

(3)#\[\begin{split} \begin{eqnarray} F & = & \left(\boldsymbol{I}-i\boldsymbol{K}\right)^{-1}P \\ \hat{F} & = & \left(\boldsymbol{I}-i\boldsymbol{\hat{K}\boldsymbol{\rho}}\right)^{-1}\hat{P} \end{eqnarray} \end{split}\]

with, from Eq. (4):

(4)#\[ \hat{\boldsymbol{K}} = \sqrt{\left(\boldsymbol{\rho}^\dagger\right)^{-1}} \boldsymbol{K} \sqrt{\boldsymbol{\rho}^{-1}} \]

Just like with the residue functions in TR-005 and TR-009, \(\beta\) is often expressed in terms of resonance mass and ‘width’:

(5)#\[ \beta^0_R = \beta_R\sqrt{m_R\Gamma^0_R} \]

When in addition, we use a coupled_width(), the \(\hat{P}\)-vector becomes:

\[ \hat{P}_i = \sum_R \frac{\beta_R\gamma_{R,i}m_R\Gamma^0_R B_{R,i}(m)}{m_R^2-m^2} \]

with \(B_{R,i}(m)\) the ratio of Blatt-Weisskopf barrier factors (BlattWeisskopfSquared) for channel \(i\).

Implementation#
Hide code cell content
@implement_doit_method()
class PhaseSpaceFactor(UnevaluatedExpression):
    is_commutative = True

    def __new__(
        cls,
        s: sp.Symbol,
        m_a: sp.Symbol,
        m_b: sp.Symbol,
        i: int,
        **hints,
    ) -> PhaseSpaceFactor:
        return create_expression(cls, s, m_a, m_b, i, **hints)

    def evaluate(self) -> sp.Expr:
        s, m_a, m_b, *_ = self.args
        return phase_space_factor_complex(s, m_a, m_b)

    def _latex(self, printer: LatexPrinter, *args) -> str:
        s = printer._print(self.args[0])
        i = self.args[-1]
        return Rf"\rho_{{{i}}}({s})"


@implement_doit_method()
class CoupledWidth(UnevaluatedExpression):
    is_commutative = True

    def __new__(
        cls,
        s: sp.Symbol,
        mass0: sp.IndexedBase,
        gamma0: sp.IndexedBase,
        m_a: sp.IndexedBase,
        m_b: sp.IndexedBase,
        angular_momentum: int,
        R: int | sp.Symbol,
        i: int,
        **hints,
    ) -> CoupledWidth:
        return create_expression(
            cls, s, mass0, gamma0, m_a, m_b, angular_momentum, R, i, **hints
        )

    def evaluate(self) -> sp.Expr:
        s, mass0, gamma0, m_a, m_b, angular_momentum, R, i = self.args

        def phsp_factor(s, m_a, m_b):
            return PhaseSpaceFactor(s, m_a, m_b, i)

        return coupled_width(
            s,
            mass0[R],
            gamma0[R, i],
            m_a[i],
            m_b[i],
            angular_momentum=angular_momentum,
            meson_radius=1,
            phsp_factor=phsp_factor,
        )

    def _latex(self, printer: LatexPrinter, *args) -> str:
        s = printer._print(self.args[0])
        R = self.args[-2]
        i = self.args[-1]
        return Rf"{{\Gamma_{{{R},{i}}}}}({s})"
def Pi_relativistic(
    m: sp.Symbol,
    M: sp.IndexedBase,
    Gamma: sp.IndexedBase,
    gamma: sp.IndexedBase,
    beta: sp.IndexedBase,
    m_a: sp.IndexedBase,
    m_b: sp.IndexedBase,
    R: int | sp.Symbol,
    i: int,
    n_resonances: int | sp.Symbol,
    angular_momentum: int | sp.Symbol = 0,
) -> sp.Expr:
    q_squared = breakup_momentum_squared(m**2, m_a[i], m_b[i])
    ff2 = BlattWeisskopfSquared(z=q_squared, angular_momentum=angular_momentum)
    parametrization = (beta[R] * gamma[R, i] * M[R] * Gamma[R, i] * sp.sqrt(ff2)) / (
        M[R] ** 2 - m**2
    )
    return sp.Sum(parametrization, (R, 0, n_resonances - 1))


def Kij_relativistic(
    m: sp.Symbol,
    M: sp.IndexedBase,
    Gamma: sp.IndexedBase,
    gamma: sp.IndexedBase,
    m_a: sp.IndexedBase,
    m_b: sp.IndexedBase,
    R: sp.IndexedBase,
    i: int,
    j: int,
    n_resonances: int | sp.Symbol,
    angular_momentum: int | sp.Symbol = 0,
) -> sp.Expr:
    def residue_function(i):
        return gamma[R, i] * sp.sqrt(
            M[R] * CoupledWidth(m**2, M, Gamma, m_a, m_b, angular_momentum, R, i)
        )

    g_i = residue_function(i)
    g_j = residue_function(j)
    parametrization = (g_i * g_j) / (M[R] ** 2 - m**2)
    return sp.Sum(parametrization, (R, 0, n_resonances - 1))


def f_vector(
    n_resonances: int,
    n_channels: int,
    angular_momentum: int | sp.Symbol = 0,
) -> sp.Matrix:
    # Define symbols
    R = sp.Symbol("R")
    m = sp.Symbol("m", real=True)
    M = sp.IndexedBase("m", shape=(n_resonances,))
    Gamma = sp.IndexedBase("Gamma", shape=(n_resonances, n_channels))
    gamma = sp.IndexedBase("gamma", shape=(n_resonances, n_channels))
    beta = sp.IndexedBase("beta", shape=(n_resonances,))
    m_a = sp.IndexedBase("m_a", shape=(n_channels,))
    m_b = sp.IndexedBase("m_b", shape=(n_channels,))
    # Define phase space matrix
    rho = sp.zeros(n_channels, n_channels)
    for i in range(n_channels):
        rho[i, i] = PhaseSpaceFactor(m**2, m_a[i], m_b[i], i)
    # Define P-vector, K-matrix and T-matrix
    P = create_symbol_matrix("P", n_channels, 1)
    K = create_symbol_matrix("K", n_channels, n_channels)
    F = (sp.eye(n_channels) - sp.I * K * rho).inv() * P
    # Substitute elements
    return F.subs({
        K[i, j]: Kij_relativistic(
            m=m,
            M=M,
            Gamma=Gamma,
            gamma=gamma,
            m_a=m_a,
            m_b=m_b,
            i=i,
            j=j,
            R=R,
            n_resonances=n_resonances,
            angular_momentum=angular_momentum,
        )
        for i in range(n_channels)
        for j in range(n_channels)
    }).subs({
        P[i]: Pi_relativistic(
            m=m,
            M=M,
            Gamma=Gamma,
            gamma=gamma,
            beta=beta,
            i=i,
            m_a=m_a,
            m_b=m_b,
            R=R,
            n_resonances=n_resonances,
            angular_momentum=angular_momentum,
        )
        for i in range(n_channels)
    })


def create_symbol_matrix(name: str, m: int, n: int) -> sp.Matrix:
    symbol = sp.IndexedBase(name, shape=(m, n))
    return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(m)])
L = sp.Symbol("L", integer=True)
symplot.partial_doit(
    f_vector(n_resonances=1, n_channels=1, angular_momentum=L)[0, 0], sp.Sum
)
\[\displaystyle \frac{\sqrt{B_{L}^2\left(\frac{\left(m^{2} - \left({m_{a}}_{0} - {m_{b}}_{0}\right)^{2}\right) \left(m^{2} - \left({m_{a}}_{0} + {m_{b}}_{0}\right)^{2}\right)}{4 m^{2}}\right)} {\Gamma}_{0,0} {\beta}_{0} {\gamma}_{0,0} {m}_{0}}{\left(1 - \frac{i {\Gamma_{0,0}}(m^{2}) {\gamma}_{0,0}^{2} {m}_{0} \rho_{0}(m^{2})}{- m^{2} + {m}_{0}^{2}}\right) \left(- m^{2} + {m}_{0}^{2}\right)}\]
Visualization#
def plot_f_vector(
    n_channels: int,
    n_resonances: int,
    angular_momentum: int | sp.Symbol = 0,
    title: str = "",
) -> None:
    # Convert to Symbol: symplot cannot handle
    m = sp.Symbol("m", real=True)
    epsilon = sp.Symbol("epsilon", real=True)
    i = sp.Symbol("i", integer=True, negative=False)
    expr = f_vector(
        n_resonances, n_channels, angular_momentum=angular_momentum
    ).doit()[i, 0]
    expr = symplot.substitute_indexed_symbols(expr)
    expr = expr.subs(m, m + epsilon * sp.I)
    np_expr, sliders = symplot.prepare_sliders(expr, m)
    symbol_to_arg = {symbol: arg for arg, symbol in sliders._arg_to_symbol.items()}

    # Set plot domain
    x_min, x_max = 1e-3, 3
    y_min, y_max = -0.5, +0.5

    plot_domain = np.linspace(x_min, x_max, num=500)
    x_values = np.linspace(x_min, x_max, num=160)
    y_values = np.linspace(y_min, y_max, num=80)
    X, Y = np.meshgrid(x_values, y_values)
    plot_domain_complex = X + Y * 1j

    # Set slider values and ranges
    m0_values = np.linspace(x_min, x_max, num=n_resonances + 2)
    m0_values = m0_values[1:-1]

    def set_default_values():
        if "L" in sliders:
            sliders.set_ranges(L=(0, 8))
        sliders.set_ranges(
            i=(0, n_channels - 1),
            epsilon=(y_min * 0.2, y_max * 0.2, 0.01),
        )
        for R in range(n_resonances):
            for i in range(n_channels):
                sliders.set_ranges({
                    f"m{R}": (0, 3, 100),
                    f"beta{R}": (0, 5, 0.1),
                    Rf"\Gamma_{{{R},{i}}}": (-5, +5, 100),
                    Rf"\gamma_{{{R},{i}}}": (0, 20, 100),
                    f"m_a{i}": (0, 1, 0.01),
                    f"m_b{i}": (0, 1, 0.01),
                })
                sliders.set_values({
                    f"m{R}": m0_values[R],
                    f"beta{R}": 1,
                    Rf"\Gamma_{{{R},{i}}}": 3 * (0.4 + R * 0.2 - i * 0.3),
                    Rf"\gamma_{{{R},{i}}}": 0.2 * (10 - R + i),
                    f"m_a{i}": (i + 1) * 0.25,
                    f"m_b{i}": (i + 1) * 0.25,
                })

    set_default_values()

    # Create interactive plots
    controls = Controls(**sliders)
    nrows = 2  # set to 3 for imag+real
    fig, axes = plt.subplots(
        nrows=nrows,
        figsize=(8, nrows * 3.0),
        sharex=True,
        tight_layout=True,
    )
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.footer_visible = False
    for ax in axes:
        ax.set_xlim(x_min, x_max)
    if not title:
        title = f"{n_channels}-channel $F$-vector with {n_resonances} resonances"
    fig.suptitle(title)

    # 2D plot
    axes[0].set_ylabel("$|T|^{2}$")
    axes[0].set_yticks([])

    def plot(channel: int):
        def wrapped(*args, **kwargs) -> sp.Expr:
            kwargs["i"] = channel
            return np.abs(np_expr(*args, **kwargs)) ** 2

        return wrapped

    for i in range(n_channels):
        iplt.plot(
            plot_domain,
            plot(i),
            ax=axes[0],
            controls=controls,
            ylim="auto",
            label=f"channel {i}",
        )
    if n_channels > 1:
        axes[0].legend(loc="upper right")
    mass_line_style = {
        "c": "red",
        "alpha": 0.3,
    }
    for name in controls.params:
        if not re.match(r"^m[0-9]+$", name):
            continue
        iplt.axvline(controls[name], ax=axes[0], **mass_line_style)

    # 3D plot
    def plot3(**kwargs):
        z_cutoff = kwargs.pop("z_cutoff")
        epsilon = kwargs["epsilon"]
        kwargs["epsilon"] = 0
        imag_real = kwargs.pop("imag_real")
        Z = np_expr(plot_domain_complex, **kwargs)
        if imag_real == "imag":
            Z_values = Z.imag
            ax_title = "Re $T$"
        elif imag_real == "real":
            Z_values = Z.real
            ax_title = "Im $T$"
        elif imag_real == "abs":
            Z_values = np.abs(Z)
            ax_title = "$|T|$"
        else:
            raise NotImplementedError
        for ax in axes[1:]:
            ax.clear()
        axes[-1].pcolormesh(
            X, Y, Z_values, cmap=cm.coolwarm, vmin=-z_cutoff, vmax=+z_cutoff
        )
        i = kwargs["i"]
        if n_channels == 1:
            axes[-1].set_title(ax_title)
        else:
            axes[-1].set_title(f"{ax_title}, channel {i}")
        for ax in axes[1:]:
            ax.axhline(0, linewidth=0.5, c="black", linestyle="dotted")
            if epsilon != 0.0:
                ax.axhline(
                    epsilon,
                    linewidth=0.5,
                    c="blue",
                    linestyle="dotted",
                    label=R"$\epsilon$",
                )
                axes[-1].text(
                    x=x_min + 0.008,
                    y=epsilon + 0.01,
                    s=R"$\epsilon$",
                    c="blue",
                )
            for R in range(n_resonances):
                mass = kwargs[f"m{R}"]
                ax.axvline(mass, **mass_line_style)
            if "m_a0" in kwargs:
                colors = cm.plasma(np.linspace(0, 1, n_channels))
                for i, color in enumerate(colors):
                    m_a = kwargs[f"m_a{i}"]
                    m_b = kwargs[f"m_b{i}"]
                    s_thr = m_a + m_b
                    ax.axvline(s_thr, c=color, linestyle="dotted")
                    ax.text(
                        x=s_thr,
                        y=0.95 * y_min,
                        s=f"$m_{{a{i}}}+m_{{b{i}}}$",
                        c=color,
                        rotation=-90,
                    )
                    m_diff = m_a - m_b
                    x_offset = (x_max - x_min) * 0.015
                    if m_diff > x_offset + 0.01 and s_thr - abs(m_diff) > x_offset:
                        ax.axvline(
                            m_diff,
                            c=color,
                            linestyle="dashed",
                            alpha=0.5,
                        )
                        ax.text(
                            x=m_diff - x_offset,
                            y=0.95 * y_min,
                            s=f"$m_{{a{i}}}-m_{{b{i}}}$",
                            c=color,
                            rotation=+90,
                        )
            ax.set_ylabel("Im $m$")
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_facecolor("white")
        for R in range(n_resonances):
            mass = kwargs[f"m{R}"]
            axes[-1].text(
                x=mass + (x_max - x_min) * 0.008,
                y=0.95 * y_min,
                s=f"$m_{R}$",
                c="red",
            )
        axes[-1].set_xlabel("Re $m$")
        fig.canvas.draw_idle()

    # Create switch for imag/real/abs
    name = "imag_real"
    sliders._sliders[name] = ipywidgets.RadioButtons(
        options=["imag", "real", "abs"],
        description=R"\(s\)-plane plot",
    )
    sliders._arg_to_symbol[name] = name

    # Create cut-off slider for z-direction
    name = "z_cutoff"
    sliders._sliders[name] = ipywidgets.IntSlider(
        value=10,
        min=+1,
        max=+50,
        description=R"\(z\)-cutoff",
    )
    sliders._arg_to_symbol[name] = name

    # Create GUI
    sliders_copy = dict(sliders)
    h_boxes = []
    for R in range(n_resonances):
        buttons = [sliders_copy.pop(f"m{R}")]
        if n_channels == 1:
            buttons.append(sliders_copy.pop(symbol_to_arg[Rf"\Gamma_{{{R},0}}"]))
            buttons.append(sliders_copy.pop(symbol_to_arg[Rf"\gamma_{{{R},0}}"]))
        h_box = ipywidgets.HBox(buttons)
        h_boxes.append(h_box)
    remaining_sliders = sorted(
        sliders_copy.values(), key=lambda s: (str(type(s)), s.description)
    )
    if n_channels == 1:
        remaining_sliders.remove(sliders["i"])
    ui = ipywidgets.VBox(h_boxes + remaining_sliders)
    output = ipywidgets.interactive_output(plot3, controls=sliders)
    display(ui, output)
plot_f_vector(
    n_resonances=2,
    n_channels=1,
    angular_momentum=L,
    title="Relativistic $F$-vector, single channel",
)

Symbolic kinematics#

Hide code cell content
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING

import black
import graphviz
import numpy as np
import qrules
import sympy as sp
from ampform.data import EventCollection
from ampform.kinematics import (
    _compute_helicity_angles,
    determine_attached_final_state,
    get_helicity_angle_label,
)
from ampform.sympy import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
)
from IPython.display import Math, display
from qrules.topology import Topology, create_isobar_topologies
from sympy.printing.numpy import NumPyPrinter
from sympy.tensor.array.expressions.array_expressions import ArraySlice, ArraySymbol

if TYPE_CHECKING:
    from sympy.printing.printer import Printer

This report investigates issue compwa.github.io#56. The ideal solution would be to use only SymPy in the existing ampform.kinematics module. This has two benefits:

  1. It allows computing kinematic variables from four-momenta with different computational back-ends.

  2. Expressions for kinematic variable can be inspected through their LaTeX representation.

To simplify things, we investigate 1. by only lambdifying to NumPy. It should be relatively straightforward to lambdify to other back-ends like TensorFlow (as long as they support Einstein summation).

Test sample#

Data sample taken from this test in AmpForm and topology and expected angles taken from here.

topologies = create_isobar_topologies(4)
topology = topologies[1]
Hide code cell source
dot = qrules.io.asdot(topology)
graphviz.Source(dot)

Hide code cell content
events = EventCollection({
    0: np.array([  # pi0
        (1.35527, 0.514208, -0.184219, 1.23296),
        (0.841933, 0.0727385, -0.0528868, 0.826163),
        (0.550927, -0.162529, 0.29976, -0.411133),
    ]),
    1: np.array([  # gamma
        (0.755744, -0.305812, 0.284, -0.630057),
        (1.02861, 0.784483, 0.614347, -0.255334),
        (0.356875, -0.20767, 0.272796, 0.0990739),
    ]),
    2: np.array([  # pi0
        (0.208274, -0.061663, -0.0211864, 0.144596),
        (0.461193, -0.243319, -0.283044, -0.234866),
        (1.03294, 0.82872, -0.0465425, -0.599834),
    ]),
    3: np.array([  # pi0
        (0.777613, -0.146733, -0.0785946, -0.747499),
        (0.765168, -0.613903, -0.278416, -0.335962),
        (1.15616, -0.458522, -0.526014, 0.911894),
    ]),
})
angles = _compute_helicity_angles(events, topology)
angles._DataSet__data
{'phi_1+2+3': ScalarSequence([ 2.79758029  2.51292308 -1.07396684],
 'theta_1+2+3': ScalarSequence([2.72456853 3.03316287 0.69240082],
 'phi_2+3,1+2+3': ScalarSequence([1.0436215  1.8734936  0.16073833],
 'theta_2+3,1+2+3': ScalarSequence([2.45361589 1.40639741 0.98079245],
 'phi_2,2+3,1+2+3': ScalarSequence([ 0.36955786 -1.68820498  0.63063002],
 'theta_2,2+3,1+2+3': ScalarSequence([1.0924374  1.99375767 1.31959621]}
Einstein summation#

First challenge is to express the Einstein summation in the existing implementation in terms of SymPy. The aim is to render the expression resulting nicely as LaTeX while at the same time being able to lambdify the expression to efficient NumPy code. We do this by deriving from UnevaluatedExpression and using the decorator functions provided by the ampform.sympy module.

Define boost and rotation classes#

First, wrap rotations and boosts in a class so with a nice LaTeX printer. Later on, a NumPy printer method will be defined externally for each of them.

class BoostZ(UnevaluatedExpression):
    def __new__(cls, beta: sp.Symbol, **hints) -> BoostZ:
        return create_expression(cls, beta, **hints)

    def as_explicit(self) -> sp.Expr:
        gamma = 1 / sp.sqrt(1 - beta**2)
        return sp.Matrix([
            [gamma, 0, 0, -gamma * beta],
            [0, 1, 0, 0],
            [0, 0, 1, 0],
            [-gamma * beta, 0, 0, gamma],
        ])

    def _latex(self, printer, *args) -> str:
        beta, *_ = self.args
        beta = printer._print(beta)
        return Rf"\boldsymbol{{B_z}}\left({beta}\right)"
class RotationY(UnevaluatedExpression):
    def __new__(cls, angle: sp.Symbol, **hints) -> RotationY:
        return create_expression(cls, angle, **hints)

    def as_explicit(self) -> sp.Expr:
        angle = self.args[0]
        return sp.Matrix([
            [1, 0, 0, 0],
            [0, sp.cos(angle), 0, sp.sin(angle)],
            [0, 0, 1, 0],
            [0, -sp.sin(angle), 0, sp.cos(angle)],
        ])

    def _latex(self, printer, *args) -> str:
        angle, *_ = self.args
        angle = printer._print(angle)
        return Rf"\boldsymbol{{R_y}}\left({angle}\right)"
Hide code cell content
class RotationZ(UnevaluatedExpression):
    def __new__(cls, angle: sp.Symbol, **hints) -> RotationZ:
        return create_expression(cls, angle, **hints)

    def as_explicit(self) -> sp.Expr:
        angle = self.args[0]
        return sp.Matrix([
            [1, 0, 0, 0],
            [0, sp.cos(angle), -sp.sin(angle), 0],
            [0, sp.sin(angle), sp.cos(angle), 0],
            [0, 0, 0, 1],
        ])

    def _latex(self, printer, *args) -> str:
        angle, *_ = self.args
        angle = printer._print(angle)
        return Rf"\boldsymbol{{R_z}}\left({angle}\right)"
Define Einstein summation class#

Similarly, we define a ArrayMultiplication class that will eventually be lambdified to numpy.einsum().

class ArrayMultiplication(sp.Expr):
    def __new__(cls, *tensors: sp.Symbol, **hints):
        return create_expression(cls, *tensors, **hints)

    def _latex(self, printer, *args) -> str:
        tensors_latex = map(printer._print, self.args)
        return " ".join(tensors_latex)

Indeed an expression involving these classes looks nice on the top-level:

n_events = 3
momentum = sp.MatrixSymbol("p", m=n_events, n=4)
beta = sp.Symbol("beta")
phi = sp.Symbol("phi")
theta = sp.Symbol("theta")

boosted_momentum = ArrayMultiplication(
    BoostZ(beta),
    RotationY(-theta),
    RotationZ(-phi),
    momentum,
)
boosted_momentum
\[\displaystyle \boldsymbol{B_z}\left(\beta\right) \boldsymbol{R_y}\left(- \theta\right) \boldsymbol{R_z}\left(- \phi\right) p\]

Note

It could be that the above can be achieved with SymPy’s ArrayTensorProduct and ArrayContraction. See for instance sympy/sympy#22279.

Define lambdification#
Hide code cell content
# small helper function


def print_lambdify(symbols, expr):
    np_expr = sp.lambdify(symbols, expr)
    src = inspect.getsource(np_expr)
    src = black.format_str(src, mode=black.Mode(line_length=79))
    print(src)

Now we have a problem: lambdification does not work…

print_lambdify([beta, theta, momentum], boosted_momentum)
def _lambdifygenerated(beta, theta, p):
    return (  # Not supported in Python with SciPy and NumPy:
        # ArrayMultiplication
        ArrayMultiplication(
            BoostZ(beta), RotationY(-theta), RotationZ(-phi), p
        )
    )

But lambdification can be defined externally to both the SymPy library and the expression classes. Here’s an implementation for NumPy where we define the lambdification through the expression class (with a _numpycode method):

def print_as_numpy(self, printer: Printer, *args) -> str:
    def multiply(matrix, vector):
        return (
            'einsum("ij...,j...",'
            f" transpose({matrix}, axes=(1, 2, 0)),"
            f" transpose({vector}))"
        )

    def recursive_multiply(tensors):
        if len(tensors) < 2:
            msg = "Need at least two tensors"
            raise ValueError(msg)
        if len(tensors) == 2:
            return multiply(tensors[0], tensors[1])
        return multiply(tensors[0], recursive_multiply(tensors[1:]))

    printer.module_imports["numpy"].update({"einsum", "transpose"})
    tensors = list(map(printer._print, self.args))
    if len(tensors) == 0:
        return ""
    if len(tensors) == 1:
        return tensors[0]
    return recursive_multiply(tensors)


ArrayMultiplication._numpycode = print_as_numpy
print_lambdify(
    symbols=[beta, theta, momentum],
    expr=ArrayMultiplication(beta, theta, momentum),
)
def _lambdifygenerated(beta, theta, p):
    return einsum(
        "ij...,j...",
        transpose(beta, axes=(1, 2, 0)),
        transpose(
            einsum(
                "ij...,j...", transpose(theta, axes=(1, 2, 0)), transpose(p)
            )
        ),
    )

This also needs to be done for the rotation and boost classes:

print_lambdify([beta, theta, momentum], boosted_momentum)
def _lambdifygenerated(beta, theta, p):
    return (  # Not supported in Python with SciPy and NumPy:
        # BoostZ
        # RotationY
        # RotationZ
        einsum(
            "ij...,j...",
            transpose(BoostZ(beta), axes=(1, 2, 0)),
            transpose(
                einsum(
                    "ij...,j...",
                    transpose(RotationY(-theta), axes=(1, 2, 0)),
                    transpose(
                        einsum(
                            "ij...,j...",
                            transpose(RotationZ(-phi), axes=(1, 2, 0)),
                            transpose(p),
                        )
                    ),
                )
            ),
        )
    )

This time, we define the lambdification through the printer class:

def _print_BoostZ(self: NumPyPrinter, expr: BoostZ) -> str:
    self.module_imports["numpy"].update({"array", "ones", "zeros", "sqrt"})
    arg = expr.args[0]
    beta = self._print(arg)
    gamma = f"1 / sqrt(1 - ({beta}) ** 2)"
    n_events = f"len({beta})"
    zeros = f"zeros({n_events})"
    ones = f"ones({n_events})"
    return f"""array(
        [
            [{gamma}, {zeros}, {zeros}, -{gamma} * {beta}],
            [{zeros}, {ones}, {zeros}, {zeros}],
            [{zeros}, {zeros}, {ones}, {zeros}],
            [-{gamma} * {beta}, {zeros}, {zeros}, {gamma}],
        ]
    ).transpose(2, 0, 1)"""


NumPyPrinter._print_BoostZ = _print_BoostZ
def _print_RotationY(self: NumPyPrinter, expr: RotationY) -> str:
    self.module_imports["numpy"].update({"array", "cos", "ones", "zeros", "sin"})
    arg = expr.args[0]
    angle = self._print(arg)
    n_events = f"len({angle})"
    zeros = f"zeros({n_events})"
    ones = f"ones({n_events})"
    return f"""array(
        [
            [{ones}, {zeros}, {zeros}, {zeros}],
            [{zeros}, cos({angle}), {zeros}, sin({angle})],
            [{zeros}, {zeros}, {ones}, {zeros}],
            [{zeros}, -sin({angle}), {zeros}, cos({angle})],
        ]
    ).transpose(2, 0, 1)"""


NumPyPrinter._print_RotationY = _print_RotationY
Hide code cell content
def _print_RotationZ(self: NumPyPrinter, expr: RotationZ) -> str:
    self.module_imports["numpy"].update({"array", "cos", "ones", "zeros", "sin"})
    arg = expr.args[0]
    angle = self._print(arg)
    n_events = f"len({angle})"
    zeros = f"zeros({n_events})"
    ones = f"ones({n_events})"
    return f"""array(
        [
            [{ones}, {zeros}, {zeros}, {zeros}],
            [{zeros}, cos({angle}), -sin({angle}), {zeros}],
            [{zeros}, sin({angle}), cos({angle}), {zeros}],
            [{zeros}, {zeros}, {zeros}, {ones}],
        ]
    ).transpose(2, 0, 1)"""


NumPyPrinter._print_RotationZ = _print_RotationZ
print_lambdify([beta, theta, momentum], boosted_momentum)
def _lambdifygenerated(beta, theta, p):
    return einsum(
        "ij...,j...",
        transpose(
            array(
                [
                    [
                        1 / sqrt(1 - (beta) ** 2),
                        zeros(len(beta)),
                        zeros(len(beta)),
                        -1 / sqrt(1 - (beta) ** 2) * beta,
                    ],
                    [
                        zeros(len(beta)),
                        ones(len(beta)),
                        zeros(len(beta)),
                        zeros(len(beta)),
                    ],
                    [
                        zeros(len(beta)),
                        zeros(len(beta)),
                        ones(len(beta)),
                        zeros(len(beta)),
                    ],
                    [
                        -1 / sqrt(1 - (beta) ** 2) * beta,
                        zeros(len(beta)),
                        zeros(len(beta)),
                        1 / sqrt(1 - (beta) ** 2),
                    ],
                ]
            ).transpose(2, 0, 1),
            axes=(1, 2, 0),
        ),
        transpose(
            einsum(
                "ij...,j...",
                transpose(
                    array(
                        [
                            [
                                ones(len(-theta)),
                                zeros(len(-theta)),
                                zeros(len(-theta)),
                                zeros(len(-theta)),
                            ],
                            [
                                zeros(len(-theta)),
                                cos(-theta),
                                zeros(len(-theta)),
                                sin(-theta),
                            ],
                            [
                                zeros(len(-theta)),
                                zeros(len(-theta)),
                                ones(len(-theta)),
                                zeros(len(-theta)),
                            ],
                            [
                                zeros(len(-theta)),
                                -sin(-theta),
                                zeros(len(-theta)),
                                cos(-theta),
                            ],
                        ]
                    ).transpose(2, 0, 1),
                    axes=(1, 2, 0),
                ),
                transpose(
                    einsum(
                        "ij...,j...",
                        transpose(
                            array(
                                [
                                    [
                                        ones(len(-phi)),
                                        zeros(len(-phi)),
                                        zeros(len(-phi)),
                                        zeros(len(-phi)),
                                    ],
                                    [
                                        zeros(len(-phi)),
                                        cos(-phi),
                                        -sin(-phi),
                                        zeros(len(-phi)),
                                    ],
                                    [
                                        zeros(len(-phi)),
                                        sin(-phi),
                                        cos(-phi),
                                        zeros(len(-phi)),
                                    ],
                                    [
                                        zeros(len(-phi)),
                                        zeros(len(-phi)),
                                        zeros(len(-phi)),
                                        ones(len(-phi)),
                                    ],
                                ]
                            ).transpose(2, 0, 1),
                            axes=(1, 2, 0),
                        ),
                        transpose(p),
                    )
                ),
            )
        ),
    )

Note

The code above contains a lot of duplicate code, such as len(-phi). This could possibly be improved with CodeBlock. See ampform#166.

Angle computation#
Computing phi#

The simplest angle to compute is \(\phi\), because it’s simply \(\phi=\arctan(p_y, p_x)\), with \(p\) a four-momentum (see existing implementation). This means we would need a way to represent \(p_x\) and \(p_y\) and some container for \(\phi\) itself. For convenience, we define \(p_z\) and \(E_p\) as well.

@implement_doit_method()
class FourMomentumX(UnevaluatedExpression):
    momentum: ArraySymbol = property(lambda self: self.args[0])

    def __new__(cls, momentum: ArraySymbol, **hints) -> FourMomentumX:
        return create_expression(cls, momentum, **hints)

    def evaluate(self) -> ArraySlice:
        return ArraySlice(self.momentum, (slice(None), 1))

    def _latex(self, printer, *args) -> str:
        momentum = printer._print(self.momentum)
        return f"{{{momentum}}}_{{x}}"


@implement_doit_method()
class FourMomentumY(UnevaluatedExpression):
    momentum: ArraySymbol = property(lambda self: self.args[0])

    def __new__(cls, momentum: ArraySymbol, **hints) -> FourMomentumY:
        return create_expression(cls, momentum, **hints)

    def evaluate(self) -> ArraySlice:
        return ArraySlice(self.momentum, (slice(None), 2))

    def _latex(self, printer, *args) -> str:
        momentum = printer._print(self.momentum)
        return f"{{{momentum}}}_{{y}}"


@implement_doit_method()
class FourMomentumZ(UnevaluatedExpression):
    momentum: ArraySymbol = property(lambda self: self.args[0])

    def __new__(cls, momentum: ArraySymbol, **hints) -> FourMomentumZ:
        return create_expression(cls, momentum, **hints)

    def evaluate(self) -> ArraySlice:
        return ArraySlice(self.momentum, (slice(None), 3))

    def _latex(self, printer, *args) -> str:
        momentum = printer._print(self.momentum)
        return f"{{{momentum}}}_{{y}}"


@implement_doit_method()
class Energy(UnevaluatedExpression):
    momentum: ArraySymbol = property(lambda self: self.args[0])

    def __new__(cls, momentum: ArraySymbol, **hints) -> Energy:
        return create_expression(cls, momentum, **hints)

    def evaluate(self) -> ArraySlice:
        return ArraySlice(self.momentum, (slice(None), 0))

    def _latex(self, printer, *args) -> str:
        momentum = printer._print(self.momentum)
        return f"{{E}}_{{{momentum}}}"


@implement_doit_method()
class Phi(UnevaluatedExpression):
    momentum: ArraySymbol = property(lambda self: self.args[0])

    def __new__(cls, momentum: ArraySymbol, **hints) -> Phi:
        return create_expression(cls, momentum, **hints)

    def evaluate(self) -> sp.Expr:
        p = self.momentum
        return sp.atan2(FourMomentumY(p), FourMomentumX(p))

    def _latex(self, printer, *args) -> str:
        momentum = printer._print(self.momentum)
        return Rf"\phi\left({momentum}\right)"

The classes indeed render nicely as LaTeX:

Hide code cell source
p = ArraySymbol("p_1")
phi = Phi(p)
p_x = FourMomentumX(p)
p_y = FourMomentumY(p)
energy = Energy(p)

math_style = {"environment": "eqnarray"}
display(
    p,
    Math(sp.multiline_latex(Phi(p), phi.evaluate(), **math_style)),
    Math(sp.multiline_latex(p_x, p_x.doit(), **math_style)),
    Math(sp.multiline_latex(p_y, p_y.doit(), **math_style)),
    Math(sp.multiline_latex(energy, energy.doit(), **math_style)),
)
\[\displaystyle p_{1}\]
\[\displaystyle \begin{eqnarray} \phi\left(p_{1}\right) & = & \operatorname{atan_{2}}{\left({p_{1}}_{y},{p_{1}}_{x} \right)} \end{eqnarray}\]
\[\displaystyle \begin{eqnarray} {p_{1}}_{x} & = & p_{1}\left[:, 1\right] \end{eqnarray}\]
\[\displaystyle \begin{eqnarray} {p_{1}}_{y} & = & p_{1}\left[:, 2\right] \end{eqnarray}\]
\[\displaystyle \begin{eqnarray} {E}_{p_{1}} & = & p_{1}\left[:, 0\right] \end{eqnarray}\]

Note that the four classes like FourMomentumX expect an ArraySymbol as input. This requires sympy/sympy#22265. This allows lambdifying the above expressions to valid NumPy code. Let’s compare this with the existing implementation using the Test sample:

momentum_sample = events[0]
np.array(momentum_sample)
array([[ 1.35527  ,  0.514208 , -0.184219 ,  1.23296  ],
       [ 0.841933 ,  0.0727385, -0.0528868,  0.826163 ],
       [ 0.550927 , -0.162529 ,  0.29976  , -0.411133 ]])
np_expr = sp.lambdify(p, p_x.doit())
np_expr(momentum_sample)
array([ 0.514208 ,  0.0727385, -0.162529 ])
np_expr = sp.lambdify(p, energy.doit())
np_expr(momentum_sample)
array([1.35527 , 0.841933, 0.550927])
np_expr = sp.lambdify(p, phi.doit())
display(np.array(momentum_sample.phi()))
np_expr(momentum_sample)
array([-0.34401236, -0.62867104,  2.06762909])
array([-0.34401236, -0.62867104,  2.06762909])
Computing theta#

Computing \(\theta\) is more complicated, because requires the norm of the three-momentum of \(p\) (see existing implementation). In other words:

\[ \theta = \arccos\left(\frac{p_z}{\left|\vec{p}\right|}\right) \quad \mathrm{with} \quad \left|\vec{p}\right| = \sqrt{p_x^2+p_y^2+p_z^2} \]

The complication here is that \(\left|\vec{p}\right|\) needs to be computed with np.sum over a slice of the arrays. As of writing, it is not yet possible to write a SymPy expression that can lambdify to this or an equivalent with np.einsum (sympy/sympy#22279). So for now, an intermediate ArrayAxisSum class has to be written for this. See also this remark.

class ArrayAxisSum(sp.Expr):
    array: ArraySymbol = property(lambda self: self.args[0])
    axis: int | None = property(lambda self: self.args[1])

    def __new__(
        cls, array: ArraySymbol, axis: int | None = None, **hints
    ) -> ArrayAxisSum:
        if axis is not None and not isinstance(axis, (int, sp.Integer)):
            msg = "Only single digits allowed for axis"
            raise TypeError(msg)
        return create_expression(cls, array, axis, **hints)

    def _latex(self, printer, *args) -> str:
        A = printer._print(self.array)
        if self.axis is None:
            return Rf"\sum{{{A}}}"
        axis = printer._print(self.axis)
        return Rf"\sum_{{\mathrm{{axis{axis}}}}}{{{A}}}"

Looks nice as LaTeX:

A = ArraySymbol("A")
display(
    ArrayAxisSum(A, axis=1),
    ArrayAxisSum(A),
)
\[\displaystyle \sum_{\mathrm{axis1}}{A}\]
\[\displaystyle \sum{A}\]

Now let’s define a printer method for NumPy:

def _print_ArrayAxisSum(self: NumPyPrinter, expr: ArrayAxisSum) -> str:
    self.module_imports["numpy"].add("sum")
    array = self._print(expr.array)
    axis = self._print(expr.axis)
    return f"sum({array}, axis={axis})"


NumPyPrinter._print_ArrayAxisSum = _print_ArrayAxisSum
print_lambdify(A, ArrayAxisSum(A, axis=1))
def _lambdifygenerated(A):
    return sum(A, axis=1)

…and let’s check whether it works as expected for a 3-dimensional array:

array = np.array(range(12)).reshape(2, 3, 2)
array
array([[[ 0,  1],
        [ 2,  3],
        [ 4,  5]],

       [[ 6,  7],
        [ 8,  9],
        [10, 11]]])
np_expr = sp.lambdify(A, ArrayAxisSum(A))
np_expr(array)
66
np_expr = sp.lambdify(A, ArrayAxisSum(A, axis=0))
np_expr(array)
array([[ 6,  8],
       [10, 12],
       [14, 16]])
np_expr = sp.lambdify(A, ArrayAxisSum(A, axis=1))
np_expr(array)
array([[ 6,  9],
       [24, 27]])
np_expr = sp.lambdify(A, ArrayAxisSum(A, axis=2))
np_expr(array)
array([[ 1,  5,  9],
       [13, 17, 21]])

Now we’re ready to define a class that can represent \(\left|\vec{p}\right|\) and that lambdifies to the expressions given by the existing implementation.

@implement_doit_method()
class ThreeMomentumNorm(UnevaluatedExpression):
    momentum: ArraySymbol = property(lambda self: self.args[0])

    def __new__(cls, momentum: ArraySymbol, **hints) -> ThreeMomentumNorm:
        return create_expression(cls, momentum, **hints)

    def evaluate(self) -> ArraySlice:
        three_momentum = ArraySlice(self.momentum, (slice(None), slice(1, None)))
        norm_squared = ArrayAxisSum(three_momentum**2, axis=1)
        return sp.sqrt(norm_squared)

    def _latex(self, printer, *args) -> str:
        three_momentum = printer._print(self.momentum)
        return Rf"\left|\vec{{{three_momentum}}}\right|"

    def _numpycode(self, printer, *args) -> str:
        return printer._print(self.evaluate())
Hide code cell source
p_norm = ThreeMomentumNorm(p)
Math(sp.multiline_latex(p_norm, p_norm.doit(), **math_style))
\[\displaystyle \begin{eqnarray} \left|\vec{p_{1}}\right| & = & \sqrt{\sum_{\mathrm{axis1}}{p_{1}\left[:, 1:\right]^{2}}} \end{eqnarray}\]
np_expr = sp.lambdify(p, p_norm.doit())
np_expr(momentum_sample)
array([1.34853137, 0.83104344, 0.53413676])

With that, we’re ready to define the Theta class!

@implement_doit_method()
class Theta(UnevaluatedExpression):
    momentum: ArraySymbol = property(lambda self: self.args[0])

    def __new__(cls, momentum: ArraySymbol, **hints) -> Theta:
        return create_expression(cls, momentum, **hints)

    def evaluate(self) -> sp.Expr:
        p = self.momentum
        return sp.acos(FourMomentumZ(p) / ThreeMomentumNorm(p))

    def _latex(self, printer, *args) -> str:
        momentum = printer._print(self.momentum)
        return Rf"\theta\left({momentum}\right)"

The math doesn’t look the best, but this due to the ArrayMultiplication class.

Hide code cell source
Math(sp.multiline_latex(Theta(p), Theta(p).doit(), **math_style))
\[\displaystyle \begin{eqnarray} \theta\left(p_{1}\right) & = & \operatorname{acos}{\left(p_{1}\left[:, 3\right] \frac{1}{\sqrt{\sum_{\mathrm{axis1}}{p_{1}\left[:, 1:\right]^{2}}}} \right)} \end{eqnarray}\]

At any rate, when we lambdify the whole thing, we have exactly the same result as the original theta() method gave!

np_expr = sp.lambdify(p, Theta(p).doit())
display(np.array(momentum_sample.theta()))
np_expr(momentum_sample)
array([0.41702412, 0.10842903, 2.4491907 ])
array([0.41702412, 0.10842903, 2.4491907 ])
Recursive angle computation#

Finally, we are ready to compute kinematic angles recursively from a Topology, just as in the existing implementation, but now with SymPy only.

def create_event_collection_from_topology(
    topology: Topology,
) -> dict[int, ArraySymbol]:
    n_final_states = len(topology.outgoing_edge_ids)
    return {i: ArraySymbol(f"p{i}") for i in range(n_final_states)}


momentum_symbols = create_event_collection_from_topology(topology)
momentum_symbols
{0: p0, 1: p1, 2: p2, 3: p3}

Now it’s quite trivial to rewrite the existing implementation with the classes defined above.

def compute_helicity_angles(
    events: dict[int, ArraySymbol], topology: Topology
) -> dict[str, sp.Expr]:
    if topology.outgoing_edge_ids != set(events):
        msg = (
            f"Momentum IDs {set(events)} do not match final state edge IDs"
            f" {set(topology.outgoing_edge_ids)}"
        )
        raise ValueError(msg)

    def __recursive_helicity_angles(
        events: dict[int, ArraySymbol], node_id: int
    ) -> dict[str, sp.Expr]:
        helicity_angles: dict[str, sp.Expr] = {}
        child_state_ids = sorted(topology.get_edge_ids_outgoing_from_node(node_id))
        if all(topology.edges[i].ending_node_id is None for i in child_state_ids):
            state_id = child_state_ids[0]
            four_momentum = events[state_id]
            phi_label, theta_label = get_helicity_angle_label(topology, state_id)
            helicity_angles[phi_label] = Phi(four_momentum)
            helicity_angles[theta_label] = Theta(four_momentum)
        for state_id in child_state_ids:
            edge = topology.edges[state_id]
            if edge.ending_node_id is not None:
                # recursively determine all momenta ids in the list
                sub_momenta_ids = determine_attached_final_state(topology, state_id)
                if len(sub_momenta_ids) > 1:
                    # add all of these momenta together -> defines new subsystem
                    four_momentum = sum(events[i] for i in sub_momenta_ids)

                    # boost all of those momenta into this new subsystem
                    phi = Phi(four_momentum)
                    theta = Theta(four_momentum)
                    p3_norm = ThreeMomentumNorm(four_momentum)
                    beta = p3_norm / Energy(four_momentum)
                    new_momentum_pool = {
                        k: ArrayMultiplication(
                            BoostZ(beta),
                            RotationY(-theta),
                            RotationZ(-phi),
                            p,
                        )
                        for k, p in events.items()
                        if k in sub_momenta_ids
                    }

                    # register current angle variables
                    phi_label, theta_label = get_helicity_angle_label(
                        topology, state_id
                    )
                    helicity_angles[phi_label] = Phi(four_momentum)
                    helicity_angles[theta_label] = Theta(four_momentum)

                    # call next recursion
                    angles = __recursive_helicity_angles(
                        new_momentum_pool,
                        edge.ending_node_id,
                    )
                    helicity_angles.update(angles)

        return helicity_angles

    initial_state_id = next(iter(topology.incoming_edge_ids))
    initial_state_edge = topology.edges[initial_state_id]
    assert initial_state_edge.ending_node_id is not None
    return __recursive_helicity_angles(events, initial_state_edge.ending_node_id)

The computation works indeed and can be rendered to both LaTeX and numpy!

symbolic_angles = compute_helicity_angles(momentum_symbols, topology)
list(symbolic_angles)
['phi_1+2+3',
 'theta_1+2+3',
 'phi_2+3,1+2+3',
 'theta_2+3,1+2+3',
 'phi_2,2+3,1+2+3',
 'theta_2,2+3,1+2+3']
Hide code cell source
def display_kinematic_variable(name: str) -> None:
    expr = symbolic_angles[name]
    display(Math(sp.multiline_latex(expr, expr.doit(), **math_style)))
    print_lambdify(momentum_symbols.values(), expr.doit())


for angle_name in list(symbolic_angles)[:3]:
    display_kinematic_variable(angle_name)
    print()
\[\displaystyle \begin{eqnarray} \phi\left(p_{1} + p_{2} + p_{3}\right) & = & \operatorname{atan_{2}}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 2\right],\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1\right] \right)} \end{eqnarray}\]
def _lambdifygenerated(p0, p1, p2, p3):
    return arctan2((p1 + p2 + p3)[:, 2], (p1 + p2 + p3)[:, 1])
\[\displaystyle \begin{eqnarray} \theta\left(p_{1} + p_{2} + p_{3}\right) & = & \operatorname{acos}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 3\right] \frac{1}{\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}}} \right)} \end{eqnarray}\]
def _lambdifygenerated(p0, p1, p2, p3):
    return arccos(
        (p1 + p2 + p3)[:, 3]
        * sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1) ** (-1 / 2)
    )
\[\displaystyle \begin{eqnarray} \phi\left(\boldsymbol{B_z}\left(\left|\vec{p_{1} + p_{2} + p_{3}}\right| {E}_{p_{1} + p_{2} + p_{3}}^{-1}\right) \boldsymbol{R_y}\left(- \theta\left(p_{1} + p_{2} + p_{3}\right)\right) \boldsymbol{R_z}\left(- \phi\left(p_{1} + p_{2} + p_{3}\right)\right) p_{2} + \boldsymbol{B_z}\left(\left|\vec{p_{1} + p_{2} + p_{3}}\right| {E}_{p_{1} + p_{2} + p_{3}}^{-1}\right) \boldsymbol{R_y}\left(- \theta\left(p_{1} + p_{2} + p_{3}\right)\right) \boldsymbol{R_z}\left(- \phi\left(p_{1} + p_{2} + p_{3}\right)\right) p_{3}\right) & = & \operatorname{atan_{2}}{\left(\left(\boldsymbol{B_z}\left(\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}} \left(p_{1} + p_{2} + p_{3}\right)\left[:, 0\right]^{-1}\right) \boldsymbol{R_y}\left(- \operatorname{acos}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 3\right] \frac{1}{\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}}} \right)}\right) \boldsymbol{R_z}\left(- \operatorname{atan_{2}}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 2\right],\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1\right] \right)}\right) p_{2} + \boldsymbol{B_z}\left(\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}} \left(p_{1} + p_{2} + p_{3}\right)\left[:, 0\right]^{-1}\right) \boldsymbol{R_y}\left(- \operatorname{acos}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 3\right] \frac{1}{\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}}} \right)}\right) \boldsymbol{R_z}\left(- \operatorname{atan_{2}}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 2\right],\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1\right] \right)}\right) p_{3}\right)\left[:, 2\right],\left(\boldsymbol{B_z}\left(\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}} \left(p_{1} + p_{2} + p_{3}\right)\left[:, 0\right]^{-1}\right) \boldsymbol{R_y}\left(- \operatorname{acos}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 3\right] \frac{1}{\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}}} \right)}\right) \boldsymbol{R_z}\left(- \operatorname{atan_{2}}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 2\right],\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1\right] \right)}\right) p_{2} + \boldsymbol{B_z}\left(\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}} \left(p_{1} + p_{2} + p_{3}\right)\left[:, 0\right]^{-1}\right) \boldsymbol{R_y}\left(- \operatorname{acos}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 3\right] \frac{1}{\sqrt{\sum_{\mathrm{axis1}}{\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1:\right]^{2}}}} \right)}\right) \boldsymbol{R_z}\left(- \operatorname{atan_{2}}{\left(\left(p_{1} + p_{2} + p_{3}\right)\left[:, 2\right],\left(p_{1} + p_{2} + p_{3}\right)\left[:, 1\right] \right)}\right) p_{3}\right)\left[:, 1\right] \right)} \end{eqnarray}\]
def _lambdifygenerated(p0, p1, p2, p3):
    return arctan2(
        (
            einsum(
                "ij...,j...",
                transpose(
                    array(
                        [
                            [
                                1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                -1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                )
                                * sqrt(sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1))
                                * (p1 + p2 + p3)[:, 0] ** (-1.0),
                            ],
                            [
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                ones(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                            ],
                            [
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                ones(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                            ],
                            [
                                -1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                )
                                * sqrt(sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1))
                                * (p1 + p2 + p3)[:, 0] ** (-1.0),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                ),
                            ],
                        ]
                    ).transpose(2, 0, 1),
                    axes=(1, 2, 0),
                ),
                transpose(
                    einsum(
                        "ij...,j...",
                        transpose(
                            array(
                                [
                                    [
                                        ones(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        cos(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        sin(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        ones(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        -sin(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        cos(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                    ],
                                ]
                            ).transpose(2, 0, 1),
                            axes=(1, 2, 0),
                        ),
                        transpose(
                            einsum(
                                "ij...,j...",
                                transpose(
                                    array(
                                        [
                                            [
                                                ones(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                cos(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                -sin(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                sin(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                cos(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                ones(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                        ]
                                    ).transpose(2, 0, 1),
                                    axes=(1, 2, 0),
                                ),
                                transpose(p2),
                            )
                        ),
                    )
                ),
            )
            + einsum(
                "ij...,j...",
                transpose(
                    array(
                        [
                            [
                                1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                -1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                )
                                * sqrt(sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1))
                                * (p1 + p2 + p3)[:, 0] ** (-1.0),
                            ],
                            [
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                ones(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                            ],
                            [
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                ones(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                            ],
                            [
                                -1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                )
                                * sqrt(sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1))
                                * (p1 + p2 + p3)[:, 0] ** (-1.0),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                ),
                            ],
                        ]
                    ).transpose(2, 0, 1),
                    axes=(1, 2, 0),
                ),
                transpose(
                    einsum(
                        "ij...,j...",
                        transpose(
                            array(
                                [
                                    [
                                        ones(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        cos(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        sin(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        ones(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        -sin(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        cos(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                    ],
                                ]
                            ).transpose(2, 0, 1),
                            axes=(1, 2, 0),
                        ),
                        transpose(
                            einsum(
                                "ij...,j...",
                                transpose(
                                    array(
                                        [
                                            [
                                                ones(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                cos(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                -sin(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                sin(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                cos(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                ones(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                        ]
                                    ).transpose(2, 0, 1),
                                    axes=(1, 2, 0),
                                ),
                                transpose(p3),
                            )
                        ),
                    )
                ),
            )
        )[:, 2],
        (
            einsum(
                "ij...,j...",
                transpose(
                    array(
                        [
                            [
                                1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                -1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                )
                                * sqrt(sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1))
                                * (p1 + p2 + p3)[:, 0] ** (-1.0),
                            ],
                            [
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                ones(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                            ],
                            [
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                ones(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                            ],
                            [
                                -1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                )
                                * sqrt(sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1))
                                * (p1 + p2 + p3)[:, 0] ** (-1.0),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                ),
                            ],
                        ]
                    ).transpose(2, 0, 1),
                    axes=(1, 2, 0),
                ),
                transpose(
                    einsum(
                        "ij...,j...",
                        transpose(
                            array(
                                [
                                    [
                                        ones(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        cos(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        sin(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        ones(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        -sin(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        cos(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                    ],
                                ]
                            ).transpose(2, 0, 1),
                            axes=(1, 2, 0),
                        ),
                        transpose(
                            einsum(
                                "ij...,j...",
                                transpose(
                                    array(
                                        [
                                            [
                                                ones(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                cos(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                -sin(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                sin(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                cos(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                ones(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                        ]
                                    ).transpose(2, 0, 1),
                                    axes=(1, 2, 0),
                                ),
                                transpose(p2),
                            )
                        ),
                    )
                ),
            )
            + einsum(
                "ij...,j...",
                transpose(
                    array(
                        [
                            [
                                1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                -1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                )
                                * sqrt(sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1))
                                * (p1 + p2 + p3)[:, 0] ** (-1.0),
                            ],
                            [
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                ones(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                            ],
                            [
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                ones(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                            ],
                            [
                                -1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                )
                                * sqrt(sum((p1 + p2 + p3)[:, 1:] ** 2, axis=1))
                                * (p1 + p2 + p3)[:, 0] ** (-1.0),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                zeros(
                                    len(
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                ),
                                1
                                / sqrt(
                                    1
                                    - (
                                        sqrt(
                                            sum(
                                                (p1 + p2 + p3)[:, 1:] ** 2,
                                                axis=1,
                                            )
                                        )
                                        * (p1 + p2 + p3)[:, 0] ** (-1.0)
                                    )
                                    ** 2
                                ),
                            ],
                        ]
                    ).transpose(2, 0, 1),
                    axes=(1, 2, 0),
                ),
                transpose(
                    einsum(
                        "ij...,j...",
                        transpose(
                            array(
                                [
                                    [
                                        ones(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        cos(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        sin(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        ones(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                    ],
                                    [
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        -sin(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                        zeros(
                                            len(
                                                -arccos(
                                                    (p1 + p2 + p3)[:, 3]
                                                    * sum(
                                                        (p1 + p2 + p3)[:, 1:]
                                                        ** 2,
                                                        axis=1,
                                                    )
                                                    ** (-1 / 2)
                                                )
                                            )
                                        ),
                                        cos(
                                            -arccos(
                                                (p1 + p2 + p3)[:, 3]
                                                * sum(
                                                    (p1 + p2 + p3)[:, 1:] ** 2,
                                                    axis=1,
                                                )
                                                ** (-1 / 2)
                                            )
                                        ),
                                    ],
                                ]
                            ).transpose(2, 0, 1),
                            axes=(1, 2, 0),
                        ),
                        transpose(
                            einsum(
                                "ij...,j...",
                                transpose(
                                    array(
                                        [
                                            [
                                                ones(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                cos(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                -sin(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                sin(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                cos(
                                                    -arctan2(
                                                        (p1 + p2 + p3)[:, 2],
                                                        (p1 + p2 + p3)[:, 1],
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                            [
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                zeros(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                                ones(
                                                    len(
                                                        -arctan2(
                                                            (p1 + p2 + p3)[
                                                                :, 2
                                                            ],
                                                            (p1 + p2 + p3)[
                                                                :, 1
                                                            ],
                                                        )
                                                    )
                                                ),
                                            ],
                                        ]
                                    ).transpose(2, 0, 1),
                                    axes=(1, 2, 0),
                                ),
                                transpose(p3),
                            )
                        ),
                    )
                ),
            )
        )[:, 1],
    )
Comparison computed value#

Finally, what’s left is to compare computed data with the Test sample. Here’s a little comparison function that can be called for each angle:

Hide code cell source
def compare(angle_name: str) -> None:
    np_expr = sp.lambdify(
        args=momentum_symbols.values(),
        expr=symbolic_angles[angle_name].doit(),
        modules="numpy",
    )
    computed = np_expr(*events.values())
    expected = np.array(angles[angle_name])
    display(computed, expected)
    np.testing.assert_allclose(computed, expected)
for angle_name in symbolic_angles:
    print(angle_name)
    %time compare(angle_name)
    print()
phi_1+2+3
array([ 2.79758029,  2.51292308, -1.07396684])
array([ 2.79758029,  2.51292308, -1.07396684])
CPU times: user 2.48 ms, sys: 6.56 ms, total: 9.04 ms
Wall time: 7.77 ms

theta_1+2+3
array([2.72456853, 3.03316287, 0.69240082])
array([2.72456853, 3.03316287, 0.69240082])
CPU times: user 5.13 ms, sys: 297 Âľs, total: 5.43 ms
Wall time: 4.93 ms

phi_2+3,1+2+3
array([1.0436215 , 1.8734936 , 0.16073833])
array([1.0436215 , 1.8734936 , 0.16073833])
CPU times: user 34.1 ms, sys: 0 ns, total: 34.1 ms
Wall time: 33.1 ms

theta_2+3,1+2+3
array([2.45361589, 1.40639741, 0.98079245])
array([2.45361589, 1.40639741, 0.98079245])
CPU times: user 36.1 ms, sys: 451 Âľs, total: 36.6 ms
Wall time: 35.5 ms

phi_2,2+3,1+2+3
array([ 0.36955786, -1.68820498,  0.63063002])
array([ 0.36955786, -1.68820498,  0.63063002])
CPU times: user 1.25 s, sys: 162 ms, total: 1.42 s
Wall time: 1.41 s

theta_2,2+3,1+2+3
array([1.0924374 , 1.99375767, 1.31959621])
array([1.0924374 , 1.99375767, 1.31959621])
CPU times: user 1.22 s, sys: 166 ms, total: 1.38 s
Wall time: 1.38 s

Extended DataSample performance#

Hide code cell content
import logging

import ampform
import numpy as np
import qrules
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)
from tensorwaves.data import (
    IntensityDistributionGenerator,
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)
from tensorwaves.function.sympy import create_parametrized_function

LOGGER = logging.getLogger("absl")
LOGGER.setLevel(logging.ERROR)
LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)
Generate amplitude model#

Formulate a HelicityModel just like in the usual workflow:

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism="helicity",
)

builder = ampform.get_builder(reaction)
builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = builder.formulate()

Now register more topologies with HelicityAdapter.permutate_registered_topologies() and formulate a new ‘extended’ model:

builder.adapter.permutate_registered_topologies()
extended_model = builder.formulate()
Create computational functions#

Now, create ParametrizedFunctions for the normal model and the extended model:

intensity = create_parametrized_function(
    expression=model.expression.doit(),
    parameters=model.parameter_defaults,
    backend="jax",
)
helicity_transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)
extended_intensity = create_parametrized_function(
    expression=extended_model.expression.doit(),
    parameters=extended_model.parameter_defaults,
    backend="jax",
)
extended_helicity_transformer = SympyDataTransformer.from_sympy(
    extended_model.kinematic_variables, backend="jax"
)
Generate data#

Generate phase space domain and hit-and-miss data sample with the normal intensity function and helicity transformer…

phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
data_generator = IntensityDistributionGenerator(
    function=intensity,
    domain_generator=phsp_generator,
    domain_transformer=helicity_transformer,
)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_momenta = phsp_generator.generate(100_000, rng)
data_momenta = data_generator.generate(10_000, rng)
phsp = helicity_transformer(phsp_momenta)
data = helicity_transformer(data_momenta)

…and with the extended function and transformer:

extended_phsp_generator = TFPhaseSpaceGenerator(
    # actually same as phsp_generator
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
extended_data_generator = IntensityDistributionGenerator(
    function=extended_intensity,
    domain_generator=phsp_generator,
    domain_transformer=helicity_transformer,
)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_momenta = extended_phsp_generator.generate(100_000, rng)
data_momenta = extended_data_generator.generate(10_000, rng)
extended_phsp = extended_helicity_transformer(phsp_momenta)
extended_data = extended_helicity_transformer(data_momenta)
Conclusion#
intensities = intensity(phsp)
extended_intensities = extended_intensity(extended_phsp)
extended_intensities.shape
(100000,)

Computation time per iteration is the same:

%timeit -n10 intensity(phsp)
%timeit -n10 extended_intensity(extended_phsp)
14.7 ms Âą 761 Âľs per loop (mean Âą std. dev. of 7 runs, 10 loops each)
14.7 ms Âą 669 Âľs per loop (mean Âą std. dev. of 7 runs, 10 loops each)

Output arrays are also the same:

np.testing.assert_allclose(intensities, extended_intensities)

assert set(data) < set(extended_data)
assert set(phsp) < set(extended_phsp)
for var in data:
    np.testing.assert_allclose(phsp[var], extended_phsp[var])
    np.testing.assert_allclose(data[var], extended_data[var])

Spin alignment with data#

Hide code cell content
%config InlineBackend.figure_formats = ['svg']
import logging
import warnings

from IPython.display import display

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")
Phase space sample#
Hide code cell content
import qrules

PDG = qrules.load_pdg()
from tensorwaves.data import TFPhaseSpaceGenerator, TFUniformRealNumberGenerator

phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=PDG["Lambda(c)+"].mass,
    final_state_masses={
        0: PDG["p"].mass,
        1: PDG["K-"].mass,
        2: PDG["pi+"].mass,
    },
)
rng = TFUniformRealNumberGenerator(seed=0)
phsp_momenta = phsp_generator.generate(1_000_000, rng)
Generate transitions#
Hide code cell content
from qrules.particle import ParticleCollection, create_particle

particle_db = ParticleCollection()
particle_db.add(PDG["Lambda(c)+"])
particle_db.add(PDG["p"])
particle_db.add(PDG["K-"])
particle_db.add(PDG["pi+"])

particle_db.add(
    create_particle(
        PDG["K*(892)0"],
        name="K*",
        latex="K^*",
    )
)
particle_db.add(
    create_particle(
        PDG["Lambda(1405)"],
        name="Lambda*",
        latex=R"\Lambda^*",
    )
)
particle_db.add(
    create_particle(
        PDG["Delta(1232)++"],
        name="Delta*++",
        latex=R"\Delta^*",
    )
)
reaction = qrules.generate_transitions(
    initial_state=("Lambda(c)+", [-0.5, +0.5]),
    final_state=["p", "K-", "pi+"],
    formalism="helicity",
    particle_db=particle_db,
)
Hide code cell source
import graphviz

n = len(reaction.transitions)
for i, t in enumerate(reaction.transitions[:: n // 3]):
    dot = qrules.io.asdot([t], collapse_graphs=True, size=3.5)
    graph = graphviz.Source(dot)
    graph.render(f"013-graph{i}", format="svg")
    display(graph)

013_12_0 013_12_1 013_12_2

Distribution without alignment#

Amplitude model formulated following Appendix C:

import ampform
from ampform.dynamics.builder import RelativisticBreitWignerBuilder

builder = ampform.get_builder(reaction)
builder.align_spin = False
builder.stable_final_state_ids = list(reaction.final_state)
builder.scalar_initial_state_mass = True
bw_builder = RelativisticBreitWignerBuilder()
for name in reaction.get_intermediate_particles().names:
    builder.set_dynamics(name, bw_builder)
standard_model = builder.formulate()
standard_model.intensity
\[\displaystyle \sum_{m_{A}=-1/2}^{1/2} \sum_{m_{0}=-1/2}^{1/2} \sum_{m_{1}=0} \sum_{m_{2}=0}{\left|{{A^{01}}_{m_{A},m_{0},m_{1},m_{2}} + {A^{02}}_{m_{A},m_{0},m_{1},m_{2}} + {A^{12}}_{m_{A},m_{0},m_{1},m_{2}}}\right|^{2}}\]
Hide code cell source
import sympy as sp
from IPython.display import Math, display

for i, (symbol, expr) in enumerate(standard_model.amplitudes.items()):
    if i == 3:
        display(Math(R"\dots"))
        break
    latex = sp.multiline_latex(symbol, expr, environment="eqnarray")
    display(Math(latex))

Importing the parameter values given by Table 1:

Hide code cell content
from ampform.helicity import HelicityModel

# fmt: off
parameter_table = {
    # K*
    R"C_{\Lambda_{c}^{+} \to K^*_{0} p_{+1/2}; K^* \to K^{-}_{0} \pi^{+}_{0}}": 1,
    R"C_{\Lambda_{c}^{+} \to K^*_{+1} p_{+1/2}; K^* \to K^{-}_{0} \pi^{+}_{0}}": 0.5 + 0.5j,
    R"C_{\Lambda_{c}^{+} \to K^*_{-1} p_{-1/2}; K^* \to K^{-}_{0} \pi^{+}_{0}}": 1j,
    R"C_{\Lambda_{c}^{+} \to K^*_{0} p_{-1/2}; K^* \to K^{-}_{0} \pi^{+}_{0}}": -0.5 - 0.5j,
    "m_{K^*}": 0.9,  # GeV
    R"\Gamma_{K^*}": 0.2,  # GeV
    # Λ*
    R"C_{\Lambda_{c}^{+} \to \Lambda^*_{-1/2} \pi^{+}_{0}; \Lambda^* \to K^{-}_{0} p_{+1/2}}": 1j,
    R"C_{\Lambda_{c}^{+} \to \Lambda^*_{+1/2} \pi^{+}_{0}; \Lambda^* \to K^{-}_{0} p_{+1/2}}": 0.8 - 0.4j,
    R"m_{\Lambda^*}": 1.6,  # GeV
    R"\Gamma_{\Lambda^*}": 0.2,  # GeV
    # Δ*
    R"C_{\Lambda_{c}^{+} \to \Delta^*_{+1/2} K^{-}_{0}; \Delta^* \to p_{+1/2} \pi^{+}_{0}}": 0.6 - 0.4j,
    R"C_{\Lambda_{c}^{+} \to \Delta^*_{-1/2} K^{-}_{0}; \Delta^* \to p_{+1/2} \pi^{+}_{0}}": 0.1j,
    R"m_{\Delta^*}": 1.4,  # GeV
    R"\Gamma_{\Delta^*}": 0.2,  # GeV
}
# fmt: on


def set_coefficients(model: HelicityModel) -> None:
    for name, value in parameter_table.items():
        model.parameter_defaults[name] = value
Hide code cell source
set_coefficients(standard_model)

latex = R"\begin{array}{lc}" + "\n"
for par_name, value in parameter_table.items():
    value = str(value).lstrip("(").rstrip(")").replace("j", "i")
    symbol = sp.Symbol(par_name)
    latex += Rf"  {sp.latex(symbol)} & {value} \\" + "\n"
latex += R"\end{array}"
Math(latex)
\[\begin{split}\displaystyle \begin{array}{lc} C_{\Lambda_{c}^{+} \to K^*_{0} p_{+1/2}; K^* \to K^{-}_{0} \pi^{+}_{0}} & 1 \\ C_{\Lambda_{c}^{+} \to K^*_{+1} p_{+1/2}; K^* \to K^{-}_{0} \pi^{+}_{0}} & 0.5+0.5i \\ C_{\Lambda_{c}^{+} \to K^*_{-1} p_{-1/2}; K^* \to K^{-}_{0} \pi^{+}_{0}} & 1i \\ C_{\Lambda_{c}^{+} \to K^*_{0} p_{-1/2}; K^* \to K^{-}_{0} \pi^{+}_{0}} & -0.5-0.5i \\ m_{K^*} & 0.9 \\ \Gamma_{K^*} & 0.2 \\ C_{\Lambda_{c}^{+} \to \Lambda^*_{-1/2} \pi^{+}_{0}; \Lambda^* \to K^{-}_{0} p_{+1/2}} & 1i \\ C_{\Lambda_{c}^{+} \to \Lambda^*_{+1/2} \pi^{+}_{0}; \Lambda^* \to K^{-}_{0} p_{+1/2}} & 0.8-0.4i \\ m_{\Lambda^*} & 1.6 \\ \Gamma_{\Lambda^*} & 0.2 \\ C_{\Lambda_{c}^{+} \to \Delta^*_{+1/2} K^{-}_{0}; \Delta^* \to p_{+1/2} \pi^{+}_{0}} & 0.6-0.4i \\ C_{\Lambda_{c}^{+} \to \Delta^*_{-1/2} K^{-}_{0}; \Delta^* \to p_{+1/2} \pi^{+}_{0}} & 0.1i \\ m_{\Delta^*} & 1.4 \\ \Gamma_{\Delta^*} & 0.2 \\ \end{array}\end{split}\]
Generate data#
Hide code cell content
import matplotlib.pyplot as plt
import numpy as np
from tensorwaves.data import SympyDataTransformer
from tensorwaves.function.sympy import create_function


def compute_sub_intensities(
    model: HelicityModel, resonance_name: str, phsp, full_expression
) -> np.ndarray:
    parameter_values = {}
    for symbol, value in model.parameter_defaults.items():
        if resonance_name not in symbol.name and symbol.name.startswith("C"):
            parameter_values[symbol] = 0
        else:
            parameter_values[symbol] = value
    sub_expression = full_expression.subs(parameter_values)
    sub_intensity = create_function(sub_expression, backend="jax")
    return np.array(sub_intensity(phsp).real)


def plot_distributions(model: HelicityModel) -> None:
    helicity_transformer = SympyDataTransformer.from_sympy(
        model.kinematic_variables, backend="jax"
    )
    phsp = helicity_transformer(phsp_momenta)
    phsp = {k: v.real for k, v in phsp.items()}

    full_expression = model.expression.doit()
    substituted_expression = full_expression.xreplace(model.parameter_defaults)
    intensity_func = create_function(substituted_expression, backend="jax")
    intensities_all = np.array(intensity_func(phsp).real)
    intensities_k = compute_sub_intensities(model, "K^*", phsp, full_expression)
    intensities_delta = compute_sub_intensities(
        model, "Delta^*", phsp, full_expression
    )
    intensities_lambda = compute_sub_intensities(
        model, "Lambda^*", phsp, full_expression
    )

    fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 5))
    hist_kwargs = {
        "bins": 80,
        "histtype": "step",
    }

    for x in ax.flatten():
        x.set_yticks([])

    ax[0, 0].set_xlabel("$m^2(pK^-)$ [GeV$^2/c^4$]")
    ax[0, 1].set_xlabel(R"$m^2(K^-\pi^+)$ [GeV$^2/c^4$]")
    ax[0, 2].set_xlabel(R"$m^2(p\pi^+)$ [GeV$^2/c^4$]")
    ax[1, 0].set_xlabel(R"$\cos\theta(p)$")
    ax[1, 1].set_xlabel(R"$\phi(p)$")
    ax[1, 2].set_xlabel(R"$\chi$")

    for x, xticks in {
        ax[0, 0]: [2, 2.5, 3, 3.5, 4, 4.5],
        ax[0, 1]: [0.4, 0.6, 0.8, 1, 1.2, 1.4, 1.6, 1.8, 2],
        ax[0, 2]: [1, 1.5, 2, 2.5, 3],
        ax[1, 0]: [-1, -0.5, 0, 0.5, 1],
        ax[1, 1]: [-3, -2, -1, 0, 1, 2, 3],
    }.items():
        x.set_xticks(xticks)
        x.set_xticklabels(xticks)

    for weights, color, label in [
        (intensities_all, "red", "Model"),
        (intensities_k, "orange", R"$K^*\to\,K^{^-}\pi^+$"),
        (intensities_delta, "brown", R"$\Delta^{*^{++}} \to\,p\pi^+$"),
        (intensities_lambda, "purple", R"$\Lambda^* \to\,p K^{^-}$"),
    ]:
        kwargs = dict(weights=weights, color=color, **hist_kwargs)
        ax[0, 0].hist(np.array(phsp["m_01"] ** 2), **kwargs)
        ax[0, 1].hist(np.array(phsp["m_12"] ** 2), **kwargs)
        ax[0, 2].hist(np.array(phsp["m_02"] ** 2), **kwargs)
        ax[1, 0].hist(np.array(np.cos(phsp["theta_01"])), **kwargs)
        ax[1, 1].hist(np.array(phsp["phi_01"]), **kwargs, label=label)

    ax[1, 2].remove()
    handles, labels = ax[1, 1].get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower right")

    ax[0, 2].set_xlim(1, 3.4)
    ax[1, 0].set_xlim(-1, +1)
    ax[1, 1].set_xlim(-np.pi, +np.pi)

    fig.tight_layout()

    plt.show()

Warning

It takes several minutes to lambdify the full expression and expressions for the Wigner rotation angles.

Hide code cell source
plot_distributions(standard_model)

013_23_0

Spin alignment sum#

Now, with the spin alignment sum from ampform#245 inserted:

builder.align_spin = True
aligned_model = builder.formulate()
set_coefficients(aligned_model)
aligned_model.intensity
\[\displaystyle \sum_{m_{A}=-1/2}^{1/2} \sum_{m_{0}=-1/2}^{1/2} \sum_{m_{1}=0} \sum_{m_{2}=0}{\left|{\sum_{\lambda^{01}_{0}=-1/2}^{1/2} \sum_{\mu^{01}_{0}=-1/2}^{1/2} \sum_{\nu^{01}_{0}=-1/2}^{1/2} \sum_{\lambda^{01}_{1}=0} \sum_{\mu^{01}_{1}=0} \sum_{\nu^{01}_{1}=0} \sum_{\lambda^{01}_{2}=0}{{A^{01}}_{m_{A},\lambda^{01}_{0},- \lambda^{01}_{1},- \lambda^{01}_{2}} D^{0}_{m_{1},\nu^{01}_{1}}\left(\alpha^{01}_{1},\beta^{01}_{1},\gamma^{01}_{1}\right) D^{0}_{m_{2},\lambda^{01}_{2}}\left(\phi_{01},\theta_{01},0\right) D^{0}_{\mu^{01}_{1},\lambda^{01}_{1}}\left(\phi^{01}_{0},\theta^{01}_{0},0\right) D^{0}_{\nu^{01}_{1},\mu^{01}_{1}}\left(\phi_{01},\theta_{01},0\right) D^{\frac{1}{2}}_{m_{0},\nu^{01}_{0}}\left(\alpha^{01}_{0},\beta^{01}_{0},\gamma^{01}_{0}\right) D^{\frac{1}{2}}_{\mu^{01}_{0},\lambda^{01}_{0}}\left(\phi^{01}_{0},\theta^{01}_{0},0\right) D^{\frac{1}{2}}_{\nu^{01}_{0},\mu^{01}_{0}}\left(\phi_{01},\theta_{01},0\right)} + \sum_{\lambda^{02}_{0}=-1/2}^{1/2} \sum_{\mu^{02}_{0}=-1/2}^{1/2} \sum_{\nu^{02}_{0}=-1/2}^{1/2} \sum_{\lambda^{02}_{1}=0} \sum_{\lambda^{02}_{2}=0} \sum_{\mu^{02}_{2}=0} \sum_{\nu^{02}_{2}=0}{{A^{02}}_{m_{A},\lambda^{02}_{0},- \lambda^{02}_{1},- \lambda^{02}_{2}} D^{0}_{m_{1},\lambda^{02}_{1}}\left(\phi_{02},\theta_{02},0\right) D^{0}_{m_{2},\nu^{02}_{2}}\left(\alpha^{02}_{2},\beta^{02}_{2},\gamma^{02}_{2}\right) D^{0}_{\mu^{02}_{2},\lambda^{02}_{2}}\left(\phi^{02}_{0},\theta^{02}_{0},0\right) D^{0}_{\nu^{02}_{2},\mu^{02}_{2}}\left(\phi_{02},\theta_{02},0\right) D^{\frac{1}{2}}_{m_{0},\nu^{02}_{0}}\left(\alpha^{02}_{0},\beta^{02}_{0},\gamma^{02}_{0}\right) D^{\frac{1}{2}}_{\mu^{02}_{0},\lambda^{02}_{0}}\left(\phi^{02}_{0},\theta^{02}_{0},0\right) D^{\frac{1}{2}}_{\nu^{02}_{0},\mu^{02}_{0}}\left(\phi_{02},\theta_{02},0\right)} + \sum_{\lambda^{12}_{0}=-1/2}^{1/2} \sum_{\lambda^{12}_{1}=0} \sum_{\mu^{12}_{1}=0} \sum_{\nu^{12}_{1}=0} \sum_{\lambda^{12}_{2}=0} \sum_{\mu^{12}_{2}=0} \sum_{\nu^{12}_{2}=0}{{A^{12}}_{m_{A},\lambda^{12}_{0},\lambda^{12}_{1},- \lambda^{12}_{2}} D^{0}_{m_{1},\nu^{12}_{1}}\left(\alpha^{12}_{1},\beta^{12}_{1},\gamma^{12}_{1}\right) D^{0}_{m_{2},\nu^{12}_{2}}\left(\alpha^{12}_{2},\beta^{12}_{2},\gamma^{12}_{2}\right) D^{0}_{\mu^{12}_{1},\lambda^{12}_{1}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right) D^{0}_{\mu^{12}_{2},\lambda^{12}_{2}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right) D^{0}_{\nu^{12}_{1},\mu^{12}_{1}}\left(\phi_{0},\theta_{0},0\right) D^{0}_{\nu^{12}_{2},\mu^{12}_{2}}\left(\phi_{0},\theta_{0},0\right) D^{\frac{1}{2}}_{m_{0},\lambda^{12}_{0}}\left(\phi_{0},\theta_{0},0\right)}}\right|^{2}}\]
Hide code cell source
plot_distributions(aligned_model)

013_28_0

Compare with Figure 2. Note that the distributions differ close to threshold, because the distributions in the paper are produced with form factors and an energy-dependent width.

Amplitude model with sum notation#

Hide code cell content
from __future__ import annotations

import inspect
import itertools
import logging
from functools import lru_cache
from typing import TYPE_CHECKING, Iterable, Sequence

import ampform
import attrs
import graphviz
import qrules
import symplot
import sympy as sp
from ampform.dynamics.builder import (
    ResonanceDynamicsBuilder,
    create_non_dynamic,
    create_relativistic_breit_wigner,
)
from ampform.helicity import (
    _generate_kinematic_variable_set,
    _generate_kinematic_variables,
)
from ampform.helicity.decay import TwoBodyDecay
from ampform.sympy import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
)
from IPython.display import Math, display
from sympy.core.symbol import Str
from sympy.physics.quantum.spin import Rotation as Wigner
from sympy.printing.precedence import PRECEDENCE

if TYPE_CHECKING:
    import sys

    from qrules.topology import Topology
    from qrules.transition import StateTransition
    from sympy.printing.latex import LatexPrinter

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)
Problem description#

Challenge

Formulate the HelicityModel in such a way that:

  1. The sum over the amplitudes is concise and expresses that the sum depends only on helicities.

  2. It is the sympy.Expr of the amplitude model can easily and uniquely be constructed from the data in the HelicityModel. (Currently, this is as simple as HelicityModel.expression.doit).

  3. All parameters under parameter_defaults are of type sympy.Symbol. This is important for a correct lambdification of the arguments with sympy.utilities.lambdify.lambdify().

This report presents two solutions:

ampform#245 implements spin alignment, which results in large sum combinatorics for all helicity combinations. The result is an amplitude model expression that is too large to be rendered as LaTeX.

To some extend, this is already the case with the current implementation of the ‘standard’ helicity formalism [Jacob and Wick, 1959, Richman, 1984, Kutschke, 1996, Chung, 2014]: many of the terms in the total intensity expression differ only by the helicities of the final and initial state.

Hide code cell source
# Simplify resonance notation
PDG = qrules.load_pdg()
delta_res = PDG["Delta(1600)++"]
lambda_res = PDG["Lambda(1405)"]
particles = set(PDG)
particles.remove(delta_res)
particles.remove(lambda_res)
particles.add(attrs.evolve(delta_res, latex=R"\Delta"))
particles.add(attrs.evolve(lambda_res, latex=R"\Lambda"))
MODIFIED_PDG = qrules.ParticleCollection(particles)
reaction = qrules.generate_transitions(
    initial_state="Lambda(c)+",
    final_state=["K-", "p", "pi+"],
    formalism="helicity",
    allowed_intermediate_particles=["Delta(1600)++"],
    particle_db=MODIFIED_PDG,
)
Hide code cell source
display(*(graphviz.Source(qrules.io.asdot(t, size=3)) for t in reaction.transitions))

builder = ampform.get_builder(reaction)
model = builder.formulate()
Hide code cell source
def remove_coefficients(expr: sp.Expr) -> sp.Expr:
    coefficients = {s: 1 for s in expr.free_symbols if s.name.startswith("C_")}
    return expr.subs(coefficients)


model = builder.formulate()
full_expression = remove_coefficients(model.expression)
I = sp.Symbol("I")  # noqa: E741
latex = sp.multiline_latex(I, full_expression)
Math(latex)
\[\begin{split}\displaystyle \begin{align*} I = & \left|{D^{\frac{1}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right) + D^{\frac{1}{2}}_{- \frac{1}{2},\frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\frac{1}{2},- \frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}\right|^{2} \\ & + \left|{D^{\frac{1}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{- \frac{1}{2},\frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right) + D^{\frac{1}{2}}_{- \frac{1}{2},\frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\frac{1}{2},\frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}\right|^{2} \\ & + \left|{D^{\frac{1}{2}}_{\frac{1}{2},- \frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right) + D^{\frac{1}{2}}_{\frac{1}{2},\frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\frac{1}{2},- \frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}\right|^{2} \\ & + \left|{D^{\frac{1}{2}}_{\frac{1}{2},- \frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{- \frac{1}{2},\frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right) + D^{\frac{1}{2}}_{\frac{1}{2},\frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\frac{1}{2},\frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}\right|^{2} \end{align*}\end{split}\]

Here, we did not insert any dynamics, but it is unusual that dynamics expressions depend on helicity or spin projection (see TwoBodyKinematicVariableSet).

Simplified notation with PoolSum#

Both Solution 1: Indexed coefficients and Solution 2: Indexed amplitude components require the definition of a special “PoolSum” class to simplify the summation over the amplitudes. The class mimics sympy.Sum in that it substitutes certain Symbols in an expression over which we symbol a range of values. The range of values in a PoolSum does not have to be a sequential range, but can be a collection of arbitrary items.

Hide code cell content
@implement_doit_method
class PoolSum(UnevaluatedExpression):
    precedence = PRECEDENCE["Add"]

    def __new__(
        cls,
        expression: sp.Expr,
        *indices: tuple[sp.Symbol, Iterable[sp.Float]],
        **hints,
    ) -> PoolSum:
        indices = tuple((s, tuple(v)) for s, v in indices)
        return create_expression(cls, expression, *indices, **hints)

    @property
    def expression(self) -> sp.Expr:
        return self.args[0]

    @property
    def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]:
        return self.args[1:]

    def evaluate(self) -> sp.Expr:
        indices = dict(self.indices)
        return sum(
            self.expression.subs(zip(indices, combi))
            for combi in itertools.product(*indices.values())
        )

    def _latex(self, printer: LatexPrinter, *args) -> str:
        indices = dict(self.indices)
        sum_symbols: list[str] = []
        for idx, values in indices.items():
            sum_symbols.append(_render_sum_symbol(printer, idx, values))
        expression = printer._print(self.expression)
        return R" ".join(sum_symbols) + f"{{{expression}}}"

    def cleanup(self) -> sp.Expr | PoolSum:
        substitutions = {}
        new_indices = []
        for idx, values in self.indices:
            if idx not in self.expression.free_symbols:
                continue
            if len(values) == 0:
                continue
            if len(values) == 1:
                substitutions[idx] = values[0]
            else:
                new_indices.append((idx, values))
        new_expression = self.expression.xreplace(substitutions)
        if len(new_indices) == 0:
            return new_expression
        return PoolSum(new_expression, *new_indices)


def _render_sum_symbol(
    printer: LatexPrinter, idx: sp.Symbol, values: Sequence[float]
) -> str:
    if len(values) == 0:
        return ""
    idx = printer._print(idx)
    if len(values) == 1:
        value = values[0]
        return Rf"\sum_{{{idx}={value}}}"
    if _is_regular_series(values):
        sorted_values = sorted(values)
        first_value = sorted_values[0]
        last_value = sorted_values[-1]
        return Rf"\sum_{{{idx}={first_value}}}^{{{last_value}}}"
    idx_values = ",".join(map(printer._print, values))
    return Rf"\sum_{{{idx}\in\left\{{{idx_values}\right\}}}}"


def _is_regular_series(values: Sequence[float]) -> bool:
    if len(values) <= 1:
        return False
    sorted_values = sorted(values)
    for val, next_val in zip(sorted_values, sorted_values[1:]):
        difference = float(next_val - val)
        if difference != 1.0:
            return False
    return True

Here’s a sketch of how to construct the amplitude model with a PoolSum:

Hide code cell source
half = sp.S.Half

spin_parent = sp.Symbol(R"s_{\Lambda_c}", real=True)
spin_resonance = sp.Symbol(R"s_\Delta", real=True)

phi_12, theta_12 = sp.symbols("phi_12 theta_12", real=True)
phi_1_12, theta_1_12 = sp.symbols(R"phi_1^12 theta_1^12", real=True)

lambda_parent = sp.Symbol(R"\lambda_{\Lambda_c}", real=True)
lambda_resonance = sp.Symbol(R"\lambda_\Delta", real=True)
lambda_p = sp.Symbol(R"\lambda_p", real=True)
lambda_k = sp.Symbol(R"\lambda_K", real=True)
lambda_pi = sp.Symbol(R"\lambda_\pi", real=True)
sum_expr = sp.Subs(
    PoolSum(
        sp.Abs(
            PoolSum(
                Wigner.D(
                    spin_parent,
                    lambda_parent,
                    lambda_k - lambda_resonance,
                    phi_12,
                    theta_12,
                    0,
                )
                * Wigner.D(
                    spin_resonance,
                    lambda_resonance,
                    lambda_p - lambda_pi,
                    phi_1_12,
                    theta_1_12,
                    0,
                ),
                (lambda_resonance, (-half, +half)),
            )
        )
        ** 2,
        (lambda_parent, (-half, +half)),
        (lambda_p, (-half, +half)),
        (lambda_pi, (0,)),
        (lambda_k, (0,)),
    ),
    (spin_parent, spin_resonance),
    (half, 3 * half),
)
display(
    sum_expr,
    sum_expr.expr.cleanup(),
    Math(sp.multiline_latex(I, sum_expr.doit(deep=False))),
    sum_expr.doit(deep=False).doit(deep=True),
)
\[\begin{split}\displaystyle \left. \sum_{\lambda_{\Lambda_c}=-1/2}^{1/2} \sum_{\lambda_{p}=-1/2}^{1/2} \sum_{\lambda_{\pi}=0} \sum_{\lambda_{K}=0}{\left|{\sum_{\lambda_{\Delta}=-1/2}^{1/2}{D^{s_{\Delta}}_{\lambda_{\Delta},- \lambda_{\pi} + \lambda_{p}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right) D^{s_{\Lambda_c}}_{\lambda_{\Lambda_c},\lambda_{K} - \lambda_{\Delta}}\left(\phi_{12},\theta_{12},0\right)}}\right|^{2}} \right|_{\substack{ s_{\Lambda_c}=\frac{1}{2}\\ s_{\Delta}=\frac{3}{2} }}\end{split}\]
\[\displaystyle \sum_{\lambda_{\Lambda_c}=-1/2}^{1/2} \sum_{\lambda_{p}=-1/2}^{1/2}{\left|{\sum_{\lambda_{\Delta}=-1/2}^{1/2}{D^{s_{\Delta}}_{\lambda_{\Delta},\lambda_{p}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right) D^{s_{\Lambda_c}}_{\lambda_{\Lambda_c},- \lambda_{\Delta}}\left(\phi_{12},\theta_{12},0\right)}}\right|^{2}}\]
\[\begin{split}\displaystyle \begin{align*} I = & \left|{\sum_{\lambda_{\Delta}=-1/2}^{1/2}{D^{\frac{1}{2}}_{- \frac{1}{2},- \lambda_{\Delta}}\left(\phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{\Delta},- \frac{1}{2}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2} \\ & + \left|{\sum_{\lambda_{\Delta}=-1/2}^{1/2}{D^{\frac{1}{2}}_{- \frac{1}{2},- \lambda_{\Delta}}\left(\phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{\Delta},\frac{1}{2}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2} \\ & + \left|{\sum_{\lambda_{\Delta}=-1/2}^{1/2}{D^{\frac{1}{2}}_{\frac{1}{2},- \lambda_{\Delta}}\left(\phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{\Delta},- \frac{1}{2}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2} \\ & + \left|{\sum_{\lambda_{\Delta}=-1/2}^{1/2}{D^{\frac{1}{2}}_{\frac{1}{2},- \lambda_{\Delta}}\left(\phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{\Delta},\frac{1}{2}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2} \end{align*}\end{split}\]
\[\displaystyle \frac{\sin^{2}{\left(\frac{\theta_{12}}{2} \right)} \sin^{2}{\left(\frac{\theta^{12}_{1}}{2} \right)}}{8} - \frac{3 \sin^{2}{\left(\frac{\theta_{12}}{2} \right)} \sin{\left(\frac{\theta^{12}_{1}}{2} \right)} \sin{\left(\frac{3 \theta^{12}_{1}}{2} \right)}}{4} + \frac{9 \sin^{2}{\left(\frac{\theta_{12}}{2} \right)} \sin^{2}{\left(\frac{3 \theta^{12}_{1}}{2} \right)}}{8} + \frac{\sin^{2}{\left(\frac{\theta_{12}}{2} \right)} \cos^{2}{\left(\frac{\theta^{12}_{1}}{2} \right)}}{8} + \frac{3 \sin^{2}{\left(\frac{\theta_{12}}{2} \right)} \cos{\left(\frac{\theta^{12}_{1}}{2} \right)} \cos{\left(\frac{3 \theta^{12}_{1}}{2} \right)}}{4} + \frac{9 \sin^{2}{\left(\frac{\theta_{12}}{2} \right)} \cos^{2}{\left(\frac{3 \theta^{12}_{1}}{2} \right)}}{8} + \frac{\sin^{2}{\left(\frac{\theta^{12}_{1}}{2} \right)} \cos^{2}{\left(\frac{\theta_{12}}{2} \right)}}{8} - \frac{3 \sin{\left(\frac{\theta^{12}_{1}}{2} \right)} \sin{\left(\frac{3 \theta^{12}_{1}}{2} \right)} \cos^{2}{\left(\frac{\theta_{12}}{2} \right)}}{4} + \frac{9 \sin^{2}{\left(\frac{3 \theta^{12}_{1}}{2} \right)} \cos^{2}{\left(\frac{\theta_{12}}{2} \right)}}{8} + \frac{\cos^{2}{\left(\frac{\theta_{12}}{2} \right)} \cos^{2}{\left(\frac{\theta^{12}_{1}}{2} \right)}}{8} + \frac{3 \cos^{2}{\left(\frac{\theta_{12}}{2} \right)} \cos{\left(\frac{\theta^{12}_{1}}{2} \right)} \cos{\left(\frac{3 \theta^{12}_{1}}{2} \right)}}{4} + \frac{9 \cos^{2}{\left(\frac{\theta_{12}}{2} \right)} \cos^{2}{\left(\frac{3 \theta^{12}_{1}}{2} \right)}}{8}\]
Solution 1: Indexed coefficients#

The current implementation of the HelicityAmplitudeBuilder has to be changed quite a bit to produce an amplitude model with PoolSums. First of all, we have to introduce special Symbols for the helicities, \(\lambda_i\), with \(i\) the state ID (taking a sum of attached final state IDs in case of a resonance ID). Next, formulate_wigner_d() has to be modified to insert these Symbols into the WignerD:

Hide code cell content
def formulate_wigner_d(transition: StateTransition, node_id: int) -> sp.Expr:
    from sympy.physics.quantum.spin import Rotation as Wigner

    decay = TwoBodyDecay.from_transition(transition, node_id)
    topology = transition.topology
    parent_helicity = create_helicity_symbol(topology, decay.parent.id)
    child1_helicity = create_helicity_symbol(topology, decay.children[0].id)
    child2_helicity = create_helicity_symbol(topology, decay.children[1].id)
    _, phi, theta = _generate_kinematic_variables(transition, node_id)
    return Wigner.D(
        j=sp.Rational(decay.parent.particle.spin),
        m=parent_helicity,
        mp=child1_helicity - child2_helicity,
        alpha=-phi,
        beta=theta,
        gamma=0,
    )


def create_helicity_symbol(topology: Topology, state_id: int) -> sp.Symbol:
    if state_id in topology.incoming_edge_ids:
        suffix = ""
    else:
        suffix = f"_{state_id}"
    return sp.Symbol(f"lambda{suffix}", rational=True)
wigner_functions = {
    sp.Mul(*[
        formulate_wigner_d(transition, node_id)
        for node_id in transition.topology.nodes
    ])
    for transition in reaction.transitions
}
display(*wigner_functions)
\[\displaystyle D^{\frac{1}{2}}_{\lambda,- \lambda_{0} + \lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},\lambda_{1} - \lambda_{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)\]

We also have to collect the allowed helicity values for each of these helicity symbols.

Hide code cell content
from collections import defaultdict

from qrules import ReactionInfo

if TYPE_CHECKING:
    if sys.version_info >= (3, 8):
        from typing import Literal
    else:
        from typing_extensions import Literal


@lru_cache(maxsize=None)
def get_helicities(
    reaction: ReactionInfo, which: Literal["inner", "outer"]
) -> dict[int, set[sp.Rational]]:
    helicities = defaultdict(set)
    initial_state_ids = set(reaction.initial_state)
    final_state_ids = set(reaction.final_state)
    intermediate_state_ids = (
        set(reaction.transitions[0].states) - initial_state_ids - final_state_ids
    )
    if which == "inner":
        state_ids = sorted(intermediate_state_ids)
    elif which == "outer":
        state_ids = sorted(initial_state_ids | final_state_ids)
    for transition in reaction.transitions:
        for state_id in state_ids:
            state = transition.states[state_id]
            helicity = sp.Rational(state.spin_projection)
            symbol = create_helicity_symbol(transition.topology, state_id)
            helicities[symbol].add(helicity)
    return dict(helicities)
inner_helicities = get_helicities(reaction, which="inner")
outer_helicities = get_helicities(reaction, which="outer")
display(inner_helicities, outer_helicities)
{lambda_3: {-1/2, 1/2}}
{lambda: {-1/2, 1/2}, lambda_0: {0}, lambda_1: {-1/2, 1/2}, lambda_2: {0}}

These collected helicity values can then be combined with the Wigner-\(D\) expressions through a PoolSum:

def formulate_intensity(reaction: ReactionInfo):
    wigner_functions = {
        sp.Mul(*[
            formulate_wigner_d(transition, node_id)
            for node_id in transition.topology.nodes
        ])
        for transition in reaction.transitions
    }
    inner_helicities = get_helicities(reaction, which="inner")
    outer_helicities = get_helicities(reaction, which="outer")
    return PoolSum(
        sp.Abs(
            PoolSum(
                sum(wigner_functions),
                *inner_helicities.items(),
            )
        )
        ** 2,
        *outer_helicities.items(),
    )


formulate_intensity(reaction)
\[\displaystyle \sum_{\lambda=-1/2}^{1/2} \sum_{\lambda_{0}=0} \sum_{\lambda_{1}=-1/2}^{1/2} \sum_{\lambda_{2}=0}{\left|{\sum_{\lambda_{3}=-1/2}^{1/2}{D^{\frac{1}{2}}_{\lambda,- \lambda_{0} + \lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},\lambda_{1} - \lambda_{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2}}\]

This is indeed identical to the model as formulated with the existing implementation:

assert formulate_intensity(reaction).doit() == full_expression.doit()

Note how this approach also works in case there are two decay topologies:

Hide code cell source
reaction_two_resonances = qrules.generate_transitions(
    initial_state="Lambda(c)+",
    final_state=["K-", "p", "pi+"],
    formalism="helicity",
    allowed_intermediate_particles=["Lambda(1405)", "Delta(1600)++"],
    particle_db=MODIFIED_PDG,
)
assert len(reaction_two_resonances.transition_groups) == 2
dot = qrules.io.asdot(reaction, collapse_graphs=True)
display(*[
    graphviz.Source(
        qrules.io.asdot(g, collapse_graphs=True, size=4, render_resonance_id=True)
    )
    for g in reaction_two_resonances.transition_groups
])

formulate_intensity(reaction_two_resonances).cleanup()
\[\displaystyle \sum_{\lambda=-1/2}^{1/2} \sum_{\lambda_{1}=-1/2}^{1/2}{\left|{\sum_{\lambda_{3}=-1/2}^{1/2}{D^{\frac{1}{2}}_{\lambda,\lambda_{3}}\left(- \phi_{01},\theta_{01},0\right) D^{\frac{1}{2}}_{\lambda_{3},- \lambda_{1}}\left(- \phi^{01}_{0},\theta^{01}_{0},0\right) + D^{\frac{1}{2}}_{\lambda,\lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},\lambda_{1}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2}}\]
Inserting coefficients#

There is a problem, though: the sums above could be written in sum form because the values over which we sum appear as arguments (args in the WignerD functions). This is only true because we previously set all coefficients to \(1\). The coefficient names are, however, also ‘dependent’ on the helicities in the final state over which we sum:

sorted(model.parameter_defaults, key=str)
[C_{\Lambda_{c}^{+} \to \Delta_{-1/2} K^{-}_{0}; \Delta \to p_{+1/2} \pi^{+}_{0}},
 C_{\Lambda_{c}^{+} \to \Delta_{+1/2} K^{-}_{0}; \Delta \to p_{+1/2} \pi^{+}_{0}}]

We therefore have to somehow introduce a dependence in these Symbols on the helicity values. An idea may be to use IndexedBase. Modifying the function introduced in Solution 1: Indexed coefficients:

C = sp.IndexedBase("C")


@lru_cache(maxsize=None)
def formulate_intensity_with_coefficient(reaction: ReactionInfo):
    amplitudes = {
        sp.Mul(
            C[[
                create_helicity_symbol(transition.topology, state_id)
                for state_id in transition.final_states
            ]],
            *[
                formulate_wigner_d(transition, node_id)
                for node_id in transition.topology.nodes
            ],
        )
        for transition in reaction.transitions
    }
    inner_helicities = get_helicities(reaction, which="inner")
    outer_helicities = get_helicities(reaction, which="outer")
    return PoolSum(
        sp.Abs(
            PoolSum(
                sum(amplitudes),
                *inner_helicities.items(),
            )
        )
        ** 2,
        *outer_helicities.items(),
    )


indexed_coefficient_expr = formulate_intensity_with_coefficient(reaction)
indexed_coefficient_expr
\[\displaystyle \sum_{\lambda=-1/2}^{1/2} \sum_{\lambda_{0}=0} \sum_{\lambda_{1}=-1/2}^{1/2} \sum_{\lambda_{2}=0}{\left|{\sum_{\lambda_{3}=-1/2}^{1/2}{{C}_{\lambda_{0},\lambda_{1},\lambda_{2}} D^{\frac{1}{2}}_{\lambda,- \lambda_{0} + \lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},\lambda_{1} - \lambda_{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2}}\]
Hide code cell source
latex = sp.multiline_latex(I, indexed_coefficient_expr.doit(deep=False))
Math(latex)
\[\begin{split}\displaystyle \begin{align*} I = & \left|{\sum_{\lambda_{3}=-1/2}^{1/2}{{C}_{0,- \frac{1}{2},0} D^{\frac{1}{2}}_{- \frac{1}{2},\lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},- \frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2} \\ & + \left|{\sum_{\lambda_{3}=-1/2}^{1/2}{{C}_{0,- \frac{1}{2},0} D^{\frac{1}{2}}_{\frac{1}{2},\lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},- \frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2} \\ & + \left|{\sum_{\lambda_{3}=-1/2}^{1/2}{{C}_{0,\frac{1}{2},0} D^{\frac{1}{2}}_{- \frac{1}{2},\lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},\frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2} \\ & + \left|{\sum_{\lambda_{3}=-1/2}^{1/2}{{C}_{0,\frac{1}{2},0} D^{\frac{1}{2}}_{\frac{1}{2},\lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},\frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}}\right|^{2} \end{align*}\end{split}\]

Caveat

Using IndexedBase makes the coefficient names concise, but harder to understand.

This seems to work rather well, but there is a subtle problems introduced by writing the coefficients as a IndexedBase: the IndexedBase itself is considered listed under the free_symbols of the expression.

free_symbols = sorted(indexed_coefficient_expr.doit().free_symbols, key=str)
free_symbols
[C, C[0, -1/2, 0], C[0, 1/2, 0], phi_12, phi_1^12, theta_12, theta_1^12]

In addition, not all symbols in the expression are of type Symbol anymore:

{s: (type(s), isinstance(s, sp.Symbol)) for s in free_symbols}
{C: (sympy.core.symbol.Symbol, True),
 C[0, -1/2, 0]: (sympy.tensor.indexed.Indexed, False),
 C[0, 1/2, 0]: (sympy.tensor.indexed.Indexed, False),
 phi_12: (sympy.core.symbol.Symbol, True),
 phi_1^12: (sympy.core.symbol.Symbol, True),
 theta_12: (sympy.core.symbol.Symbol, True),
 theta_1^12: (sympy.core.symbol.Symbol, True)}

This will become problematic when lambdifying, because it results in an additional argument in the signature of the generated function:

func = sp.lambdify(free_symbols, indexed_coefficient_expr.doit())
inspect.signature(func)
<Signature (C, Dummy_24, Dummy_23, phi_12, Dummy_26, theta_12, Dummy_25)>

A solution may be to use symplot.substitute_indexed_symbols():

indexed_coefficient_expr_symbols_only = symplot.substitute_indexed_symbols(
    indexed_coefficient_expr.doit()
)
indexed_coefficient_expr_symbols_only.free_symbols
{C_{0,-1/2,0}, C_{0,1/2,0}, phi_12, phi_1^12, theta_12, theta_1^12}
args = sorted(indexed_coefficient_expr_symbols_only.free_symbols, key=str)
func = sp.lambdify(args, indexed_coefficient_expr_symbols_only)
inspect.signature(func)
<Signature (Dummy_30, Dummy_29, phi_12, Dummy_28, theta_12, Dummy_27)>

Caveat

This seems clumsy, because substitute_indexed_symbols() would have to be actively called before creating a computational function with TensorWaves. It also becomes a hassle to keep track of the correct Symbol names in HelicityModel.parameter_defaults.

Hide code cell content
# One topology
expr = ampform.get_builder(reaction).formulate().expression
expr = remove_coefficients(expr.doit())
sum_expr = formulate_intensity_with_coefficient(reaction)
sum_expr = symplot.substitute_indexed_symbols(sum_expr.doit())
sum_expr = remove_coefficients(sum_expr)
assert sum_expr == expr

# Two topologies
expr = ampform.get_builder(reaction_two_resonances).formulate().expression
expr = remove_coefficients(expr.doit())
sum_expr = formulate_intensity_with_coefficient(reaction_two_resonances)
sum_expr = symplot.substitute_indexed_symbols(sum_expr.doit())
sum_expr = remove_coefficients(sum_expr)
assert sum_expr == expr
Inserting dynamics#

Dynamics pose a challenge that is similar to Inserting coefficients in that we have to introduce expressions that are dependent on spin. Still, as can be seen from the available attributes on a TwoBodyKinematicVariableSet (which serves as input to ResonanceDynamicsBuilders), dynamics (currently) cannot depend on helicities.

What may become a problem are \(LS\)-combinations. So far we have only considered a ReactionInfo that was created with formalism="helicity", but we also have to sum over \(LS\)-combinations when using formalism="canonical-helicity". This is particularly important when using dynamics with form factors, which depend on angular_momentum.

Note

The sympy.tensor.indexed.Indexed now also contains the names of the resonances.

def formulate_intensity_with_dynamics(
    reaction: ReactionInfo,
    dynamics_choices: dict[str, ResonanceDynamicsBuilder],
):
    amplitudes = set()
    for transition in reaction.transitions:
        final_state_helicities = [
            create_helicity_symbol(transition.topology, state_id)
            for state_id in transition.final_states
        ]
        resonances = [
            Str(s.particle.latex) for s in transition.intermediate_states.values()
        ]
        indices = [*final_state_helicities, *resonances]
        coefficient = C[indices]
        expr: sp.Expr = coefficient
        for node_id in sorted(transition.topology.nodes):
            expr *= formulate_wigner_d(transition, node_id)
            decay = TwoBodyDecay.from_transition(transition, node_id)
            parent_particle = decay.parent.particle
            dynamics_builder = dynamics_choices.get(
                parent_particle.name, create_non_dynamic
            )
            variables = _generate_kinematic_variable_set(transition, node_id)
            dynamics, _ = dynamics_builder(parent_particle, variables)
            expr *= dynamics
        amplitudes.add(expr)
    inner_helicities = get_helicities(reaction, which="inner")
    outer_helicities = get_helicities(reaction, which="outer")
    return PoolSum(
        sp.Abs(
            PoolSum(sum(amplitudes), *inner_helicities.items()),
            evaluate=False,
        )
        ** 2,
        *outer_helicities.items(),
    )
formulate_intensity_with_dynamics(
    reaction,
    dynamics_choices={
        resonance.name: create_relativistic_breit_wigner
        for resonance in reaction.get_intermediate_particles()
    },
)
\[\displaystyle \sum_{\lambda=-1/2}^{1/2} \sum_{\lambda_{0}=0} \sum_{\lambda_{1}=-1/2}^{1/2} \sum_{\lambda_{2}=0}{\left|{\sum_{\lambda_{3}=-1/2}^{1/2}{\frac{\Gamma_{\Delta} m_{\Delta} {C}_{\lambda_{0},\lambda_{1},\lambda_{2},\Delta} D^{\frac{1}{2}}_{\lambda,- \lambda_{0} + \lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},\lambda_{1} - \lambda_{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}{- i \Gamma_{\Delta} m_{\Delta} - m_{12}^{2} + m_{\Delta}^{2}}}}\right|^{2}}\]
formulate_intensity_with_dynamics(
    reaction_two_resonances,
    dynamics_choices={
        resonance.name: create_relativistic_breit_wigner
        for resonance in reaction_two_resonances.get_intermediate_particles()
    },
)
\[\displaystyle \sum_{\lambda=-1/2}^{1/2} \sum_{\lambda_{0}=0} \sum_{\lambda_{1}=-1/2}^{1/2} \sum_{\lambda_{2}=0}{\left|{\sum_{\lambda_{3}=-1/2}^{1/2}{\frac{\Gamma_{\Delta} m_{\Delta} {C}_{\lambda_{0},\lambda_{1},\lambda_{2},\Delta} D^{\frac{1}{2}}_{\lambda,- \lambda_{0} + \lambda_{3}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\lambda_{3},\lambda_{1} - \lambda_{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}{- i \Gamma_{\Delta} m_{\Delta} - m_{12}^{2} + m_{\Delta}^{2}} + \frac{\Gamma_{\Lambda} m_{\Lambda} {C}_{\lambda_{0},\lambda_{1},\lambda_{2},\Lambda} D^{\frac{1}{2}}_{\lambda,- \lambda_{2} + \lambda_{3}}\left(- \phi_{01},\theta_{01},0\right) D^{\frac{1}{2}}_{\lambda_{3},\lambda_{0} - \lambda_{1}}\left(- \phi^{01}_{0},\theta^{01}_{0},0\right)}{- i \Gamma_{\Lambda} m_{\Lambda} - m_{01}^{2} + m_{\Lambda}^{2}}}}\right|^{2}}\]

The resulting amplitude is again identical to the original HelicityModel.expression:

Hide code cell source
# One topology
b = ampform.get_builder(reaction)
for p in reaction.get_intermediate_particles():
    b.set_dynamics(p.name, create_relativistic_breit_wigner)
expr = b.formulate().expression
expr = remove_coefficients(expr.doit())
sum_expr = formulate_intensity_with_dynamics(
    reaction,
    dynamics_choices={
        resonance.name: create_relativistic_breit_wigner
        for resonance in reaction.get_intermediate_particles()
    },
)
sum_expr = symplot.substitute_indexed_symbols(sum_expr.doit())
sum_expr = remove_coefficients(sum_expr)
assert sum_expr == expr

# Two topologies
b = ampform.get_builder(reaction_two_resonances)
for p in reaction_two_resonances.get_intermediate_particles():
    b.set_dynamics(p.name, create_relativistic_breit_wigner)
expr = b.formulate().expression
expr = remove_coefficients(expr)
sum_expr = formulate_intensity_with_dynamics(
    reaction_two_resonances,
    dynamics_choices={
        resonance.name: create_relativistic_breit_wigner
        for resonance in reaction_two_resonances.get_intermediate_particles()
    },
)
sum_expr = symplot.partial_doit(sum_expr, doit_classes=(PoolSum,))
sum_expr = symplot.partial_doit(sum_expr, doit_classes=(PoolSum,))  # recurse
sum_expr = symplot.substitute_indexed_symbols(sum_expr)
sum_expr = remove_coefficients(sum_expr)
assert sum_expr.free_symbols == expr.free_symbols
for intensity1, intensity2 in zip(sum_expr.args, expr.args):
    # Annoyingly, Abs is rewritten with conjugates when using PoolSum...
    amp1 = intensity1.args[0]
    amp2 = intensity2.rewrite(sp.conjugate).args[0]
    amp1 = sp.factor(amp1, deep=True, fraction=False)
    amp2 = sp.factor(amp2, deep=True, fraction=False)
    assert amp1 == amp2
Solution 2: Indexed amplitude components#

The main problem with Solution 2: Indexed amplitude components is that it requires changing coefficient Symbols to instances of Indexed, which have to be substituted using substitute_indexed_symbols() (after calling doit()).

An alternative would be insert dynamics (and coefficients) into the PoolSums over the helicities is to index the amplitude itself. The HelicityModel.expression would then contain Indexed symbols that represent specific amplitudes. A definition of these amplitudes can be provided through HelicityModel.components or an equivalent attribute.

Hide code cell content
from ampform.helicity.naming import HelicityAmplitudeNameGenerator

A = sp.IndexedBase(R"\mathcal{A}")


def formulate_intensity_indexed_amplitudes_only(
    reaction: ReactionInfo,
    dynamics_choices: dict[str, ResonanceDynamicsBuilder],
) -> tuple[sp.Expr, dict[sp.Indexed, sp.Expr]]:
    name_generator = HelicityAmplitudeNameGenerator()
    amplitudes = set()
    amplitude_definitions = {}
    for transition in reaction.transitions:
        name_generator.register_amplitude_coefficient_name(transition)
    for transition in reaction.transitions:
        suffix = name_generator.generate_sequential_amplitude_suffix(transition)
        expr: sp.Expr = sp.Symbol(f"C_{{{suffix}}}")
        for node_id in sorted(transition.topology.nodes):
            expr *= ampform.helicity.formulate_wigner_d(transition, node_id)
            decay = TwoBodyDecay.from_transition(transition, node_id)
            parent_particle = decay.parent.particle
            dynamics_builder = dynamics_choices.get(
                parent_particle.name, create_non_dynamic
            )
            variables = _generate_kinematic_variable_set(transition, node_id)
            dynamics, _ = dynamics_builder(parent_particle, variables)
            expr *= dynamics
        resonances = [
            Str(s.particle.latex) for s in transition.intermediate_states.values()
        ]
        helicity_symbols = [
            create_helicity_symbol(transition.topology, state_id)
            for state_id in sorted(transition.states)
        ]
        helicities = [
            sp.Rational(transition.states[state_id].spin_projection)
            for state_id in sorted(transition.states)
        ]
        amplitudes.add(A[[*helicity_symbols, *resonances]])
        amplitude_definitions[A[[*helicities, *resonances]]] = expr
    inner_helicities = get_helicities(reaction, which="inner")
    outer_helicities = get_helicities(reaction, which="outer")
    expression = PoolSum(
        sp.Abs(
            PoolSum(sum(amplitudes), *inner_helicities.items()),
            evaluate=False,
        )
        ** 2,
        *outer_helicities.items(),
    )
    return expression, amplitude_definitions
expression, amplitudes = formulate_intensity_indexed_amplitudes_only(
    reaction_two_resonances,
    dynamics_choices={
        resonance.name: create_relativistic_breit_wigner
        for resonance in reaction_two_resonances.get_intermediate_particles()
    },
)
Hide code cell source
display(Math(sp.multiline_latex(I, expression)))
display(Math(sp.multiline_latex(I, expression.doit())))
for i, (symbol, expr) in enumerate(amplitudes.items(), 1):
    latex = sp.multiline_latex(symbol, expr)
    display(Math(latex))
    if i == 3:
        break
\[\displaystyle \begin{align*} I = & \sum_{\lambda=-1/2}^{1/2} \sum_{\lambda_{0}=0} \sum_{\lambda_{1}=-1/2}^{1/2} \sum_{\lambda_{2}=0}{\left|{\sum_{\lambda_{3}=-1/2}^{1/2}{{\mathcal{A}}_{\lambda,\lambda_{0},\lambda_{1},\lambda_{2},\lambda_{3},\Delta} + {\mathcal{A}}_{\lambda,\lambda_{0},\lambda_{1},\lambda_{2},\lambda_{3},\Lambda}}}\right|^{2}} \end{align*}\]
\[\begin{split}\displaystyle \begin{align*} I = & \left|{{\mathcal{A}}_{- \frac{1}{2},0,- \frac{1}{2},0,- \frac{1}{2},\Delta} + {\mathcal{A}}_{- \frac{1}{2},0,- \frac{1}{2},0,- \frac{1}{2},\Lambda} + {\mathcal{A}}_{- \frac{1}{2},0,- \frac{1}{2},0,\frac{1}{2},\Delta} + {\mathcal{A}}_{- \frac{1}{2},0,- \frac{1}{2},0,\frac{1}{2},\Lambda}}\right|^{2} \\ & + \left|{{\mathcal{A}}_{- \frac{1}{2},0,\frac{1}{2},0,- \frac{1}{2},\Delta} + {\mathcal{A}}_{- \frac{1}{2},0,\frac{1}{2},0,- \frac{1}{2},\Lambda} + {\mathcal{A}}_{- \frac{1}{2},0,\frac{1}{2},0,\frac{1}{2},\Delta} + {\mathcal{A}}_{- \frac{1}{2},0,\frac{1}{2},0,\frac{1}{2},\Lambda}}\right|^{2} \\ & + \left|{{\mathcal{A}}_{\frac{1}{2},0,- \frac{1}{2},0,- \frac{1}{2},\Delta} + {\mathcal{A}}_{\frac{1}{2},0,- \frac{1}{2},0,- \frac{1}{2},\Lambda} + {\mathcal{A}}_{\frac{1}{2},0,- \frac{1}{2},0,\frac{1}{2},\Delta} + {\mathcal{A}}_{\frac{1}{2},0,- \frac{1}{2},0,\frac{1}{2},\Lambda}}\right|^{2} \\ & + \left|{{\mathcal{A}}_{\frac{1}{2},0,\frac{1}{2},0,- \frac{1}{2},\Delta} + {\mathcal{A}}_{\frac{1}{2},0,\frac{1}{2},0,- \frac{1}{2},\Lambda} + {\mathcal{A}}_{\frac{1}{2},0,\frac{1}{2},0,\frac{1}{2},\Delta} + {\mathcal{A}}_{\frac{1}{2},0,\frac{1}{2},0,\frac{1}{2},\Lambda}}\right|^{2} \end{align*}\end{split}\]
\[\displaystyle \begin{align*} {\mathcal{A}}_{- \frac{1}{2},0,- \frac{1}{2},0,- \frac{1}{2},\Delta} = & \frac{C_{\Lambda_{c}^{+} \to \Delta_{-1/2} K^{-}_{0}; \Delta \to p_{+1/2} \pi^{+}_{0}} \Gamma_{\Delta} m_{\Delta} D^{\frac{1}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}{- i \Gamma_{\Delta} m_{\Delta} - m_{12}^{2} + m_{\Delta}^{2}} \end{align*}\]
\[\displaystyle \begin{align*} {\mathcal{A}}_{- \frac{1}{2},0,- \frac{1}{2},0,\frac{1}{2},\Delta} = & \frac{C_{\Lambda_{c}^{+} \to \Delta_{+1/2} K^{-}_{0}; \Delta \to p_{+1/2} \pi^{+}_{0}} \Gamma_{\Delta} m_{\Delta} D^{\frac{1}{2}}_{- \frac{1}{2},\frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{\frac{1}{2},- \frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}{- i \Gamma_{\Delta} m_{\Delta} - m_{12}^{2} + m_{\Delta}^{2}} \end{align*}\]
\[\displaystyle \begin{align*} {\mathcal{A}}_{- \frac{1}{2},0,\frac{1}{2},0,- \frac{1}{2},\Delta} = & \frac{C_{\Lambda_{c}^{+} \to \Delta_{-1/2} K^{-}_{0}; \Delta \to p_{+1/2} \pi^{+}_{0}} \Gamma_{\Delta} m_{\Delta} D^{\frac{1}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi_{12},\theta_{12},0\right) D^{\frac{3}{2}}_{- \frac{1}{2},\frac{1}{2}}\left(- \phi^{12}_{1},\theta^{12}_{1},0\right)}{- i \Gamma_{\Delta} m_{\Delta} - m_{12}^{2} + m_{\Delta}^{2}} \end{align*}\]

\(\dots\)

The resulting amplitude is indeed identical to the original HelicityModel.expression:

Hide code cell content
b = ampform.get_builder(reaction_two_resonances)
for resonance in reaction_two_resonances.get_intermediate_particles():
    b.set_dynamics(resonance.name, create_relativistic_breit_wigner)
model_two_res = b.formulate()

assert model_two_res.expression == expression.doit().xreplace(amplitudes)

Tip

Currently, amplitudes with different resonances are put under a different amplitude symbol, identified by that resonance. Such resonances can be combined, e.g. \(\mathcal{A}_{\lambda_i} = \mathcal{A}_{\lambda_i,\Delta} + \mathcal{A}_{\lambda_i,\Lambda}\). This would also make it easier to introduce correct interference terms through the \(K\)-matrix.

Question

The helicity of the intermediate state is also passed to the indexed amplitude. This is required for the coefficient name, which has a helicity subscript for the intermediate state, e.g. \(C_{\Lambda_{c}^{+} \to \Lambda_{\color{red}+1/2} \pi^{+}_{0}; \Lambda \to K^{-}_{0} p_{+1/2}}\). Does it really make sense to distinguish coefficients for different helicities of intermediate states?

Spin alignment implementation#

Hide code cell content
import logging
import warnings

import ampform
import graphviz
import qrules
import sympy as sp
from ampform.helicity import formulate_wigner_d
from IPython.display import Math, display

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")


def show_transition(transition, **kwargs):
    if "size" not in kwargs:
        kwargs["size"] = 5
    dot = qrules.io.asdot(transition, **kwargs)
    display(graphviz.Source(dot))
Helicity formalism#

Imagine we want to formulate the amplitude for the following single {external+qrules-0.9.x:class}.StateTransition:

Hide code cell source
full_reaction = qrules.generate_transitions(
    initial_state="J/psi(1S)",
    final_state=["K0", ("Sigma+", [+0.5]), ("p~", [+0.5])],
    allowed_intermediate_particles=["Sigma(1660)~-", "N(1650)+"],
    allowed_interaction_types="strong",
    formalism="helicity",
)
graphs = full_reaction.to_graphs()
single_transition_reaction = full_reaction.from_graphs(
    [graphs[0]], formalism=full_reaction.formalism
)
transition = single_transition_reaction.transitions[0]
show_transition(transition)

The specific spin_projections for each particle only make sense given a specific reference frame. AmpForm’s HelicityAmplitudeBuilder interprets these projections as the helicity \(\lambda=\vec{S}\cdot\vec{p}\) of each particle in the rest frame of the parent particle. For example, the helicity \(\lambda_2=+\tfrac{1}{2}\) of \(\bar p\) is the helicity as measured in the rest frame of resonance \(\bar\Sigma(1660)^-\). The reason is that these helicities are needed when formulating the two-particle state for the decay node \(\bar\Sigma(1660)^- \to K^0\bar p\) (see Helicity versus canonical).

Ignoring dynamics and coefficients, the HelicityModel for this single transition is rather simple:

Hide code cell source
builder = ampform.get_builder(single_transition_reaction)
model = builder.formulate()
model.expression.subs(model.parameter_defaults).subs(1.0, 1)
\[\displaystyle \left|{D^{\frac{1}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi^{02}_{0},\theta^{02}_{0},0\right) D^{1}_{-1,-1}\left(- \phi_{02},\theta_{02},0\right)}\right|^{2}\]

The two Wigner-\(D\) functions come from the two two-body decay nodes that appear in the {external+qrules-0.9.x:class}.StateTransition above. They were formulated as follows:

sp.Mul(
    formulate_wigner_d(transition, node_id=0),
    formulate_wigner_d(transition, node_id=1),
)
\[\displaystyle D^{\frac{1}{2}}_{- \frac{1}{2},- \frac{1}{2}}\left(- \phi^{02}_{0},\theta^{02}_{0},0\right) D^{1}_{-1,-1}\left(- \phi_{02},\theta_{02},0\right)\]

Now, as formulate_wigner_d() explains, the numbers that appear in the Wigner-\(D\) functions here are computed from the helicities of the decay products. But there’s a subtle problem: these helicities are assumed to be in the rest frame of each parent particle. For the first node, this is fine, because the parent particle rest frame matches that of the initial state in the {external+qrules-0.9.x:class}.StateTransition above. In the second node, however, we are in a different rest frame. This can result in phase differences for the different amplitudes.

If there is a single decay Topology in the ReactionInfo object for which we are formulating an amplitude model, the problem we identified here can be ignored. The reason is that the phase difference for each {external+qrules-0.9.x:class}.StateTransition (with each an identical decay Topology) is the same and does not introduce interference effects within the coherent sum. It again becomes a problem, however, when we are formulating an amplitude model with different topologies. An example would be the following reaction:

Hide code cell source
show_transition(full_reaction, collapse_graphs=True)

When formulating the amplitude model for this reaction, the HelicityAmplitudeBuilder implements the ‘standard’ helicity formalism as described in [Richman, 1984, Kutschke, 1996, Chung, 2014] and simply sums over the different amplitudes to get the full amplitude:

Hide code cell source
builder = ampform.get_builder(full_reaction)
model = builder.formulate()
latex = sp.multiline_latex(
    sp.Symbol("I"),
    model.expression.subs(model.parameter_defaults).subs(1.0, 1),
    environment="eqnarray",
)
Math(latex)

As pointed out in [Marangotto, 2020, Mikhasenko et al., 2020, Wang et al., 2020], this is wrong because of the mismatch in reference frames for the helicities.

Aligning reference frames#

In the rest of this document, we follow [Marangotto, 2020] to align all amplitudes in the different topologies back to the initial state reference frame \(A\), so that they can be correctly summed up. Specifically, we want to formulate a new, correctly aligned amplitude \(\mathcal{A}^{A\to 0,1,\dots}_{m_A,m_0,m_1,\dots}\) from the original amplitudes \(\mathcal{A}^{A\to R,S,i,...\to 0,1,\dots}_{\lambda_A,\lambda_0,\lambda_1,\dots}\) by applying Eq.(45) and Eq.(47) for generic, multi-body decays. Here, the \(\lambda\) values are the helicities in the parent rest frame of each two-body decay and the \(m\) are the canonical[1] spin projections in the rest frame of the mother particle that is the same no matter the Topology.

Just as in [Marangotto, 2020], we test the implementation with 1-to-3 body decays. We use the notation from get_boost_chain_suffix() to indicate resonances \(R,S,U\). This results in the following figure for the two alignments sums of Equations (45) and (46) in [Marangotto, 2020]:

Hide code cell source
dot1 = """
digraph {
    bgcolor=none
    rankdir=LR
    edge [arrowhead=none]
    node [shape=none, width=0]
    A
    0 [fontcolor=red]
    1 [fontcolor=green, label=<<o>1</o>>]
    2 [fontcolor=blue, label=<<o>2</o>>]
    { rank=same A }
    { rank=same 0, 1, 2 }
    N0 [label=""]
    N1 [label=""]
    A -> N0 [style=dotted]
    N0 -> N1 [label="R = 01", fontcolor=orange]
    N1 -> 0
    N0 -> 2 [style=dashed]
    N1 -> 1 [style=dashed]
}
"""
dot2 = """
digraph {
    bgcolor=none
    rankdir=LR
    edge [arrowhead=none]
    node [shape=none, width=0]
    A
    0 [label=0, fontcolor=red]
    1 [label=1, fontcolor=green, label=<<o>1</o>>]
    2 [label=2, fontcolor=blue, label=<<o>2</o>>]
    { rank=same A }
    { rank=same 0, 1, 2 }
    N0 [label=""]
    N1 [label=""]
    A -> N0 [style=dotted]
    N0 -> N1 [label="S = 02", fontcolor=violet]
    N1 -> 0
    N0 -> 1 [style=dashed]
    N1 -> 2 [style=dashed]
}
"""
display(*map(graphviz.Source, [dot1, dot2]))

The dashed edges and bars above the state IDs indicate “opposite helicity” states. The helicity of an opposite helicity state gets a minus sign in the Wigner-\(D\) function for a two-body state as formulated by formulate_wigner_d() (see Helicity formalism) and therefore needs to be defined consistently. AmpForm does this with is_opposite_helicity_state().

Opposite helicity states are also of importance in the spin alignment procedure sketched by [Marangotto, 2020]. The Wigner-\(D\) functions that appear in Equations (45) and (46) from [Marangotto, 2020], operate on the spin of the final state, but the angles in the Wigner-\(D\) function are taken from the sibling state:

(1)#\[\begin{split} \begin{eqnarray} \mathcal{A}^{A \to {\color{orange}R},2 \to 0,1,2}_{m_A,m_0,m_1,m_2} &=& \sum_{\lambda_0^{01},\mu_0^{01},\nu_0^{01}} {\color{red}{D^{s_0}_{m_0,\nu_0^{01}}}}\!\left({\color{red}{\alpha_0^{01}, \beta_0^{01}, \gamma_0^{01}}}\right) {\color{red}{D^{s_0}_{\nu_0^{01},\mu_0^{01}}}}\!\left({\color{orange}{\phi_{_{01}}, \theta_{_{01}}}}, 0\right) {\color{red}{D^{s_0}_{\mu_0^{01},\lambda_0^{01}}}}\!\left({\color{red}{\phi_0^{01}, \theta_0^{01}}}\right) \\ &\times& \sum_{\lambda_1^{01},\mu_1^{01},\nu_1^{01}} {\color{green}{D^{s_1}_{m_1,\nu_1^{01}}}}\!\left({\color{green}{\alpha_1^{01}, \beta_1^{01}, \gamma_1^{01}}}\right) {\color{green}{D^{s_1}_{\nu_1^{01},\mu_1^{01}}}}\!\left({\color{orange}{\phi_{_{01}}, \theta_{_{01}}}}, 0\right) {\color{green}{D^{s_1}_{\mu_1^{01},\lambda_1^{01}}}}\!\left({\color{red}{\phi_0^{01}, \theta_0^{01}}}\right) \\ &\times& \sum_{\lambda_2^{01}} {\color{blue}{D^{s_2}_{m_2,\lambda_2^{01}}}}\!\left({\color{orange}{\phi_{_{01}}, \theta_{_{01}}}}, 0\right) \\ &\times& \mathcal{A}^{A \to {\color{orange}R},2 \to 0,1,2}_{m_A,\lambda_0^{01},\bar\lambda_1^{01},\bar\lambda_2^{01}} \end{eqnarray} \end{split}\]
(2)#\[\begin{split} \begin{eqnarray} \mathcal{A}^{A \to {\color{violet}S},1 \to 0,1,2}_{m_A,m_0,m_1,m_2} &=& \sum_{\lambda_0^{02},\mu_0^{02},\nu_0^{02}} {\color{red}{D^{s_0}_{m_0,\nu_0^{02}}}}\!\left({\color{red}{\alpha_0^{02}, \beta_0^{02}, \gamma_0^{02}}}\right) {\color{red}{D^{s_0}_{\nu_0^{02},\mu_0^{02}}}}\!\left({\color{violet}{\phi_{_{02}}, \theta_{_{02}}}}, 0\right) {\color{red}{D^{s_0}_{\mu_0^{02},\lambda_0^{02}}}}\!\left({\color{red}{\phi_0^{02}, \theta_0^{02}}}\right) \\ &\times& \sum_{\lambda_1^{02}} {\color{green}{D^{s_1}_{m_1,\lambda_1^{02}}}}\!\left({\color{violet}{\phi_{_{02}}, \theta_{_{02}}}}, 0\right) \\ &\times& \sum_{\lambda_2^{02},\mu_2^{02},\nu_2^{02}} {\color{blue}{D^{s_2}_{m_2,\nu_2^{02}}}}\!\left({\color{blue}{\alpha_2^{02}, \beta_2^{02}, \gamma_2^{02}}}\right) {\color{blue}{D^{s_2}_{\nu_2^{02},\mu_2^{02}}}}\!\left({\color{violet}{\phi_{_{02}}, \theta_{_{02}}}}, 0\right) {\color{blue}{D^{s_2}_{\mu_2^{02},\lambda_2^{02}}}}\!\left({\color{red}{\phi_0^{02}, \theta_0^{02}}}\right) \\ &\times& \mathcal{A}^{A \to {\color{violet}S},2 \to 0,1,2}_{m_A,\lambda_0^{02},\bar\lambda_1^{02},\bar\lambda_2^{02}} \end{eqnarray} \end{split}\]

This procedure also allows us to formulate the alignment summation for \(\mathcal{A}^{A \to {\color{turquoise}U},0 \to 0,1,2}_{m_A,m_0,m_1,m_2}\):

Hide code cell source
dot3 = """
digraph {
    bgcolor=none
    rankdir=LR
    edge [arrowhead=none]
    node [shape=none, width=0]
    0 [shape=none, label=0, fontcolor=red]
    1 [shape=none, label=1, fontcolor=green]
    2 [shape=none, label=2, fontcolor=blue, label=<<o>2</o>>]
    A [shape=none, label=A]
    { rank=same A }
    { rank=same 0, 1, 2 }
    N0 [label=""]
    N1 [label=""]
    A -> N0 [style=dotted]
    N0 -> N1 [label=<U =<o>12</o>>, fontcolor=turquoise, style=dashed]
    N0 -> 0
    N1 -> 1
    N1 -> 2 [style=dashed]
}
"""
graphviz.Source(dot3)

(3)#\[\begin{split} \begin{eqnarray} \mathcal{A}^{A \to {\color{turquoise}U},0 \to 0,1,2}_{m_A,m_0,m_1,m_2} &=& \sum_{\lambda_0^{12}} {\color{red}{D^{s_0}_{m_0,\lambda_0^{12}}}}\!\left({\color{red}{\phi_0, \theta_0}}, 0\right) \\ &\times& \sum_{\lambda_1^{12},\mu_1^{12},\nu_1^{12}} {\color{green}{D^{s_0}_{m_0,\nu_1^{12}}}}\!\left({\color{green}{\alpha_1^{12}, \beta_1^{12}, \gamma_1^{12}}}\right) {\color{green}{D^{s_0}_{\nu_1^{12},\mu_1^{12}}}}\!\left({\color{red}{\phi_0, \theta_0}}, 0\right) {\color{green}{D^{s_0}_{\mu_1^{12},\lambda_1^{12}}}}\!\left({\color{green}{\phi_1^{12}, \theta_1^{12}}}\right) \\ &\times& \sum_{\lambda_2^{12},\mu_2^{12},\nu_2^{12}} {\color{blue}{D^{s_2}_{m_2,\nu_2^{12}}}}\!\left({\color{blue}{\alpha_2^{12}, \beta_2^{12}, \gamma_2^{12}}}\right) {\color{blue}{D^{s_2}_{\nu_2^{12},\mu_2^{12}}}}\!\left({\color{red}{\phi_0, \theta_0}}, 0\right) {\color{blue}{D^{s_2}_{\mu_2^{12},\lambda_2^{12}}}}\!\left({\color{green}{\phi_1^{12}, \theta_1^{12}}}\right) \\ &\times& \mathcal{A}^{A \to {\color{turquoise}U},2 \to 0,1,2}_{m_A,\lambda_1^{12},\bar\lambda_1^{12},\bar\lambda_2^{12}} \end{eqnarray} \end{split}\]

Finally, the total intensity can be computed from these amplitudes by incoherently summing over the initial and final state canonical spin projections (see Equation (47) in [Marangotto, 2020]):

(4)#\[ I = \sum_{m_A,m_0,m_1,m_2}\left| \mathcal{A}^{A \to {\color{orange}R},2 \to 0,1,2}_{m_A,m_0,m_1,m_2} + \mathcal{A}^{A \to {\color{violet}S},1 \to 0,1,2}_{m_A,m_0,m_1,m_2} + \mathcal{A}^{A \to {\color{turquoise}U},0 \to 0,1,2}_{m_A,m_0,m_1,m_2} \right|^2 \]
\(J/\psi \to K^0 \Sigma^+ \bar{p}\)#
Hide code cell content
from ampform.helicity import (
    formulate_helicity_rotation_chain,
    formulate_rotation_chain,
    formulate_spin_alignment,
)


def show_all_spin_matrices(transition, functor, cleanup: bool) -> None:
    for i in transition.final_states:
        state = transition.states[i]
        particle_name = state.particle.latex
        s = sp.Rational(state.particle.spin)
        m = sp.Rational(state.spin_projection)
        display(
            Math(Rf"|s_{i},m_{i}\rangle=|{s},{m}\rangle \quad ({particle_name})")
        )
        if functor is formulate_rotation_chain:
            args = (transition, i)
        else:
            args = (transition, i, state.spin_projection)
        summation = functor(*args)
        if cleanup:
            summation = summation.cleanup()
        display(summation)

In this section, we test some of the functions from the helicity and kinematics modules to see if they reproduce Equations (1), (2), and (3). We perform this test on the channel \(J/\psi \to K^0 \Sigma^+ \bar{p}\) with resonances generated for each of the three allowed three-body topologies. The transition that corresponds to Equation (1) is shown below.

The first step is to use formulate_helicity_rotation_chain() to generate the Wigner-\(D\) functions for all helicity rotations for each final state. These helicity rotations “undo” all rotations that came from each Lorentz boosts when boosting from initial state \(J/\psi\) to each final state:

Hide code cell source
transition_r = full_reaction.transitions[-1]
show_all_spin_matrices(transition_r, formulate_helicity_rotation_chain, cleanup=True)
\[\displaystyle |s_0,m_0\rangle=|0,0\rangle \quad (K^{0})\]
\[\displaystyle D^{0}_{0,0}\left(\phi^{01}_{0},\theta^{01}_{0},0\right) D^{0}_{\nu^{01}_{0},0}\left(\phi_{01},\theta_{01},0\right)\]
\[\displaystyle |s_1,m_1\rangle=|1/2,1/2\rangle \quad (\Sigma^{+})\]
\[\displaystyle \sum_{\lambda^{01}_{1}=-1/2}^{1/2} \sum_{\mu^{01}_{1}=-1/2}^{1/2}{D^{\frac{1}{2}}_{\mu^{01}_{1},\lambda^{01}_{1}}\left(\phi^{01}_{0},\theta^{01}_{0},0\right) D^{\frac{1}{2}}_{\nu^{01}_{1},\mu^{01}_{1}}\left(\phi_{01},\theta_{01},0\right)}\]
\[\displaystyle |s_2,m_2\rangle=|1/2,1/2\rangle \quad (\overline{p})\]
\[\displaystyle \sum_{\lambda^{01}_{2}=-1/2}^{1/2}{D^{\frac{1}{2}}_{0.5,\lambda^{01}_{2}}\left(\phi_{01},\theta_{01},0\right)}\]
Hide code cell source
show_transition(transition_r)

The function formulate_rotation_chain() goes one step further. It adds a Wigner rotation to the generated list of helicity rotation Wigner-\(D\) functions in case there are resonances in between the initial state and rotated final state. If there are no resonances in between (here, state 2, the \(\bar p\)), there is only one helicity rotation and there is no need for a Wigner rotation.

Hide code cell source
show_all_spin_matrices(transition_r, formulate_rotation_chain, cleanup=False)
\[\displaystyle |s_0,m_0\rangle=|0,0\rangle \quad (K^{0})\]
\[\displaystyle \sum_{\lambda^{01}_{0}=0} \sum_{\mu^{01}_{0}=0} \sum_{\nu^{01}_{0}=0}{D^{0}_{m_{0},\nu^{01}_{0}}\left(\alpha^{01}_{0},\beta^{01}_{0},\gamma^{01}_{0}\right) D^{0}_{\mu^{01}_{0},\lambda^{01}_{0}}\left(\phi^{01}_{0},\theta^{01}_{0},0\right) D^{0}_{\nu^{01}_{0},\mu^{01}_{0}}\left(\phi_{01},\theta_{01},0\right)}\]
\[\displaystyle |s_1,m_1\rangle=|1/2,1/2\rangle \quad (\Sigma^{+})\]
\[\displaystyle \sum_{\lambda^{01}_{1}=-1/2}^{1/2} \sum_{\mu^{01}_{1}=-1/2}^{1/2} \sum_{\nu^{01}_{1}=-1/2}^{1/2}{D^{\frac{1}{2}}_{m_{1},\nu^{01}_{1}}\left(\alpha^{01}_{1},\beta^{01}_{1},\gamma^{01}_{1}\right) D^{\frac{1}{2}}_{\mu^{01}_{1},\lambda^{01}_{1}}\left(\phi^{01}_{0},\theta^{01}_{0},0\right) D^{\frac{1}{2}}_{\nu^{01}_{1},\mu^{01}_{1}}\left(\phi_{01},\theta_{01},0\right)}\]
\[\displaystyle |s_2,m_2\rangle=|1/2,1/2\rangle \quad (\overline{p})\]
\[\displaystyle \sum_{\lambda^{01}_{2}=-1/2}^{1/2}{D^{\frac{1}{2}}_{m_{2},\lambda^{01}_{2}}\left(\phi_{01},\theta_{01},0\right)}\]

These are indeed all the terms that we see in Equation (1)!

To create all sum combinations for all final states, we can use formulate_spin_alignment(). This should give the sum of Eq.(45):

alignment_summation = formulate_spin_alignment(transition_r)
alignment_summation.cleanup()
\[\displaystyle \sum_{\lambda^{01}_{1}=-1/2}^{1/2} \sum_{\mu^{01}_{1}=-1/2}^{1/2} \sum_{\nu^{01}_{1}=-1/2}^{1/2} \sum_{\lambda^{01}_{2}=-1/2}^{1/2}{D^{0}_{0,0}\left(\phi_{01},\theta_{01},0\right) D^{0}_{0,0}\left(\phi^{01}_{0},\theta^{01}_{0},0\right) D^{0}_{m_{0},0}\left(\alpha^{01}_{0},\beta^{01}_{0},\gamma^{01}_{0}\right) D^{\frac{1}{2}}_{m_{1},\nu^{01}_{1}}\left(\alpha^{01}_{1},\beta^{01}_{1},\gamma^{01}_{1}\right) D^{\frac{1}{2}}_{m_{2},\lambda^{01}_{2}}\left(\phi_{01},\theta_{01},0\right) D^{\frac{1}{2}}_{\mu^{01}_{1},\lambda^{01}_{1}}\left(\phi^{01}_{0},\theta^{01}_{0},0\right) D^{\frac{1}{2}}_{\nu^{01}_{1},\mu^{01}_{1}}\left(\phi_{01},\theta_{01},0\right)}\]

Finally, here are the generated spin alignment terms for the other two decay chains. Notice that the first is indeed the same as (2):

Hide code cell source
reaction_s = qrules.generate_transitions(
    initial_state="J/psi(1S)",
    final_state=["K0", ("Sigma+", [+0.5]), ("p~", [+0.5])],
    allowed_intermediate_particles=["N(1650)+"],
    allowed_interaction_types="strong",
    formalism="helicity",
)
transition_s = reaction_s.transitions[0]
show_all_spin_matrices(transition_s, formulate_rotation_chain, cleanup=False)
\[\displaystyle |s_0,m_0\rangle=|0,0\rangle \quad (K^{0})\]
\[\displaystyle \sum_{\lambda^{02}_{0}=0} \sum_{\mu^{02}_{0}=0} \sum_{\nu^{02}_{0}=0}{D^{0}_{m_{0},\nu^{02}_{0}}\left(\alpha^{02}_{0},\beta^{02}_{0},\gamma^{02}_{0}\right) D^{0}_{\mu^{02}_{0},\lambda^{02}_{0}}\left(\phi^{02}_{0},\theta^{02}_{0},0\right) D^{0}_{\nu^{02}_{0},\mu^{02}_{0}}\left(\phi_{02},\theta_{02},0\right)}\]
\[\displaystyle |s_1,m_1\rangle=|1/2,1/2\rangle \quad (\Sigma^{+})\]
\[\displaystyle \sum_{\lambda^{02}_{1}=-1/2}^{1/2}{D^{\frac{1}{2}}_{m_{1},\lambda^{02}_{1}}\left(\phi_{02},\theta_{02},0\right)}\]
\[\displaystyle |s_2,m_2\rangle=|1/2,1/2\rangle \quad (\overline{p})\]
\[\displaystyle \sum_{\lambda^{02}_{2}=-1/2}^{1/2} \sum_{\mu^{02}_{2}=-1/2}^{1/2} \sum_{\nu^{02}_{2}=-1/2}^{1/2}{D^{\frac{1}{2}}_{m_{2},\nu^{02}_{2}}\left(\alpha^{02}_{2},\beta^{02}_{2},\gamma^{02}_{2}\right) D^{\frac{1}{2}}_{\mu^{02}_{2},\lambda^{02}_{2}}\left(\phi^{02}_{0},\theta^{02}_{0},0\right) D^{\frac{1}{2}}_{\nu^{02}_{2},\mu^{02}_{2}}\left(\phi_{02},\theta_{02},0\right)}\]
Hide code cell source
show_transition(transition_s)

…and that the second matches Equation (3):

Hide code cell source
reaction_u = qrules.generate_transitions(
    initial_state="J/psi(1S)",
    final_state=["K0", ("Sigma+", [+0.5]), ("p~", [+0.5])],
    allowed_intermediate_particles=["K*(1680)~0"],
    allowed_interaction_types="strong",
    formalism="helicity",
)
transition_u = reaction_u.transitions[0]
show_all_spin_matrices(transition_u, formulate_rotation_chain, cleanup=False)
\[\displaystyle |s_0,m_0\rangle=|0,0\rangle \quad (K^{0})\]
\[\displaystyle \sum_{\lambda^{12}_{0}=0}{D^{0}_{m_{0},\lambda^{12}_{0}}\left(\phi_{0},\theta_{0},0\right)}\]
\[\displaystyle |s_1,m_1\rangle=|1/2,1/2\rangle \quad (\Sigma^{+})\]
\[\displaystyle \sum_{\lambda^{12}_{1}=-1/2}^{1/2} \sum_{\mu^{12}_{1}=-1/2}^{1/2} \sum_{\nu^{12}_{1}=-1/2}^{1/2}{D^{\frac{1}{2}}_{m_{1},\nu^{12}_{1}}\left(\alpha^{12}_{1},\beta^{12}_{1},\gamma^{12}_{1}\right) D^{\frac{1}{2}}_{\mu^{12}_{1},\lambda^{12}_{1}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right) D^{\frac{1}{2}}_{\nu^{12}_{1},\mu^{12}_{1}}\left(\phi_{0},\theta_{0},0\right)}\]
\[\displaystyle |s_2,m_2\rangle=|1/2,1/2\rangle \quad (\overline{p})\]
\[\displaystyle \sum_{\lambda^{12}_{2}=-1/2}^{1/2} \sum_{\mu^{12}_{2}=-1/2}^{1/2} \sum_{\nu^{12}_{2}=-1/2}^{1/2}{D^{\frac{1}{2}}_{m_{2},\nu^{12}_{2}}\left(\alpha^{12}_{2},\beta^{12}_{2},\gamma^{12}_{2}\right) D^{\frac{1}{2}}_{\mu^{12}_{2},\lambda^{12}_{2}}\left(\phi^{12}_{1},\theta^{12}_{1},0\right) D^{\frac{1}{2}}_{\nu^{12}_{2},\mu^{12}_{2}}\left(\phi_{0},\theta_{0},0\right)}\]
Hide code cell source
show_transition(transition_u)

Compute Wigner rotation angles#

Now it’s still a matter of computing the values for the angles \(\alpha,\beta,\gamma\) in the Wigner rotation matrices. These angles represents the difference between the canonical spin frame as attained by a direct boost from the initial state versus a chain of boosts through each resonance. See Equation (36) in [Marangotto, 2020].

The kinematics module can generate an expression for the chain of Lorentz boosts from the initial state to the final state with compute_boost_chain():

Hide code cell source
dot = qrules.io.asdot(transition_u)
topology = transition_u.topology
display(graphviz.Source(dot))

from ampform.kinematics import compute_boost_chain, create_four_momentum_symbols

momenta = create_four_momentum_symbols(topology)
for state_id in topology.outgoing_edge_ids:
    boosts = compute_boost_chain(topology, momenta, state_id)
    display(sp.Array(boosts))
\[\displaystyle \left[\begin{matrix}\boldsymbol{B}\left(p_{0}\right)\end{matrix}\right]\]
\[\displaystyle \left[\begin{matrix}\boldsymbol{B}\left({p}_{12}\right) & \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{1}\right)\end{matrix}\right]\]
\[\displaystyle \left[\begin{matrix}\boldsymbol{B}\left({p}_{12}\right) & \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{2}\right)\end{matrix}\right]\]

This chain of Lorentz boosts needs to be ‘undo’ with a direct Lorentz boost back to the initial state. A contraction of inverse Lorentz boost with the chain of Lorentz boosts can be generated with compute_wigner_rotation_matrix():

from ampform.kinematics import compute_wigner_rotation_matrix

for state_id in topology.outgoing_edge_ids:
    expr = compute_wigner_rotation_matrix(topology, momenta, state_id)
    display(expr)
\[\displaystyle \boldsymbol{B}\left(-\left(p_{0}\right)\right) \boldsymbol{B}\left(p_{0}\right)\]
\[\displaystyle \boldsymbol{B}\left(-\left(p_{1}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{1}\right)\]
\[\displaystyle \boldsymbol{B}\left(-\left(p_{2}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{2}\right)\]

The result of this matrix product is the rotation matrix for the Wigner rotation. The function compute_wigner_angles() computes the required Euler angles from this rotation matrix by implementing Equations (B.2-3) from [Marangotto, 2020]:

from ampform.kinematics import compute_wigner_angles

angles = {}
for state_id in topology.outgoing_edge_ids:
    angle_definitions = compute_wigner_angles(topology, momenta, state_id)
    for name, expr in angle_definitions.items():
        angle_symbol = sp.Symbol(name, real=True)
        angles[angle_symbol] = expr
Hide code cell source
latex_lines = [R"\begin{eqnarray}"]
for symbol, expr in angles.items():
    latex_lines.append(Rf"{sp.latex(symbol)}&=&{sp.latex(expr)}\\")
latex_lines.append(R"\end{eqnarray}")
Math("\n".join(latex_lines))
\[\begin{split}\displaystyle \begin{eqnarray} \alpha^{12}_{0}&=&\operatorname{atan_{2}}{\left(\boldsymbol{B}\left(-\left(p_{0}\right)\right) \boldsymbol{B}\left(p_{0}\right)\left[:, 3, 2\right],\boldsymbol{B}\left(-\left(p_{0}\right)\right) \boldsymbol{B}\left(p_{0}\right)\left[:, 3, 1\right] \right)}\\ \beta^{12}_{0}&=&\operatorname{acos}{\left(\boldsymbol{B}\left(-\left(p_{0}\right)\right) \boldsymbol{B}\left(p_{0}\right)\left[:, 3, 3\right] \right)}\\ \gamma^{12}_{0}&=&\operatorname{atan_{2}}{\left(\boldsymbol{B}\left(-\left(p_{0}\right)\right) \boldsymbol{B}\left(p_{0}\right)\left[:, 2, 3\right],- \boldsymbol{B}\left(-\left(p_{0}\right)\right) \boldsymbol{B}\left(p_{0}\right)\left[:, 1, 3\right] \right)}\\ \alpha^{12}_{1}&=&\operatorname{atan_{2}}{\left(\boldsymbol{B}\left(-\left(p_{1}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{1}\right)\left[:, 3, 2\right],\boldsymbol{B}\left(-\left(p_{1}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{1}\right)\left[:, 3, 1\right] \right)}\\ \beta^{12}_{1}&=&\operatorname{acos}{\left(\boldsymbol{B}\left(-\left(p_{1}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{1}\right)\left[:, 3, 3\right] \right)}\\ \gamma^{12}_{1}&=&\operatorname{atan_{2}}{\left(\boldsymbol{B}\left(-\left(p_{1}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{1}\right)\left[:, 2, 3\right],- \boldsymbol{B}\left(-\left(p_{1}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{1}\right)\left[:, 1, 3\right] \right)}\\ \alpha^{12}_{2}&=&\operatorname{atan_{2}}{\left(\boldsymbol{B}\left(-\left(p_{2}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{2}\right)\left[:, 3, 2\right],\boldsymbol{B}\left(-\left(p_{2}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{2}\right)\left[:, 3, 1\right] \right)}\\ \beta^{12}_{2}&=&\operatorname{acos}{\left(\boldsymbol{B}\left(-\left(p_{2}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{2}\right)\left[:, 3, 3\right] \right)}\\ \gamma^{12}_{2}&=&\operatorname{atan_{2}}{\left(\boldsymbol{B}\left(-\left(p_{2}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{2}\right)\left[:, 2, 3\right],- \boldsymbol{B}\left(-\left(p_{2}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{2}\right)\left[:, 1, 3\right] \right)}\\ \end{eqnarray}\end{split}\]

Note

In the topology underlying (3), the Wigner rotation matrix with angles \(\alpha_0^{12}, \beta_0^{12}, \gamma_0^{12}\) is simply the identity matrix. This is the reason why it can be omitted in formulate_rotation_chain() and we only have one helicity rotation.

Computational code#

Classes like BoostMatrix have been split up into smaller unevaluated() expression classes so that lambdification to NumPy code results in relatively small and fast code, when using cse=True in lambdify() (see NumPyPrintable).

Hide code cell source
import inspect

beta = sp.Symbol("beta_1^12", real=True)
beta_expr = angles[beta]

func = sp.lambdify(momenta.values(), beta_expr.doit(), cse=True)
src = inspect.getsource(func)
n_characters = len(src)
latex = sp.latex(beta)
latex += Rf":\quad\text{{{n_characters:,} characters in generated code}}"
Math(latex)
\[\displaystyle \beta^{12}_{1}:\quad\text{2,147 characters in generated code}\]
Test on data sample#

Tip

A test with a larger data distribution is being developed in TR-013.

The following phase space mini-sample of four-momenta has been generated for the decay \(J/\psi \to K^0 \Sigma^+ \bar{p}\) with the tensorwaves.data module.

import numpy as np

phsp = {
    "p0": np.array([
        [0.63140486, 0.13166435, -0.35734744, 0.07760603],
        [0.65169531, 0.37242432, 0.12027178, 0.15467675],
        [0.60647425, -0.22286205, -0.175258, 0.19952806],
        [0.72744323, 0.05529811, 0.30502402, -0.43064999],
        [0.76778868, -0.43557036, 0.35491651, -0.16185017],
    ]),
    "p1": np.array([
        [1.37017117, 0.173769668, 0.355893315, -0.553093198],
        [1.34556663, -5.272033e-04, -0.3074542, -0.54901747],
        [1.41660182, 0.634007973, -0.0457976658, -0.433700564],
        [1.38592340, 0.138369075, -0.258624859, 0.648189682],
        [1.37858847, 0.551405385, -0.338705615, 0.259105737],
    ]),
    "p2": np.array([
        [1.09532397, -0.30543402, 0.00145413, 0.47548716],
        [1.09963805, -0.37189712, 0.18718247, 0.39434072],
        [1.07382393, -0.41114592, 0.22105567, 0.2341725],
        [0.98353336, -0.19366719, -0.04639917, -0.21753969],
        [0.95052285, -0.11583502, -0.01621089, -0.09725557],
    ]),
}
matrix_expr = compute_wigner_rotation_matrix(topology, momenta, state_id=2)
matrix_expr
\[\displaystyle \boldsymbol{B}\left(-\left(p_{2}\right)\right) \boldsymbol{B}\left({p}_{12}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{12}\right) p_{2}\right)\]
matrix_func = sp.lambdify(momenta.values(), matrix_expr.doit(), cse=True)
matrix_array = matrix_func(*phsp.values())
np.round(matrix_array, decimals=2).real
array([[[ 1.  , -0.  , -0.  , -0.  ],
        [-0.  ,  1.  ,  0.02, -0.02],
        [-0.  , -0.02,  1.  ,  0.03],
        [ 0.  ,  0.02, -0.03,  1.  ]],

       [[ 1.  ,  0.  , -0.  ,  0.  ],
        [ 0.  ,  1.  , -0.02, -0.04],
        [ 0.  ,  0.02,  1.  , -0.  ],
        [ 0.  ,  0.04,  0.  ,  1.  ]],

       [[ 1.  ,  0.  ,  0.  ,  0.  ],
        [-0.  ,  1.  ,  0.02, -0.01],
        [ 0.  , -0.02,  1.  ,  0.02],
        [ 0.  ,  0.01, -0.02,  1.  ]],

       [[ 1.  ,  0.  , -0.  ,  0.  ],
        [ 0.  ,  1.  , -0.01,  0.02],
        [ 0.  ,  0.01,  1.  ,  0.02],
        [ 0.  , -0.02, -0.02,  1.  ]],

       [[ 1.  , -0.  ,  0.  , -0.  ],
        [-0.  ,  1.  , -0.01, -0.01],
        [ 0.  ,  0.01,  1.  ,  0.01],
        [-0.  ,  0.01, -0.01,  1.  ]]])
Hide code cell source
latex_lines = [R"\begin{eqnarray}"]
for angle_symbol, angle_expr in angles.items():
    angle_func = sp.lambdify(momenta.values(), angle_expr.doit(), cse=True)
    angle_array = angle_func(*phsp.values())
    rounded_values = np.round(angle_array, decimals=2).real
    latex_lines.append(
        Rf"{sp.latex(angle_symbol)}&=&{sp.latex(sp.Array(rounded_values))}\\"
    )
latex_lines.append(R"\end{eqnarray}")
Math("\n".join(latex_lines))
\[\begin{split}\displaystyle \begin{eqnarray} \alpha^{12}_{0}&=&\left[\begin{matrix}-1.25 & 0.0 & 0.79 & 1.33 & 2.36\end{matrix}\right]\\ \beta^{12}_{0}&=&\left[\begin{matrix}0.0 & \text{NaN} & 0.0 & 0.0 & 0.0\end{matrix}\right]\\ \gamma^{12}_{0}&=&\left[\begin{matrix}-1.89 & 3.14 & 2.36 & 1.82 & 0.79\end{matrix}\right]\\ \alpha^{12}_{1}&=&\left[\begin{matrix}2.03 & -3.04 & 1.9 & 0.74 & 2.14\end{matrix}\right]\\ \beta^{12}_{1}&=&\left[\begin{matrix}0.03 & 0.03 & 0.01 & 0.02 & 0.01\end{matrix}\right]\\ \gamma^{12}_{1}&=&\left[\begin{matrix}-2.05 & 3.06 & -1.92 & -0.73 & -2.13\end{matrix}\right]\\ \alpha^{12}_{2}&=&\left[\begin{matrix}-1.09 & 0.08 & -1.22 & -2.41 & -1.01\end{matrix}\right]\\ \beta^{12}_{2}&=&\left[\begin{matrix}0.04 & 0.04 & 0.02 & 0.03 & 0.01\end{matrix}\right]\\ \gamma^{12}_{2}&=&\left[\begin{matrix}1.11 & -0.1 & 1.25 & 2.4 & 1.0\end{matrix}\right]\\ \end{eqnarray}\end{split}\]
Hide code cell source
dot = qrules.io.asdot(transition_u, collapse_graphs=True)
graphviz.Source(dot)

Note

The NAN values above come from the fact that the inverse boost on a boost results in negative values under the square root of \(\gamma=\sqrt{1-\beta^2}\). This can be ignored, because the Wigner rotation is simply omitted when formulating the chain of rotation matrices, as noted in Compute Wigner rotation angles.

Four-body decay#

The algorithm for computing Euler angles for the Wigner rotation works an arbitrary number of final states. Here, we illustrate this by formulating an expression for the Wigner rotation matrix in a four-body decay.

from qrules.topology import create_isobar_topologies

topology_4body = create_isobar_topologies(4)[1]
momenta_4body = create_four_momentum_symbols(topology_4body)
compute_wigner_rotation_matrix(topology_4body, momenta_4body, state_id=3)
\[\displaystyle \boldsymbol{B}\left(-\left(p_{3}\right)\right) \boldsymbol{B}\left({p}_{123}\right) \boldsymbol{B}\left(\boldsymbol{B}\left({p}_{123}\right) {p}_{23}\right) \boldsymbol{B}\left(\boldsymbol{B}\left(\boldsymbol{B}\left({p}_{123}\right) {p}_{23}\right) \boldsymbol{B}\left({p}_{123}\right) p_{3}\right)\]
Hide code cell source
dot = qrules.io.asdot(topology_4body)
graphviz.Source(dot)


Complex integral#

Complex integration#

SciPy cannot integrate complex functions:

from scipy.integrate import quad


def integrand(x):
    return x * (x + 1j)


quad(integrand, 0.0, 2.0)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [3], in <cell line: 8>()
      4 def integrand(x):
      5     return x * (x + 1j)
----> 8 quad(integrand, 0.0, 2.0)

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/scipy/integrate/_quadpack_py.py:351, in quad(func, a, b, args, full_output, epsabs, epsrel, limit, points, weight, wvar, wopts, maxp1, limlst)
    348 flip, a, b = b < a, min(a, b), max(a, b)
    350 if weight is None:
--> 351     retval = _quad(func, a, b, args, full_output, epsabs, epsrel, limit,
    352                    points)
    353 else:
    354     if points is not None:

File ~/miniconda3/envs/compwa-org/lib/python3.8/site-packages/scipy/integrate/_quadpack_py.py:463, in _quad(func, a, b, args, full_output, epsabs, epsrel, limit, points)
    461 if points is None:
    462     if infbounds == 0:
--> 463         return _quadpack._qagse(func,a,b,args,full_output,epsabs,epsrel,limit)
    464     else:
    465         return _quadpack._qagie(func,bound,infbounds,args,full_output,epsabs,epsrel,limit)

TypeError: can't convert complex to float
Split real and imaginary integral#

A proposed solution is to wrap the quad() function in a special integrate function that integrates the real and imaginary part of a function separately:

import numpy as np


def complex_integrate(func, a, b, **quad_kwargs):
    def real_func(x):
        return np.real(func(x))

    def imag_func(x):
        return np.imag(func(x))

    real_integral, real_integral_err = quad(real_func, a, b, **quad_kwargs)
    imag_integral, imag_integral_err = quad(imag_func, a, b, **quad_kwargs)
    return (
        real_integral + 1j * imag_integral,
        real_integral_err**2 + 1j * imag_integral_err,
    )
complex_integrate(integrand, 0.0, 2.0)
((2.666666666666667+2j), (8.765121169122355e-28+2.220446049250313e-14j))

Warning

The handling of uncertainties is incorrect.

Integrate with quadpy#

Alternatively, one could use quadpy, which essentially does the same as in Split real and imaginary integral, but can also (to a large degree) handle vectorized input and properly handles uncertainties.

import quadpy

quadpy.quad(integrand, a=0.0, b=2.0)
((2.6666666666666665+2.0000000000000004j), 2.0082667671941473e-19)

Note

One may need to play around with the tolerance if the function is not smooth, see sigma-py/quadpy#255.

Tip

quadpy raises exceptions with ModuleNotFoundErrors that are a bit unreadable. They are caused by the fact that orthopy and ndim need to be installed separately.

Vectorized input#

The dispersion integral from Eq. (2) in TR-003 features a variable \(s\) that is an argument to the function \(\Sigma_a\). This becomes a problem when \(s\) gets vectorized (in this case: gets an event-wise numpy.array of invariant masses). Here’s a simplified version of the problem:

from functools import partial


def parametrized_func(s_prime, s):
    return s_prime * (s_prime + s + 1j)


s_array = np.linspace(-1, 1, num=10)
quadpy.quad(
    partial(parametrized_func, s=s_array),
    a=0.0,
    b=2.0,
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [7], in <cell line: 9>()
      5     return s_prime * (s_prime + s + 1j)
      8 s_array = np.linspace(-1, 1, num=10)
----> 9 quadpy.quad(
     10     partial(parametrized_func, s=s_array),
     11     a=0.0,
     12     b=2.0,
     13 )

File <string>:55, in quad(f, a, b, args, epsabs, epsrel, limit)

File <string>:45, in integrate_adaptive(f, intervals, eps_abs, eps_rel, criteria_connection, kronrod_degree, minimum_interval_length, max_num_subintervals, dot, domain_shape, range_shape)

File <string>:129, in _gauss_kronrod_integrate(k, f, intervals, dot, domain_shape, range_shape)

File <string>:49, in g(x)

Input In [7], in parametrized_func(s_prime, s)
      4 def parametrized_func(s_prime, s):
----> 5     return s_prime * (s_prime + s + 1j)

ValueError: operands could not be broadcast together with shapes (21,) (10,) 

The way out seems to be to vectorize the quadpy.quad() call itself and forward the function arguments through functools.partial():

from functools import partial


@np.vectorize
def vectorized_quad(func, a, b, **func_kwargs):
    return quadpy.quad(partial(func, **func_kwargs), a, b)


vectorized_quad(parametrized_func, a=0.0, b=2.0, s=s_array)
(array([0.66666667+2.j, 1.11111111+2.j, 1.55555556+2.j, 2.        +2.j,
        2.44444444+2.j, 2.88888889+2.j, 3.33333333+2.j, 3.77777778+2.j,
        4.22222222+2.j, 4.66666667+2.j]),
 array([1.94765926e-19, 2.69476631e-19, 2.24127752e-19, 2.79100064e-19,
        2.67216263e-19, 1.43065895e-19, 3.08910645e-19, 3.62329394e-19,
        4.86795288e-19, 2.21702097e-19]))

Note, however, that this becomes difficult to implement as lambdify() output for a sympy.Integral. An attempt at this is made in TR-003.

SymPy integral#

There is no good way to write integrals as SymPy expressions that correctly lambdify() to a vectorized integral that handles complex values. Here is a first step however. Note that this integral expression class derives from sympy.Integral and:

  1. overwrites its doit() method, so that the integral cannot be evaluated by SymPy.

  2. provides a custom NumPy printer method (see TR-001) that lambdifies this expression node to quadpy.quad().

  3. adds class variables that can affect the behavior of quadpy.quad().

import sympy as sp
from sympy.printing.pycode import _unpack_integral_limits


class UnevaluatableIntegral(sp.Integral):
    abs_tolerance = 1e-5
    rel_tolerance = 1e-5
    limit = 50

    def doit(self, **hints):
        args = [arg.doit(**hints) for arg in self.args]
        return self.func(*args)

    def _numpycode(self, printer, *args):
        integration_vars, limits = _unpack_integral_limits(self)
        if len(limits) != 1:
            msg = f"Cannot handle {len(limits)}-dimensional integrals"
            raise ValueError(msg)
        integrate = "quadpy_quad"
        printer.module_imports["quadpy"].update({f"quad as {integrate}"})
        limit_str = "{}, {}".format(*tuple(map(printer._print, limits[0])))
        args = ", ".join(map(printer._print, integration_vars))
        expr = printer._print(self.args[0])
        return (
            f"{integrate}(lambda {args}: {expr}, {limit_str},"
            f" epsabs={self.abs_tolerance}, epsrel={self.abs_tolerance},"
            f" limit={self.limit})[0]"
        )

To test whether this works, we write the expression from Vectorized input as a sympy.Expr. Note that the integral indeed does not evaluate when calling doit():

s, s_prime, a, b = sp.symbols("s s_p a b")
integral_expr: sp.Expr = UnevaluatableIntegral(
    s_prime * (s_prime + s + sp.I),
    (s_prime, a, b),
)
integral_expr.doit()
\[\displaystyle \int\limits_{a}^{b} s_{p} \left(s + s_{p} + i\right)\, ds_{p}\]

Indeed the expression correctly lambdifies correctly:

import inspect

integral_func = sp.lambdify([s, a, b], integral_expr)
src = inspect.getsource(integral_func)
print(src)
def _lambdifygenerated(s, a, b):
    return quadpy_quad(lambda s_p: s_p*(s + s_p + 1j), a, b, epsabs=1e-05, epsrel=1e-05, limit=50)[0]

Note, however, that the lambdified function has to be vectorized before it can handle numpy.arrays:

vec_integral_func = np.vectorize(integral_func)
vec_integral_func(s_array, a=0.0, b=2.0)
array([0.66666667+2.j, 1.11111111+2.j, 1.55555556+2.j, 2.        +2.j,
       2.44444444+2.j, 2.88888889+2.j, 3.33333333+2.j, 3.77777778+2.j,
       4.22222222+2.j, 4.66666667+2.j])

Tip

For a more complicated and challenging expression, see SymPy expressions in TR-003.

Phase space for a three-body decay#

Hide code cell content
%config InlineBackend.figure_formats = ['svg']
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from ampform.sympy import (
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
    make_commutative,
)
from IPython.display import Math
from ipywidgets import FloatSlider, VBox, interactive_output
from matplotlib.patches import Patch
from tensorwaves.function.sympy import create_parametrized_function

if TYPE_CHECKING:
    from matplotlib.axis import Axis
    from matplotlib.contour import QuadContourSet
    from matplotlib.lines import Line2D

warnings.filterwarnings("ignore")

Kinematics for a three-body decay \(0 \to 123\) can be fully described by two Mandelstam variables \(\sigma_1, \sigma_2\), because the third variable \(\sigma_3\) can be expressed in terms \(\sigma_1, \sigma_2\), the mass \(m_0\) of the initial state, and the masses \(m_1, m_2, m_3\) of the final state. As can be seen, the roles of \(\sigma_1, \sigma_2, \sigma_3\) are interchangeable.

Hide code cell source
def compute_third_mandelstam(sigma1, sigma2, m0, m1, m2, m3) -> sp.Expr:
    return m0**2 + m1**2 + m2**2 + m3**2 - sigma1 - sigma2


m0, m1, m2, m3 = sp.symbols("m:4")
s1, s2, s3 = sp.symbols("sigma1:4")
computed_s3 = compute_third_mandelstam(s1, s2, m0, m1, m2, m3)
Math(Rf"{sp.latex(s3)} = {sp.latex(computed_s3)}")
\[\displaystyle \sigma_{3} = m_{0}^{2} + m_{1}^{2} + m_{2}^{2} + m_{3}^{2} - \sigma_{1} - \sigma_{2}\]

The phase space is defined by the closed area that satisfies the condition \(\phi(\sigma_1,\sigma_2) \leq 0\), where \(\phi\) is a Kibble function:

Hide code cell source
@make_commutative
@implement_doit_method
class Kibble(UnevaluatedExpression):
    def __new__(cls, sigma1, sigma2, sigma3, m0, m1, m2, m3, **hints) -> Kibble:
        return create_expression(
            cls, sigma1, sigma2, sigma3, m0, m1, m2, m3, **hints
        )

    def evaluate(self) -> sp.Expr:
        sigma1, sigma2, sigma3, m0, m1, m2, m3 = self.args
        return KällÊn(
            KällÊn(sigma2, m2**2, m0**2),
            KällÊn(sigma3, m3**2, m0**2),
            KällÊn(sigma1, m1**2, m0**2),
        )

    def _latex(self, printer, *args):
        sigma1, sigma2, *_ = map(printer._print, self.args)
        return Rf"\phi\left({sigma1}, {sigma2}\right)"


@make_commutative
@implement_doit_method
class KällÊn(UnevaluatedExpression):
    def __new__(cls, x, y, z, **hints) -> KällÊn:
        return create_expression(cls, x, y, z, **hints)

    def evaluate(self) -> sp.Expr:
        x, y, z = self.args
        return x**2 + y**2 + z**2 - 2 * x * y - 2 * y * z - 2 * z * x

    def _latex(self, printer, *args):
        x, y, z = map(printer._print, self.args)
        return Rf"\lambda\left({x}, {y}, {z}\right)"


kibble = Kibble(s1, s2, s3, m0, m1, m2, m3)
Math(Rf"{sp.latex(kibble)} = {sp.latex(kibble.doit(deep=False))}")
\[\displaystyle \phi\left(\sigma_{1}, \sigma_{2}\right) = \lambda\left(\lambda\left(\sigma_{2}, m_{2}^{2}, m_{0}^{2}\right), \lambda\left(\sigma_{3}, m_{3}^{2}, m_{0}^{2}\right), \lambda\left(\sigma_{1}, m_{1}^{2}, m_{0}^{2}\right)\right)\]

and \(\lambda\) is the KällÊn function:

Hide code cell source
x, y, z = sp.symbols("x:z")
expr = KällÊn(x, y, z)
Math(f"{sp.latex(expr)} = {sp.latex(expr.doit())}")
\[\displaystyle \lambda\left(x, y, z\right) = x^{2} - 2 x y - 2 x z + y^{2} - 2 y z + z^{2}\]

Any distribution over the phase space can now be defined using a two-dimensional grid over a Mandelstam pair \(\sigma_1,\sigma_2\) of choice, with the condition \(\phi(\sigma_1,\sigma_2)<0\) selecting the values that are physically allowed.

Hide code cell source
def is_within_phasespace(
    sigma1, sigma2, m0, m1, m2, m3, outside_value=sp.nan
) -> sp.Piecewise:
    sigma3 = compute_third_mandelstam(sigma1, sigma2, m0, m1, m2, m3)
    kibble = Kibble(sigma1, sigma2, sigma3, m0, m1, m2, m3)
    return sp.Piecewise(
        (1, sp.LessThan(kibble, 0)),
        (outside_value, True),
    )


is_within_phasespace(s1, s2, m0, m1, m2, m3)
\[\begin{split}\displaystyle \begin{cases} 1 & \text{for}\: \phi\left(\sigma_{1}, \sigma_{2}\right) \leq 0 \\\text{NaN} & \text{otherwise} \end{cases}\end{split}\]
phsp_expr = is_within_phasespace(s1, s2, m0, m1, m2, m3, outside_value=0)
phsp_func = create_parametrized_function(
    phsp_expr.doit(),
    parameters={m0: 2.2, m1: 0.2, m2: 0.4, m3: 0.4},
    backend="numpy",
)
Hide code cell source
sliders = {
    "m0": FloatSlider(description="m0", max=3, value=2.1, step=0.01),
    "m1": FloatSlider(description="m1", max=2, value=0.2, step=0.01),
    "m2": FloatSlider(description="m2", max=2, value=0.4, step=0.01),
    "m3": FloatSlider(description="m3", max=2, value=0.4, step=0.01),
}

resolution = 300
X, Y = np.meshgrid(
    np.linspace(0, 4, num=resolution),
    np.linspace(0, 4, num=resolution),
)
data = {"sigma1": X, "sigma2": Y}

sidebar_ratio = 0.15
fig, ((ax1, _), (ax, ax2)) = plt.subplots(
    figsize=(7, 7),
    ncols=2,
    nrows=2,
    gridspec_kw={
        "height_ratios": [sidebar_ratio, 1],
        "width_ratios": [1, sidebar_ratio],
    },
)
_.remove()
ax.set_xlim(0, 4)
ax.set_ylim(0, 4)
ax.set_xlabel(R"$\sigma_1$")
ax.set_ylabel(R"$\sigma_2$")
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax1.set_xlim(0, 4)
ax2.set_ylim(0, 4)
for a in [ax1, ax2]:
    a.set_xticks([])
    a.set_yticks([])
    a.axis("off")
fig.tight_layout()
fig.subplots_adjust(wspace=0, hspace=0)

fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False

MESH: QuadContourSet | None = None
PROJECTIONS: tuple[Line2D, Line2D] = None
BOUNDARIES: list[Line2D] | None = None


def plot(**parameters):
    draw_boundaries(
        parameters["m0"],
        parameters["m1"],
        parameters["m2"],
        parameters["m3"],
    )
    global MESH, PROJECTIONS
    if MESH is not None:
        for coll in MESH.collections:
            ax.collections.remove(coll)
    phsp_func.update_parameters(parameters)
    Z = phsp_func(data)
    MESH = ax.contour(X, Y, Z, colors="black")
    contour = MESH.collections[0]
    contour.set_facecolor("lightgray")
    x = X[0]
    y = Y[:, 0]
    Zx = np.nansum(Z, axis=0)
    Zy = np.nansum(Z, axis=1)
    if PROJECTIONS is None:
        PROJECTIONS = (
            ax1.plot(x, Zx, c="black", lw=2)[0],
            ax2.plot(Zy, y, c="black", lw=2)[0],
        )
    else:
        PROJECTIONS[0].set_data(x, Zx)
        PROJECTIONS[1].set_data(Zy, y)
    ax1.relim()
    ax2.relim()
    ax1.autoscale_view(scalex=False)
    ax2.autoscale_view(scaley=False)
    create_legend(ax)
    fig.canvas.draw()


def draw_boundaries(m0, m1, m2, m3) -> None:
    global BOUNDARIES
    s1_min = (m2 + m3) ** 2
    s1_max = (m0 - m1) ** 2
    s2_min = (m1 + m3) ** 2
    s2_max = (m0 - m2) ** 2
    if BOUNDARIES is None:
        BOUNDARIES = [
            ax.axvline(s1_min, c="red", ls="dotted", label="$(m_2+m_3)^2$"),
            ax.axhline(s2_min, c="blue", ls="dotted", label="$(m_1+m_3)^2$"),
            ax.axvline(s1_max, c="red", ls="dashed", label="$(m_0-m_1)^2$"),
            ax.axhline(s2_max, c="blue", ls="dashed", label="$(m_0-m_2)^2$"),
        ]
    else:
        BOUNDARIES[0].set_data(get_line_data(s1_min))
        BOUNDARIES[1].set_data(get_line_data(s2_min, horizontal=True))
        BOUNDARIES[2].set_data(get_line_data(s1_max))
        BOUNDARIES[3].set_data(get_line_data(s2_max, horizontal=True))


def create_legend(ax: Axis):
    if ax.get_legend() is not None:
        return
    label = Rf"${sp.latex(kibble)}\leq0$"
    ax.legend(
        handles=[
            Patch(label=label, ec="black", fc="lightgray"),
            *BOUNDARIES,
        ],
        loc="upper right",
        facecolor="white",
        framealpha=1,
    )


def get_line_data(value, horizontal: bool = False) -> np.ndarray:
    pair = (value, value)
    if horizontal:
        return np.array([(0, 1), pair])
    return np.array([pair, (0, 1)])


output = interactive_output(plot, controls=sliders)
VBox([output, *sliders.values()])
Cell output - interactive Dalitz plot

The phase space boundary can be described analytically in terms of \(\sigma_1\) or \(\sigma_2\), in which case there are two solutions:

sol1, sol2 = sp.solve(kibble.doit().subs(s3, computed_s3), s2)
\[\begin{split}\displaystyle \begin{array}{c} \frac{- m_{0}^{2} m_{2}^{2} + m_{0}^{2} m_{3}^{2} + m_{0}^{2} \sigma_{1} + m_{1}^{2} m_{2}^{2} - m_{1}^{2} m_{3}^{2} + m_{1}^{2} \sigma_{1} + m_{2}^{2} \sigma_{1} + m_{3}^{2} \sigma_{1} - \sigma_{1}^{2} - \sqrt{\left(m_{0}^{2} - 2 m_{0} m_{1} + m_{1}^{2} - \sigma_{1}\right) \left(m_{0}^{2} + 2 m_{0} m_{1} + m_{1}^{2} - \sigma_{1}\right) \left(m_{2}^{2} - 2 m_{2} m_{3} + m_{3}^{2} - \sigma_{1}\right) \left(m_{2}^{2} + 2 m_{2} m_{3} + m_{3}^{2} - \sigma_{1}\right)}}{2 \sigma_{1}} \\ \frac{- m_{0}^{2} m_{2}^{2} + m_{0}^{2} m_{3}^{2} + m_{0}^{2} \sigma_{1} + m_{1}^{2} m_{2}^{2} - m_{1}^{2} m_{3}^{2} + m_{1}^{2} \sigma_{1} + m_{2}^{2} \sigma_{1} + m_{3}^{2} \sigma_{1} - \sigma_{1}^{2} + \sqrt{\left(m_{0}^{2} - 2 m_{0} m_{1} + m_{1}^{2} - \sigma_{1}\right) \left(m_{0}^{2} + 2 m_{0} m_{1} + m_{1}^{2} - \sigma_{1}\right) \left(m_{2}^{2} - 2 m_{2} m_{3} + m_{3}^{2} - \sigma_{1}\right) \left(m_{2}^{2} + 2 m_{2} m_{3} + m_{3}^{2} - \sigma_{1}\right)}}{2 \sigma_{1}} \\ \end{array}\end{split}\]

The boundary cannot be parametrized analytically in polar coordinates, but there is a numeric solution. The idea is to solve the condition \(\phi(\sigma_1,\sigma_2)=0\) after the following substitutions:

Hide code cell source
T0, T1, T2, T3 = sp.symbols("T0:4")
r, theta = sp.symbols("r theta", nonnegative=True)
substitutions = {
    s1: (m2 + m3) ** 2 + T1,
    s2: (m1 + m3) ** 2 + T2,
    s3: (m1 + m2) ** 2 + T3,
    T1: T0 / 3 - r * sp.cos(theta),
    T2: T0 / 3 - r * sp.cos(theta + 2 * sp.pi / 3),
    T3: T0 / 3 - r * sp.cos(theta + 4 * sp.pi / 3),
    T0: (
        m0**2
        + m1**2
        + m2**2
        + m3**2
        - (m2 + m3) ** 2
        - (m1 + m3) ** 2
        - (m1 + m2) ** 2
    ),
}
\[\begin{split}\displaystyle \begin{array}{rcl} \sigma_{1} &=& T_{1} + \left(m_{2} + m_{3}\right)^{2} \\ \sigma_{2} &=& T_{2} + \left(m_{1} + m_{3}\right)^{2} \\ \sigma_{3} &=& T_{3} + \left(m_{1} + m_{2}\right)^{2} \\ T_{1} &=& \frac{T_{0}}{3} - r \cos{\left(\theta \right)} \\ T_{2} &=& \frac{T_{0}}{3} + r \sin{\left(\theta + \frac{\pi}{6} \right)} \\ T_{3} &=& \frac{T_{0}}{3} + r \cos{\left(\theta + \frac{\pi}{3} \right)} \\ T_{0} &=& m_{0}^{2} + m_{1}^{2} + m_{2}^{2} + m_{3}^{2} - \left(m_{1} + m_{2}\right)^{2} - \left(m_{1} + m_{3}\right)^{2} - \left(m_{2} + m_{3}\right)^{2} \\ \end{array}\end{split}\]

For every value of \(\theta \in [0, 2\pi)\), the value of \(r\) can now be found by solving the condition \(\phi(r, \theta)=0\). Note that \(\phi(r, \theta)\) is a cubic polynomial of \(r\). For instance, if we take \(m_0=5, m_1=2, m_{2,3}=1\):

Hide code cell source
phi_r = (
    kibble.doit()
    .subs(substitutions)  # substitute sigmas
    .subs(substitutions)  # substitute T123
    .subs(substitutions)  # substitute T0
    .subs({m0: 5, m1: 2, m2: 1, m3: 1})
    .simplify()
    .collect(r)
)
\[\begin{split}\displaystyle \begin{eqnarray} \phi(r, \theta) & = & r^{3} \cdot \left(56 \sqrt{3} \sin{\left(\theta \right)} + 400 \cos^{3}{\left(\theta \right)} - 356 \cos{\left(\theta \right)} + 112 \cos{\left(\theta + \frac{\pi}{3} \right)}\right) \nonumber\\ & & + r^{2} \cdot \left(2000 \cos^{2}{\left(\theta \right)} + 2100\right) \nonumber\\ & & - 4800 r \cos{\left(\theta \right)} \nonumber\\ & & + -25200 \end{eqnarray}\end{split}\]

The lowest value of \(r\) that satisfies \(\phi(r,\theta)=0\) defines the phase space boundary.

Importance sampling#

Model definition#
Hide code cell content
from __future__ import annotations

import logging
import os
import warnings

import jax.numpy as jnp
import numpy as np

logging.getLogger("absl").setLevel(logging.ERROR)  # no JAX warnings
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # no TF warnings
warnings.filterwarnings("ignore")  # sqrt negative argument

We generate data for the reaction \(J/\psi \to \gamma \pi^0\pi^0\). We limit ourselves to two resonances, so that the amplitude model contains one narrow structure. This makes it hard to numerically compute the integral over the intensity distribution.

import qrules

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)(980)", "omega(782)"],
    allowed_interaction_types=["strong", "EM"],
    formalism="canonical-helicity",
)
Hide code cell source
import graphviz

src = qrules.io.asdot(reaction, collapse_graphs=True)
_ = graphviz.Source(src).render("018-graph", format="svg")

import ampform
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)

builder = ampform.get_builder(reaction)
builder.align_spin = False
builder.adapter.permutate_registered_topologies()
builder.scalar_initial_state_mass = True
builder.stable_final_state_ids = [0, 1, 2]
builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = builder.formulate()
Phase space distribution#

An evenly distributed phase space sample can be generated with a TFPhaseSpaceGenerator:

from tensorwaves.data import (
    SympyDataTransformer,
    TFPhaseSpaceGenerator,
    TFUniformRealNumberGenerator,
)

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
transformer = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)
phsp = phsp_generator.generate(1_000_000, rng)
phsp = transformer(phsp)
Hide code cell source
import matplotlib.pyplot as plt


def convert_zero_to_nan(array):
    array = np.array(array).astype("float")
    array[array == 0] = np.nan
    return jnp.array(array)


Z, x_edges, y_edges = jnp.histogram2d(
    phsp["m_01"].real ** 2,
    phsp["m_12"].real ** 2,
    bins=100,
)
X, Y = jnp.meshgrid(x_edges, y_edges)
Z = convert_zero_to_nan(Z)

bin_width_x = X[0, 1] - X[0, 0]
bin_width_y = Y[1, 0] - Y[0, 0]
bar_title = (
    Rf"events per ${1e3*bin_width_x:.0f} \times {1e3*bin_width_y:.0f}$ MeV$^2/c^4$"
)
xlabel = R"$M^2\left(\gamma\pi^0\right)$"
ylabel = R"$M^2\left(\pi^0\pi^0\right)$"

plt.ioff()
fig, ax = plt.subplots(dpi=200, figsize=(4.5, 4))
ax.set_title("TFPhaseSpaceGenerator sample")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
mesh = ax.pcolormesh(X, Y, Z)
c_bar = plt.colorbar(mesh, ax=ax)
c_bar.ax.set_ylabel(bar_title)
fig.savefig("018-TFPhaseSpaceGenerator.png")
plt.ion()
plt.close(fig)

This TFPhaseSpaceGenerator actually uses a hit-and-miss strategy on a distribution and its weights generated by a TFWeightedPhaseSpaceGenerator. That generator interfaces to the phasespace package. We have a short look at the distribution and its weights generated by a TFWeightedPhaseSpaceGenerator. The ‘unweighted’ distribution is uneven, because four-momenta events are generated using a certain decay algorithm. The weights cause these events to be normalized, so that we again have the same, evenly distributed distribution from above when we combine them.

from tensorwaves.data import TFWeightedPhaseSpaceGenerator

weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
unweighted_phsp = weighted_phsp_generator.generate(1_000_000, rng)
phsp_weights = unweighted_phsp["weights"]
unweighted_phsp = transformer(unweighted_phsp)
Hide code cell source
from typing import TYPE_CHECKING

from scipy.interpolate import griddata

if TYPE_CHECKING:
    from tensorwaves.interface import DataSample


def plot_distribution_and_weights(phsp: DataSample, weights: np.ndarray) -> None:
    n_bins = 100
    x = phsp["m_01"].real ** 2
    y = phsp["m_12"].real ** 2
    X, Y = jnp.meshgrid(
        jnp.linspace(x.min(), x.max(), num=n_bins),
        jnp.linspace(y.min(), y.max(), num=n_bins),
    )

    Z_weights = griddata(np.transpose([x, y]), weights, (X, Y))
    Z_unweighted, x_edges, y_edges = jnp.histogram2d(x, y, bins=n_bins)
    Z_weighted, x_edges, y_edges = jnp.histogram2d(
        x, y, bins=n_bins, weights=weights
    )
    # https://numpy.org/doc/stable/reference/generated/numpy.histogram2d.html
    Z_unweighted = Z_unweighted.T
    Z_weighted = Z_weighted.T

    X_edges, Y_edges = jnp.meshgrid(x_edges, y_edges)
    Z_unweighted = convert_zero_to_nan(Z_unweighted)
    Z_weighted = convert_zero_to_nan(Z_weighted)

    _, axes = plt.subplots(
        dpi=200,
        figsize=(16, 5),
        ncols=3,
        tight_layout=True,
    )
    for ax in axes:
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
    axes[0].set_title("Unweighted distribution")
    axes[1].set_title("Weights")
    axes[2].set_title("Weighted phase space distribution")

    mesh = axes[0].pcolormesh(X_edges, Y_edges, Z_unweighted)
    c_bar = plt.colorbar(mesh, ax=axes[0])
    c_bar.ax.set_ylabel(bar_title)

    mesh = axes[1].pcolormesh(X, Y, Z_weights)
    c_bar = plt.colorbar(mesh, ax=axes[1])
    c_bar.ax.set_ylabel("phase space weight")

    mesh = axes[2].pcolormesh(X_edges, Y_edges, Z_weighted)
    c_bar = plt.colorbar(mesh, ax=axes[2])
    c_bar.ax.set_ylabel(bar_title)


plot_distribution_and_weights(unweighted_phsp, phsp_weights)
plt.gcf().suptitle("TFWeightedPhaseSpaceGenerator sample")
plt.savefig("018-TFWeightedPhaseSpaceGenerator.png")
plt.show()

Intensity distribution#

We now use a IntensityDistributionGenerator to generate a hit-and-miss data sample based on the amplitude model that we formulated for this \(J/\psi \to \gamma\pi^0\pi^0\) reaction.

from tensorwaves.function.sympy import create_parametrized_function

intensity_expr = model.expression.doit()
intensity_func = create_parametrized_function(
    expression=intensity_expr,
    parameters=model.parameter_defaults,
    backend="jax",
)
from tensorwaves.data import IntensityDistributionGenerator

data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=intensity_func,
    domain_transformer=transformer,
)
data = data_generator.generate(100_000, rng)
data = transformer(data)

Note that it takes a long time to generate a distribution for amplitude model. This is because most phase space points are outside the region where the intensity is highest and therefore result in a ‘miss’.

Hide code cell source
Z, x_edges, y_edges = jnp.histogram2d(
    data["m_01"].real ** 2,
    data["m_12"].real ** 2,
    bins=100,
)
X, Y = jnp.meshgrid(x_edges, y_edges)
Z = Z.T  # https://numpy.org/doc/stable/reference/generated/numpy.histogram2d.html
Z = convert_zero_to_nan(Z)

plt.ioff()
fig, ax = plt.subplots(dpi=200, figsize=(4.5, 4))
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
mesh = ax.pcolormesh(X, Y, Z)
c_bar = plt.colorbar(mesh, ax=ax)
c_bar.ax.set_ylabel("intensity")
fig.savefig("018-intensity-distribution.png")
plt.ion()
plt.close(fig)

The \(\omega\) resonance appears as a narrow structure on the Dalitz plot. This is problematic when computing the integral over this distribution, which is important when performing an UnbinnedNLL fit. The integral that appears in the log-likelihood has to be computed in each fit iteration and this can be done most efficiently when there are more points on which to evaluate the amplitude model in the phase space regions where the intensity is high.

The solution is to evaluate the intensity over an importance-sampled phase space sample. This is a phase space sample with more events in the regions where the intensity is high. Each point \(\tau\) carries a weight that is set to \(1/I(\tau)\). In fact, all this is, is the intensity-based sample from the previous step, with the weights computed posteriorly by simply evaluating the a amplitude model over the sample (and taking the inverse).

from copy import deepcopy

importance_phsp = deepcopy(data)
importance_weights = 1 / intensity_func(importance_phsp)

Of course, we could define a special class for this.

As expected, the inverse-intensity weights flatten the distribution again to a flat phase space sample:

Hide code cell source
plot_distribution_and_weights(importance_phsp, importance_weights)
plt.gcf().suptitle("Importance-sampled phase space distribution")
plt.savefig("018-importance-sampling.png")
plt.show()

Now, aren’t we duplicating things here? Not really. First, in an actual analysis, there would be no intensity-based data sample. Second, the importance-sampled phase space sample is generated with a specific parameter values. During a fit, the parameters change and the integral over the (importance-sampled) phase space changes. So after updating parameters during a fit iteration, we have to multiply the new intensities with the importance weights (the inverse of the original intensity distribution) in order to get the new distribution. This needs to be done in particular when computing the negative log likelihood (UnbinnedNLL).[1]

In the following, extreme example, we move the mass of the \(f_0(980)\) resonance far from its original position. As can be seen in the distribution below, the narrow structure has indeed moved, but the structure is still visible as a blur in the original position, because there are many more phase space points in that region.

intensity_func.update_parameters({"m_{f_{0}(980)}": 2.0})
new_intensities = intensity_func(importance_phsp)
Hide code cell source
Z, x_edges, y_edges = jnp.histogram2d(
    importance_phsp["m_01"].real ** 2,
    importance_phsp["m_12"].real ** 2,
    bins=100,
    weights=new_intensities * importance_weights,
)
X, Y = jnp.meshgrid(x_edges, y_edges)
Z = Z.T  # https://numpy.org/doc/stable/reference/generated/numpy.histogram2d.html
Z = convert_zero_to_nan(Z)

plt.ioff()
fig, ax = plt.subplots(dpi=200, figsize=(4.5, 4))
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
mesh = ax.pcolormesh(X, Y, Z)
c_bar = plt.colorbar(mesh, ax=ax)
c_bar.ax.set_ylabel(R"new intensity $\times$ importance weight")
fig.savefig("018-importance-sampling-after-modification.png")
plt.ion()
plt.close(fig)


Jupyter notebook with Julia kernel#

This notebook shows that the instructions provided in the Julia installation and IJulia instructions work correctly. The cell outputs below are generated automatically with MyST-NB from the Julia code input.

Simple example:

println("Hello world!")
Hello world!

Here’s an example that prints a Mandelbrot set!

function mandelbrot(a)
    z = 0
    for i=1:50
        z = z^2 + a
    end
    return z
end

for y=1.0:-0.05:-1.0
    for x=-2.0:0.0315:0.5
        abs(mandelbrot(complex(x, y))) < 2 ? print("*") : print(" ")
    end
    println()
end
                                                                                
                                                                                
                                                                                
                                                           **                   
                                                         ******                 
                                                       ********                 
                                                         ******                 
                                                      ******** **   *           
                                              ***   *****************           
                                              ************************  ***     
                                              ****************************      
                                           ******************************       
                                            ******************************      
                                         ************************************   
                                *         **********************************    
                           ** ***** *     **********************************    
                           ***********   ************************************   
                         ************** ************************************    
                         ***************************************************    
                     *****************************************************      
 ***********************************************************************        
                     *****************************************************      
                         ***************************************************    
                         ************** ************************************    
                           ***********   ************************************   
                           ** ***** *     **********************************    
                                *         **********************************    
                                         ************************************   
                                            ******************************      
                                           ******************************       
                                              ****************************      
                                              ************************  ***     
                                              ***   *****************           
                                                      ******** **   *           
                                                         ******                 
                                                       ********                 
                                                         ******                 
                                                           **                   
                                                                                
                                                                                
                                                                                

It’s also possible to work with a local environment from the notebook. In this case, we activate the environment defined by the file 019/Project.toml and instantiate it so that the exact versions of the dependencies as defined in 019/Manifest.toml are installed.

using Pkg
Pkg.activate(joinpath(@__DIR__, "019"))
Pkg.instantiate()
using Images
 
@inline function hsv2rgb(h, s, v)
    c = v * s
    x = c * (1 - abs(((h/60) % 2) - 1))
    m = v - c
    r,g,b = if     h < 60   (c, x, 0)
            elseif h < 120  (x, c, 0)
            elseif h < 180  (0, c, x)
            elseif h < 240  (0, x, c)
            elseif h < 300  (x, 0, c)
            else            (c, 0, x) end
    (r + m), (b + m), (g + m)
end
 
function mandelbrot()
    w       = 1600
    h       = 1200
    zoom    = 0.5
    moveX   = -0.5
    moveY   = 0
    maxIter = 30
    img = Array{RGB{Float64},2}(undef,h,w)
    for x in 1:w
      for y in 1:h
        i = maxIter
        z = c = Complex( (2*x - w) / (w * zoom) + moveX,
                         (2*y - h) / (h * zoom) + moveY )
        while abs(z) < 2 && (i -= 1) > 0
            z = z^2 + c
        end
        r,g,b = hsv2rgb(i / maxIter * 360, 1, i / maxIter)
        img[y,x] = RGB{Float64}(r, g, b)
      end
    end
    return img
end
 
mandelbrot()

Amplitude analysis with zfit#

Hide code cell content
%config InlineBackend.figure_formats = ['svg']
import logging
import os
import warnings

JAX_LOGGER = logging.getLogger("absl")
JAX_LOGGER.setLevel(logging.ERROR)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore")
Formulating the model#
import qrules

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
    allowed_interaction_types=["strong", "EM"],
    formalism="helicity",
)
Hide code cell source
import graphviz

dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)

image

import ampform
from ampform.dynamics.builder import (
    create_non_dynamic_with_ff,
    create_relativistic_breit_wigner_with_ff,
)

model_builder = ampform.get_builder(reaction)
model_builder.scalar_initial_state_mass = True
model_builder.stable_final_state_ids = [0, 1, 2]
model_builder.set_dynamics("J/psi(1S)", create_non_dynamic_with_ff)
for name in reaction.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.formulate()
Generate data#
Phase space sample#
from tensorwaves.data import TFPhaseSpaceGenerator, TFUniformRealNumberGenerator

rng = TFUniformRealNumberGenerator(seed=0)
phsp_generator = TFPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
phsp_momenta = phsp_generator.generate(100_000, rng)
Intensity-based sample#
from tensorwaves.function.sympy import create_function

unfolded_expression = model.expression.doit()
fixed_intensity_func = create_function(
    unfolded_expression.xreplace(model.parameter_defaults),
    backend="jax",
)
from tensorwaves.data import SympyDataTransformer

transform_momenta = SympyDataTransformer.from_sympy(
    model.kinematic_variables, backend="jax"
)
from tensorwaves.data import (
    IntensityDistributionGenerator,
    TFWeightedPhaseSpaceGenerator,
)

weighted_phsp_generator = TFWeightedPhaseSpaceGenerator(
    initial_state_mass=reaction.initial_state[-1].mass,
    final_state_masses={i: p.mass for i, p in reaction.final_state.items()},
)
data_generator = IntensityDistributionGenerator(
    domain_generator=weighted_phsp_generator,
    function=fixed_intensity_func,
    domain_transformer=transform_momenta,
)
data_momenta = data_generator.generate(10_000, rng)
import pandas as pd

phsp = transform_momenta(phsp_momenta)
data = transform_momenta(data_momenta)
pd.DataFrame(data)
m_12 phi_0 phi_1^12 theta_0 theta_1^12
0 1.499845+0.000000j 2.941350 -0.984419 2.344617 1.064114
1 0.580070+0.000000j 1.422127 0.183725 1.086667 1.535691
2 1.495937+0.000000j 2.695585 3.063622 0.777978 1.730394
3 1.172263+0.000000j 0.527850 1.515685 1.343530 0.602596
4 1.581282+0.000000j -0.678981 -2.951556 2.987470 1.959462
... ... ... ... ... ...
9995 1.486016+0.000000j -1.271331 -1.387495 2.792571 2.565453
9996 0.584599+0.000000j -2.452912 -1.957086 1.070889 2.313677
9997 1.956302+0.000000j 0.378314 2.711496 0.588987 1.551541
9998 1.585024+0.000000j -0.816920 -1.166315 2.076068 1.807813
9999 1.712966+0.000000j 0.604657 0.553347 1.264140 2.079405

10000 rows × 5 columns

Hide code cell source
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

resonances = sorted(
    reaction.get_intermediate_particles(),
    key=lambda p: p.mass,
)
evenly_spaced_interval = np.linspace(0, 1, len(resonances))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]
fig, ax = plt.subplots(figsize=(9, 4))
ax.hist(
    np.real(data["m_12"]),
    bins=100,
    alpha=0.5,
    density=True,
)
ax.set_xlabel("$m$ [GeV]")
for p, color in zip(resonances, colors):
    ax.axvline(x=p.mass, linestyle="dotted", label=p.name, color=color)
ax.legend()
plt.show()

Fit#
Determine free parameters#
initial_parameters = {
    R"C_{J/\psi(1S) \to {f_{0}(1500)}_{0} \gamma_{+1}; f_{0}(1500) \to \pi^{0}_{0} \pi^{0}_{0}}": (
        1.0 + 0.0j
    ),
    "m_{f_{0}(500)}": 0.4,
    "m_{f_{0}(980)}": 0.88,
    "m_{f_{0}(1370)}": 1.22,
    "m_{f_{0}(1500)}": 1.45,
    "m_{f_{0}(1710)}": 1.83,
    R"\Gamma_{f_{0}(500)}": 0.3,
    R"\Gamma_{f_{0}(980)}": 0.1,
    R"\Gamma_{f_{0}(1710)}": 0.3,
}
Parametrized function and caching#
from tensorwaves.function.sympy import create_parametrized_function

intensity_func = create_parametrized_function(
    expression=unfolded_expression,
    parameters=model.parameter_defaults,
    backend="jax",
)
from tensorwaves.estimator import create_cached_function

free_parameter_symbols = [
    symbol
    for symbol in model.parameter_defaults
    if symbol.name in set(initial_parameters)
]
cached_intensity_func, transform_to_cache = create_cached_function(
    unfolded_expression,
    parameters=model.parameter_defaults,
    free_parameters=free_parameter_symbols,
    backend="jax",
)
cached_data = transform_to_cache(data)
cached_phsp = transform_to_cache(phsp)
Estimator#
from tensorwaves.estimator import UnbinnedNLL

estimator = UnbinnedNLL(
    intensity_func,
    data=data,
    phsp=phsp,
    backend="jax",
)
estimator_with_caching = UnbinnedNLL(
    cached_intensity_func,
    data=cached_data,
    phsp=cached_phsp,
    backend="jax",
)
Optimize fit parameters#
Hide code cell content
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

reaction_info = model.reaction_info
resonances = sorted(
    reaction_info.get_intermediate_particles(),
    key=lambda p: p.mass,
)

evenly_spaced_interval = np.linspace(0, 1, len(resonances))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]


def indicate_masses(ax):
    ax.set_xlabel("$m$ [GeV]")
    for color, resonance in zip(colors, resonances):
        ax.axvline(
            x=resonance.mass,
            linestyle="dotted",
            label=resonance.name,
            color=color,
        )


def compare_model(
    variable_name,
    data,
    phsp,
    function,
    bins=100,
):
    intensities = function(phsp)
    _, ax = plt.subplots(figsize=(9, 4))
    data_projection = np.real(data[variable_name])
    ax = plt.gca()
    ax.hist(
        data_projection,
        bins=bins,
        alpha=0.5,
        label="data",
        density=True,
    )
    phsp_projection = np.real(phsp[variable_name])
    ax.hist(
        phsp_projection,
        weights=np.array(intensities),
        bins=bins,
        histtype="step",
        color="red",
        label="fit model",
        density=True,
    )
    indicate_masses(ax)
    ax.legend()
original_parameters = intensity_func.parameters
intensity_func.update_parameters(initial_parameters)
compare_model("m_12", data, phsp, intensity_func)

from tensorwaves.optimizer import Minuit2
from tensorwaves.optimizer.callbacks import CSVSummary

minuit2 = Minuit2(
    callback=CSVSummary("fit_traceback.csv"),
    use_analytic_gradient=False,
)
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result
FitResult(
 minimum_valid=True,
 execution_time=7.060763359069824,
 function_calls=539,
 estimator_value=-4891.01730754809,
 parameter_values={
  'm_{f_{0}(500)}': 0.6102707294724865,
  'm_{f_{0}(980)}': 0.9902119846615327,
  'm_{f_{0}(1370)}': 1.3456300421915652,
  'm_{f_{0}(1500)}': 1.50502995100389,
  'm_{f_{0}(1710)}': 1.7096496843682751,
  '\\Gamma_{f_{0}(500)}': 0.4226040807774344,
  '\\Gamma_{f_{0}(980)}': 0.06479339507889993,
  '\\Gamma_{f_{0}(1710)}': 0.13301019075808046,
  'C_{J/\\psi(1S) \\to {f_{0}(1500)}_{0} \\gamma_{+1}; f_{0}(1500) \\to \\pi^{0}_{0} \\pi^{0}_{0}}': (1.0699249014701417-0.018664035501929042j),
 },
 parameter_errors={
  'm_{f_{0}(500)}': 0.006168655466103817,
  'm_{f_{0}(980)}': 0.0016283609785222876,
  'm_{f_{0}(1370)}': 0.005122588422790316,
  'm_{f_{0}(1500)}': 0.0033157863330869892,
  'm_{f_{0}(1710)}': 0.0025660827305775034,
  '\\Gamma_{f_{0}(500)}': 0.023838186430050128,
  '\\Gamma_{f_{0}(980)}': 0.003556673018336295,
  '\\Gamma_{f_{0}(1710)}': 0.007573518980113613,
  'C_{J/\\psi(1S) \\to {f_{0}(1500)}_{0} \\gamma_{+1}; f_{0}(1500) \\to \\pi^{0}_{0} \\pi^{0}_{0}}': (0.04106392764099969+0.07043808181098646j),
 },
)
minuit2 = Minuit2()
fit_result_with_caching = minuit2.optimize(
    estimator_with_caching, initial_parameters
)
fit_result_with_caching
FitResult(
 minimum_valid=True,
 execution_time=3.6658225059509277,
 function_calls=539,
 estimator_value=-4891.01730754809,
 parameter_values={
  'm_{f_{0}(500)}': 0.6102707294731716,
  'm_{f_{0}(980)}': 0.9902119846618569,
  'm_{f_{0}(1370)}': 1.3456300421927978,
  'm_{f_{0}(1500)}': 1.5050299510041418,
  'm_{f_{0}(1710)}': 1.7096496843680975,
  '\\Gamma_{f_{0}(500)}': 0.42260408077678696,
  '\\Gamma_{f_{0}(980)}': 0.06479339507977673,
  '\\Gamma_{f_{0}(1710)}': 0.13301019075895135,
  'C_{J/\\psi(1S) \\to {f_{0}(1500)}_{0} \\gamma_{+1}; f_{0}(1500) \\to \\pi^{0}_{0} \\pi^{0}_{0}}': (1.069924901473717-0.018664035486070114j),
 },
 parameter_errors={
  'm_{f_{0}(500)}': 0.006168655451483166,
  'm_{f_{0}(980)}': 0.0016283609759060128,
  'm_{f_{0}(1370)}': 0.005122588414282541,
  'm_{f_{0}(1500)}': 0.0033157863009583644,
  'm_{f_{0}(1710)}': 0.0025660827200538303,
  '\\Gamma_{f_{0}(500)}': 0.023838186345858253,
  '\\Gamma_{f_{0}(980)}': 0.00355667300785808,
  '\\Gamma_{f_{0}(1710)}': 0.007573518972833387,
  'C_{J/\\psi(1S) \\to {f_{0}(1500)}_{0} \\gamma_{+1}; f_{0}(1500) \\to \\pi^{0}_{0} \\pi^{0}_{0}}': (0.04106392765352627+0.07043808113241967j),
 },
)
Fit result analysis#
Hide code cell source
intensity_func.update_parameters(fit_result.parameter_values)
compare_model("m_12", data, phsp, intensity_func)

Hide code cell source
fit_traceback = pd.read_csv("fit_traceback.csv")
fig, (ax1, ax2) = plt.subplots(
    2, figsize=(7, 9), sharex=True, gridspec_kw={"height_ratios": [1, 2]}
)
fit_traceback.plot("function_call", "estimator_value", ax=ax1)
fit_traceback.plot("function_call", sorted(initial_parameters), ax=ax2)
fig.tight_layout()
ax2.set_xlabel("function call")
plt.show()

Zfit#
PDF definition#
import jax.numpy as jnp
import zfit  # suppress tf warnings
import zfit.z.numpy as znp
from zfit import supports, z

zfit.run.set_graph_mode(False)  # We cannot (yet) compile through the function
zfit.run.set_autograd_mode(False)


class TensorWavesPDF(zfit.pdf.BasePDF):
    def __init__(self, intensity, norm, obs, params=None, name="tensorwaves"):
        """tensorwaves intensity normalized over the *norm* dataset."""
        super().__init__(obs, params, name)
        self.intensity = intensity
        norm = {ob: jnp.asarray(ar) for ob, ar in zip(self.obs, z.unstack_x(norm))}
        self.norm_sample = norm

    @supports(norm=True)
    def _pdf(self, x, norm):
        # we can also use better mechanics, where it automatically normalizes or not
        # this here is rather to take full control, it is always possible

        # updating the parameters of the model. This seems not very TF compatible?
        self.intensity.update_parameters(
            {p.name: float(p) for p in self.params.values()}
        )

        # converting the data to a dict for tensorwaves
        data = {ob: jnp.asarray(ar) for ob, ar in zip(self.obs, z.unstack_x(x))}

        non_normalized_pdf = self.intensity(data)
        # this is not really needed, but can be useful for e.g. sampling with `pdf(..., norm_range=False)`
        if norm is False:
            out = non_normalized_pdf
        else:
            out = non_normalized_pdf / jnp.mean(self.intensity(self.norm_sample))
        return znp.asarray(out)
params = [
    zfit.param.convert_to_parameter(val, name, prefer_constant=False)
    for name, val in model.parameter_defaults.items()
]
def reset_parameters():
    for p in params_fit:
        if p.name in initial_parameters:
            p.set_value(initial_parameters[p.name])
obs = [
    zfit.Space(ob, limits=(np.min(data[ob]) - 1, np.max(data[ob]) + 1))
    for ob in pd.DataFrame(phsp)
]
obs_all = zfit.dimension.combine_spaces(*obs)
Data conversion#
phsp_zfit = zfit.Data.from_pandas(pd.DataFrame(phsp), obs=obs_all)
data_zfit = zfit.Data.from_pandas(pd.DataFrame(data), obs=obs_all)
Perform fit#

complex parameters need to be removed first:

params_fit = [p for p in params if p.name in initial_parameters if p.independent]
jax_intensity_func = create_parametrized_function(
    expression=unfolded_expression,
    parameters=model.parameter_defaults,
    backend="jax",
)
reset_parameters()
pdf = TensorWavesPDF(
    obs=obs_all,
    intensity=jax_intensity_func,
    norm=phsp_zfit,
    params={f"{p.name}": p for i, p in enumerate(params_fit)},
)
loss = zfit.loss.UnbinnedNLL(pdf, data_zfit)
minimizer = zfit.minimize.Minuit(gradient=True, mode=0)

Note

You can also try different minimizers, like ScipyTrustConstrV1, but Minuit seems to perform best.

%%time
result = minimizer.minimize(loss)
result
CPU times: user 22 s, sys: 188 ms, total: 22.2 s
Wall time: 8.56 s
FitResult of
<UnbinnedNLL model=[<zfit.<class '__main__.TensorWavesPDF'>  params=[\Gamma_{f_{0}(1710)}, \Gamma_{f_{0}(500)}, \Gamma_{f_{0}(980)}, m_{f_{0}(1370)}, m_{f_{0}(1500)}, m_{f_{0}(1710)}, m_{f_{0}(500)}, m_{f_{0}(980)}]] data=[<zfit.core.data.Data object at 0x7fdc203d0430>] constraints=[]> 
with
<Minuit Minuit tol=0.001>

╒═════════╤═════════════╤══════════════════╤═════════╤═════════════╕
│ valid   │ converged   │ param at limit   │ edm     │ min value   │
╞═════════╪═════════════╪══════════════════╪═════════╪═════════════╡
│ True    │ True        │ False            │ 0.00041 │ -1871.035   │
╘═════════╧═════════════╧══════════════════╧═════════╧═════════════╛

Parameters
name                    value  (rounded)    at limit
--------------------  ------------------  ----------
m_{f_{0}(500)}                  0.608864       False
\Gamma_{f_{0}(500)}             0.419716       False
m_{f_{0}(980)}                  0.990038       False
\Gamma_{f_{0}(980)}            0.0643328       False
m_{f_{0}(1370)}                  1.35137       False
m_{f_{0}(1500)}                  1.50627       False
m_{f_{0}(1710)}                  1.70956       False
\Gamma_{f_{0}(1710)}            0.132484       False
%%time
result.hesse(name="hesse")
result
CPU times: user 2.5 s, sys: 12.6 ms, total: 2.51 s
Wall time: 953 ms
FitResult of
<UnbinnedNLL model=[<zfit.<class '__main__.TensorWavesPDF'>  params=[\Gamma_{f_{0}(1710)}, \Gamma_{f_{0}(500)}, \Gamma_{f_{0}(980)}, m_{f_{0}(1370)}, m_{f_{0}(1500)}, m_{f_{0}(1710)}, m_{f_{0}(500)}, m_{f_{0}(980)}]] data=[<zfit.core.data.Data object at 0x7fdc203d0430>] constraints=[]> 
with
<Minuit Minuit tol=0.001>

╒═════════╤═════════════╤══════════════════╤═════════╤═════════════╕
│ valid   │ converged   │ param at limit   │ edm     │ min value   │
╞═════════╪═════════════╪══════════════════╪═════════╪═════════════╡
│ True    │ True        │ False            │ 0.00041 │ -1871.035   │
╘═════════╧═════════════╧══════════════════╧═════════╧═════════════╛

Parameters
name                    value  (rounded)        hesse    at limit
--------------------  ------------------  -----------  ----------
m_{f_{0}(500)}                  0.608864  +/-  0.0061       False
\Gamma_{f_{0}(500)}             0.419716  +/-   0.024       False
m_{f_{0}(980)}                  0.990038  +/-  0.0016       False
\Gamma_{f_{0}(980)}            0.0643328  +/-  0.0035       False
m_{f_{0}(1370)}                  1.35137  +/-  0.0039       False
m_{f_{0}(1500)}                  1.50627  +/-   0.002       False
m_{f_{0}(1710)}                  1.70956  +/-  0.0023       False
\Gamma_{f_{0}(1710)}            0.132484  +/-   0.007       False
%%time
result.errors(name="errors")
result
CPU times: user 45.3 s, sys: 393 ms, total: 45.7 s
Wall time: 17.2 s
FitResult of
<UnbinnedNLL model=[<zfit.<class '__main__.TensorWavesPDF'>  params=[\Gamma_{f_{0}(1710)}, \Gamma_{f_{0}(500)}, \Gamma_{f_{0}(980)}, m_{f_{0}(1370)}, m_{f_{0}(1500)}, m_{f_{0}(1710)}, m_{f_{0}(500)}, m_{f_{0}(980)}]] data=[<zfit.core.data.Data object at 0x7fdc203d0430>] constraints=[]> 
with
<Minuit Minuit tol=0.001>

╒═════════╤═════════════╤══════════════════╤═════════╤═════════════╕
│ valid   │ converged   │ param at limit   │ edm     │ min value   │
╞═════════╪═════════════╪══════════════════╪═════════╪═════════════╡
│ True    │ True        │ False            │ 0.00041 │ -1871.035   │
╘═════════╧═════════════╧══════════════════╧═════════╧═════════════╛

Parameters
name                    value  (rounded)        hesse               errors    at limit
--------------------  ------------------  -----------  -------------------  ----------
m_{f_{0}(500)}                  0.608864  +/-  0.0061  -  0.006   + 0.0063       False
\Gamma_{f_{0}(500)}             0.419716  +/-   0.024  -  0.024   +  0.023       False
m_{f_{0}(980)}                  0.990038  +/-  0.0016  - 0.0016   + 0.0016       False
\Gamma_{f_{0}(980)}            0.0643328  +/-  0.0035  - 0.0034   + 0.0036       False
m_{f_{0}(1370)}                  1.35137  +/-  0.0039  - 0.0039   + 0.0039       False
m_{f_{0}(1500)}                  1.50627  +/-   0.002  -  0.002   +  0.002       False
m_{f_{0}(1710)}                  1.70956  +/-  0.0023  - 0.0024   + 0.0024       False
\Gamma_{f_{0}(1710)}            0.132484  +/-   0.007  - 0.0068   + 0.0073       False
Statistical inference using the hepstats library#

hepstats is built on top of zfit-interface:

from hepstats.hypotests import ConfidenceInterval
from hepstats.hypotests.calculators import AsymptoticCalculator
from hepstats.hypotests.parameters import POIarray
calculator = AsymptoticCalculator(result, minimizer)

We take one of the parameters as POI:

poi = pdf.params[r"\Gamma_{f_{0}(500)}"]
poi
<zfit.Parameter '\Gamma_{f_{0}(500)}' floating=True value=0.4197>
poi_null = POIarray(poi, np.linspace(poi - 0.1, poi + 0.1, 50))
ci = ConfidenceInterval(calculator, poi_null)
alpha = 0.328
ci.interval(alpha=alpha);
Confidence interval on \Gamma_{f_{0}(500)}:
	0.3964206394323228 < \Gamma_{f_{0}(500)} < 0.44257337109434974 at 67.2% C.L.

A helper function to plot the result:

def one_minus_cl_plot(x, pvalues, alpha=None, ax=None):
    if alpha is None:
        alpha = [0.32]
    if isinstance(alpha, (float, int)):
        alpha = [alpha]
    if ax is None:
        ax = plt.gca()

    ax.plot(x, pvalues, ".--")
    for a in alpha:
        ax.axhline(a, color="red", label="$\\alpha = " + str(a) + "$")
    ax.set_ylabel("1-CL")

    return ax
plt.figure(figsize=(9, 8))
one_minus_cl_plot(poi_null.values, ci.pvalues(), alpha=alpha)
plt.xlabel(f"${poi.name}$")
plt.show()

Polarimeter vector field#

Hide code cell content
%matplotlib widget
from __future__ import annotations

import itertools
import logging
from typing import TYPE_CHECKING

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import qrules
import sympy as sp
from ampform.sympy import (
    PoolSum,
    UnevaluatedExpression,
    create_expression,
    implement_doit_method,
    make_commutative,
    perform_cached_doit,
)
from attrs import frozen
from IPython.display import HTML, Image, Math, display
from ipywidgets import (
    Button,
    Combobox,
    HBox,
    HTMLMath,
    Tab,
    VBox,
    interactive_output,
)
from matplotlib.colors import LogNorm
from symplot import create_slider
from sympy.core.symbol import Str
from sympy.physics.matrices import msigma
from sympy.physics.quantum.spin import Rotation as Wigner
from tensorwaves.data.transform import SympyDataTransformer
from tensorwaves.function.sympy import create_function, create_parametrized_function

if TYPE_CHECKING:
    from qrules.particle import Particle
    from tensorwaves.interface import DataSample, ParametrizedFunction

LOGGER = logging.getLogger()
LOGGER.setLevel(logging.ERROR)

PDG = qrules.load_pdg()


def display_definitions(definitions: dict[sp.Symbol, sp.Expr]) -> None:
    latex = R"\begin{array}{rcl}" + "\n"
    for symbol, expr in definitions.items():
        symbol = sp.sympify(symbol)
        expr = sp.sympify(expr)
        lhs = sp.latex(symbol)
        rhs = sp.latex(expr)
        latex += Rf"  {lhs} & = & {rhs} \\" + "\n"
    latex += R"\end{array}"
    display(Math(latex))


def display_doit(
    expr: UnevaluatedExpression, deep=False, terms_per_line: int = 10
) -> None:
    latex = sp.multiline_latex(
        lhs=expr,
        rhs=expr.doit(deep=deep),
        terms_per_line=terms_per_line,
        environment="eqnarray",
    )
    display(Math(latex))


# hack for moving Indexed indices below superscript of the base
def _print_Indexed_latex(self, printer, *args):
    base = printer._print(self.base)
    indices = ", ".join(map(printer._print, self.indices))
    return f"{base}_{{{indices}}}"


sp.Indexed._latex = _print_Indexed_latex
Amplitude model#

The helicity amplitude for the \(\Lambda_c \to p K \pi\) decay reads as a sum of three partial wave series, incorporating \(\Lambda^{**}\) resonances, \(\Delta^{**}\) resonances, and \(K^{**}\) resonances. The particles are ordered as \(\Lambda_c(\mathbf{0}) \to p(\mathbf{1}) \pi(\mathbf{2}) K(\mathbf{3})\).

(1)#\[\begin{split} \begin{align} \mathcal{A}_{\nu,\lambda}(m_{K\pi},m_{pK}) &= \sum_{\nu',\lambda'} \big[\\ &\qquad d_{\nu,\nu'}^{1/2}(\zeta_{1(1)}^0) \mathcal{A}_{\nu',\lambda'}^{K} d_{\lambda',\lambda}^{1/2}(\zeta_{1(1)}^1) &\mathbf{subsystem\,1\;}(\to 23)\\ &\qquad + d_{\nu,\nu'}^{1/2}(\zeta_{2(1)}^0) \mathcal{A}_{\nu',\lambda'}^{\Lambda} d_{\lambda',\lambda}^{1/2}(\zeta_{2(1)}^1) &\mathbf{subsystem\,2\;}(\to 31)\\ &\qquad + d_{\nu,\nu'}^{1/2}(\zeta_{3(1)}^0) \mathcal{A}_{\nu',\lambda'}^{\Delta} d_{\lambda',\lambda}^{1/2}(\zeta_{3(1)}^1)\big]\,. &\mathbf{subsystem\,3\;}(\to 12)\\ \end{align} \end{split}\]

where \(\zeta^{i}_{j(k)}\) is the Wigner rotation for particle \(i,k\) and chain \(j\). The number in brackets indicates the overall definition of the helicity states \(|1/2,\nu\rangle\) and \(|1/2,\lambda\rangle\) for \(\Lambda_c\) and proton, respectively.

We use the particle-2 convention for the helicity couplings, which leads to the phase factor in the transitions \(\Lambda_c\to K^{**}p\), and \(\Lambda\to K p\):

\[\begin{split} \begin{align} \mathcal{A}^{K}_{\nu,\lambda} &= \sum_{j,\tau} \delta_{\nu,\tau - \lambda}\mathcal{H}^{\Lambda_c \to K^{**} p}_{\tau,\lambda} (-)^{1/2 - \lambda} \,d^{j}_{\lambda,0} (\theta_{23}) \, \mathcal{H}^{K^{**} \to \pi K}_{0,0}\\ % \mathcal{A}^{\Lambda}_{\nu,\lambda} &= \sum_{j,\tau} \delta_{\nu,\tau} \mathcal{H}^{\Lambda_c \to \Lambda^{**} \pi}_{\tau,0} d^{j}_{\tau,-\lambda} (\theta_{31}) \mathcal{H}^{\Lambda^{**} \to K p}_{0,\lambda} (-)^{1/2-\lambda} \\ % \mathcal{A}^{\Delta}_{\nu,\lambda} &= \sum_{j,\tau} \delta_{\nu,\tau} \mathcal{H}^{\Lambda_c \to \Delta^{**} K}_{\tau,0} d^{j}_{\tau,\lambda}(\theta_{12}) \mathcal{H}^{\Delta^{**} \to p\pi}_{\lambda,0} \,. \end{align} \end{split}\]

The helicity couplings in the particle-2 convention obey simple properties with respect to the parity transformation:

\[ \begin{align} \mathcal{H}^{A\to BC}_{-\lambda,-\lambda'} = P_A P_B P_C (-)^{j_A-j_B-j_C} \mathcal{H}^{A\to BC}_{\lambda,\lambda'} \end{align} \]

It reduced amount of the couplings in the strong decay of isobars. Moreover the magnitude of the couplings cannot be determined separately, therefore, it is set to 1:

\[\begin{split} \begin{align} \mathcal{H}^{\Lambda^{**} \to K p}_{0,1/2} &= 1\,, & \mathcal{H}^{\Delta^{**} \to p\pi}_{1/2,0} &= 1\,,& \mathcal{H}^{K^{**} \to \pi K}_{0,0} &= 1, \\ \mathcal{H}^{\Lambda^{**} \to K p}_{0,-1/2} &= -P_\Lambda (-)^{j-1/2}\,, & \mathcal{H}^{\Delta^{**} \to p\pi}_{-1/2,0} &= -P_\Delta (-)^{j-1/2}\,, && \end{align} \end{split}\]

The helicity couplings for the \(\Lambda_c^+\) decay are fit parameters. There are four of them for the \(K^{**}\) chain, and two for both the \(\Delta^{**}\) and \(\Lambda^{**}\) chains.

Resonances and LS-scheme#
Hide code cell source
Λc = PDG["Lambda(c)+"]
p = PDG["p"]
K = PDG["K-"]
π = PDG["pi+"]
decay_products = {
    1: (π, K),
    2: (p, K),
    3: (p, π),
}
siblings = {
    1: p,
    2: π,
    3: K,
}
chain_ids = {
    1: "K",
    2: "L",
    3: "D",
}
chain_labels = {
    1: "K^{**}",
    2: R"\Lambda^{**}",
    3: R"\Delta^{**}",
}

Resonance choices and their \(LS\)-couplings are defined as follows:

resonance_names = {
    1: ["K*(892)0"],
    2: ["Lambda(1520)", "Lambda(1670)"],
    3: ["Delta(1232)++"],
}
Hide code cell source
@frozen
class Resonance:
    particle: Particle
    l_R: int
    l_Λc: int

    @staticmethod
    def generate_ls(particle: Particle, chain_id: int) -> Resonance:
        LS_prod = generate_ls(Λc, particle, siblings[chain_id], strong=False)
        LS_prod = [L for L, S in LS_prod]
        LS_dec = generate_ls(particle, *decay_products[chain_id])
        LS_dec = [L for L, S in LS_dec]
        return Resonance(particle, l_R=min(LS_dec), l_Λc=min(LS_prod))


def generate_ls(
    parent: Particle,
    child1: Particle,
    child2: Particle,
    strong: bool = True,
    max_L: int = 3,
):
    s1 = child1.spin
    s2 = child2.spin
    s_values = arange(abs(s1 - s2), s1 + s2)
    LS_values = set()
    for S in s_values:
        for L in arange(0, max_L):
            if not abs(L - S) <= parent.spin <= L + S:
                continue
            Ρ0, Ρ1, Ρ2 = [
                int(parent.parity),
                int(child1.parity),
                int(child2.parity),
            ]
            if strong and Ρ0 != Ρ1 * Ρ2 * (-1) ** L:
                continue
            LS_values.add((L, S))
    return sorted(LS_values)


def arange(x1, x2):
    spin_range = np.arange(float(x1), +float(x2) + 0.5)
    return list(map(sp.Rational, spin_range))


resonance_particles = {
    chain_id: [PDG[name] for name in names]
    for chain_id, names in resonance_names.items()
}
ls_resonances = {
    chain_id: [Resonance.generate_ls(particle, chain_id) for particle in particles]
    for chain_id, particles in resonance_particles.items()
}


def jp(particle: Particle):
    p = "+" if particle.parity > 0 else "-"
    j = sp.Rational(particle.spin)
    return Rf"\({j}^{p}\)"


def create_html_table_row(*items, typ="td"):
    items = (f"<{typ}>{i}</{typ}>" for i in items)
    return "<tr>" + "".join(items) + "</tr>\n"


column_names = [
    "resonance",
    R"\(j^P\)",
    R"\(m\) (MeV)",
    R"\(\Gamma_0\) (MeV)",
    R"\(l_R\)",
    R"\(l_{\Lambda_c}^\mathrm{min}\)",
]
src = "<table>\n"
src += create_html_table_row(*column_names, typ="th")
for chain_id, resonance_list in ls_resonances.items():
    child1, child2 = decay_products[chain_id]
    for resonance in resonance_list:
        src += create_html_table_row(
            Rf"\({resonance.particle.latex} \to {child1.latex} {child2.latex}\)",
            jp(resonance.particle),
            int(1e3 * resonance.particle.mass),
            int(1e3 * resonance.particle.width),
            resonance.l_R,
            resonance.l_Λc,
        )
src += "</table>\n"
HTML(src)
resonance\(j^P\)\(m\) (MeV)\(\Gamma_0\) (MeV)\(l_R\)\(l_{\Lambda_c}^\mathrm{min}\)
\(K^{*}(892)^{0} \to \pi^{+} K^{-}\)\(1^-\)8954710
\(\Lambda(1520) \to p K^{-}\)\(3/2^-\)15191621
\(\Lambda(1670) \to p K^{-}\)\(1/2^-\)16743000
\(\Delta(1232)^{++} \to p \pi^{+}\)\(3/2^+\)123211711
Aligned amplitude#
Hide code cell source
A_K = sp.IndexedBase(R"A^K")
A_Λ = sp.IndexedBase(R"A^{\Lambda}")
A_Δ = sp.IndexedBase(R"A^{\Delta}")

half = sp.S.Half

Îś_0_11 = sp.Symbol(R"\zeta^0_{1(1)}", real=True)
Îś_0_21 = sp.Symbol(R"\zeta^0_{2(1)}", real=True)
Îś_0_31 = sp.Symbol(R"\zeta^0_{3(1)}", real=True)
Îś_1_11 = sp.Symbol(R"\zeta^1_{1(1)}", real=True)
Îś_1_21 = sp.Symbol(R"\zeta^1_{2(1)}", real=True)
Îś_1_31 = sp.Symbol(R"\zeta^1_{3(1)}", real=True)


def formulate_aligned_amplitude(λ_Λc, λ_p):
    _ν = sp.Symbol(R"\nu^{\prime}", rational=True)
    _Îť = sp.Symbol(R"\lambda^{\prime}", rational=True)
    return PoolSum(
        A_K[_ν, _Ν]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_11)
        * Wigner.d(half, _Îť, Îť_p, Îś_1_11)
        + A_Λ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_21)
        * Wigner.d(half, _Îť, Îť_p, Îś_1_21)
        + A_Δ[_ν, _λ]
        * Wigner.d(half, λ_Λc, _ν, ζ_0_31)
        * Wigner.d(half, _Îť, Îť_p, Îś_1_31),
        (_Îť, [-half, +half]),
        (_ν, [-half, +half]),
    )


ν = sp.Symbol("nu")
Îť = sp.Symbol("lambda")
formulate_aligned_amplitude(λ_Λc=ν, λ_p=λ)
\[\displaystyle \sum_{\lambda^{\prime}=-1/2}^{1/2} \sum_{\nu^{\prime}=-1/2}^{1/2}{A^{K}_{\nu^{\prime}, \lambda^{\prime}} d^{\frac{1}{2}}_{\lambda^{\prime},\lambda}\left(\zeta^1_{1(1)}\right) d^{\frac{1}{2}}_{\nu,\nu^{\prime}}\left(\zeta^0_{1(1)}\right) + A^{\Delta}_{\nu^{\prime}, \lambda^{\prime}} d^{\frac{1}{2}}_{\lambda^{\prime},\lambda}\left(\zeta^1_{3(1)}\right) d^{\frac{1}{2}}_{\nu,\nu^{\prime}}\left(\zeta^0_{3(1)}\right) + A^{\Lambda}_{\nu^{\prime}, \lambda^{\prime}} d^{\frac{1}{2}}_{\lambda^{\prime},\lambda}\left(\zeta^1_{2(1)}\right) d^{\frac{1}{2}}_{\nu,\nu^{\prime}}\left(\zeta^0_{2(1)}\right)}\]
Dynamics#

The lineshape function is factored out of the \(\Lambda_c\) helicity coupling:

\[ \begin{align} \mathcal{H}^{\Lambda_c \to R x}_{\lambda,\lambda'} = \hat{\mathcal{H}}^{\Lambda_c \to R x}_{\lambda,\lambda'}\,\mathcal{R}(s)\,. \end{align} \]

The relativistic Breit-Wigner parametrization reads:

\[ \begin{align} \mathcal{R}(s) = \left(\frac{q}{q_0}\right)^{l_{\Lambda_c}^\text{min}} \frac{F_{l_{\Lambda_c}^\text{min}}(q R)}{F_{l_{\Lambda_c}^\text{min}}(q_0 R)}\, \frac{1}{m^2-s-im\Gamma(s)} \left(\frac{p}{p_0}\right)^{l_R} \frac{F_{l_R}(pR)}{F_{l_R}(p_0R)}, \end{align} \]

with energy-dependent width given by

\[ \begin{align} \Gamma(s) = \Gamma_0 \left(\frac{p}{p_0}\right)^{2l_R+1} \frac{m}{\sqrt{s}} \, \frac{F_{l_R}^2(pR)}{F_{l_R}^2(p_0R)}\,, \end{align} \]

The form-factor \(F\) is the Blatt-Weisskopf factor with the length factor \(R=5\,\)GeV\(^{-1}\):

\[ \begin{align} F_0(pR) &= 1\,,& F_1(pR) &= \sqrt{\frac{1}{1+(pR)^2}}\,,& F_2(pR) &= \sqrt{\frac{1}{9+3(pR)^2+(pR)^4}}\,. \end{align} \]

The break-up momenta is calculated for every decay chain separately. Using the notations \(0->R(\to ij) k\), one writes:

\[ \begin{align} p &= \lambda^{1/2}(s,m_i^2,m_j^2)/(2\sqrt{s})\,, & q &= \lambda^{1/2}(s,m_0^2,m_k^2)/(2m_0)\,. \end{align} \]

The momenta with subindex zero are computed for nominal mass of the resonance, \(s=m^2\). The three-argument KällÊn function reads:

(2)#\[ \begin{align} \lambda(x,y,z) = x^2+y^2+z^2 - 2xy-2yz-2zx\,. \end{align} \]

Formulation with SymPy

Hide code cell source
@make_commutative
@implement_doit_method
class BlattWeisskopf(UnevaluatedExpression):
    def __new__(cls, z, L, **hints):
        return create_expression(cls, z, L, **hints)

    def evaluate(self):
        z, L = self.args
        cases = {
            0: 1,
            1: 1 / (1 + z**2),
            2: 1 / (9 + 3 * z**2 + z**4),
        }
        return sp.Piecewise(
            *[(sp.sqrt(expr), sp.Eq(L, l_val)) for l_val, expr in cases.items()]
        )

    def _latex(self, printer, *args):
        z, L = map(printer._print, self.args)
        return Rf"F_{{{L}}}\left({z}\right)"


z = sp.Symbol("z", positive=True)
L = sp.Symbol("L", integer=True, nonnegative=True)
latex = sp.multiline_latex(BlattWeisskopf(z, L), BlattWeisskopf(z, L).doit())
Math(latex)
\[\begin{split}\displaystyle \begin{align*} F_{L}\left(z\right) = & \begin{cases} 1 & \text{for}\: L = 0 \\\frac{1}{\sqrt{z^{2} + 1}} & \text{for}\: L = 1 \\\frac{1}{\sqrt{z^{4} + 3 z^{2} + 9}} & \text{for}\: L = 2 \end{cases} \end{align*}\end{split}\]
Hide code cell source
@make_commutative
@implement_doit_method
class KällÊn(UnevaluatedExpression):
    def __new__(cls, x, y, z, **hints):
        return create_expression(cls, x, y, z, **hints)

    def evaluate(self) -> sp.Expr:
        x, y, z = self.args
        return x**2 + y**2 + z**2 - 2 * x * y - 2 * y * z - 2 * z * x

    def _latex(self, printer, *args):
        x, y, z = map(printer._print, self.args)
        return Rf"\lambda\left({x}, {y}, {z}\right)"


x, y, z = sp.symbols("x:z")
display_doit(KällÊn(x, y, z))
\[\displaystyle \begin{eqnarray} \lambda\left(x, y, z\right) & = & x^{2} - 2 x y - 2 x z + y^{2} - 2 y z + z^{2} \end{eqnarray}\]
Hide code cell source
@make_commutative
@implement_doit_method
class P(UnevaluatedExpression):
    def __new__(cls, s, mi, mj, **hints):
        return create_expression(cls, s, mi, mj, **hints)

    def evaluate(self):
        s, mi, mj = self.args
        return sp.sqrt(KällÊn(s, mi**2, mj**2)) / (2 * sp.sqrt(s))

    def _latex(self, printer, *args):
        s = printer._print(self.args[0])
        return Rf"p_{{{s}}}"


@make_commutative
@implement_doit_method
class Q(UnevaluatedExpression):
    def __new__(cls, s, m0, mk, **hints):
        return create_expression(cls, s, m0, mk, **hints)

    def evaluate(self):
        s, m0, mk = self.args
        return sp.sqrt(KällÊn(s, m0**2, mk**2)) / (2 * m0)  # <-- not s!

    def _latex(self, printer, *args):
        s = printer._print(self.args[0])
        return Rf"q_{{{s}}}"


s, m0, mi, mj, mk = sp.symbols("s m0 m_i:k", nonnegative=True)
display_doit(P(s, mi, mj))
display_doit(Q(s, m0, mk))
\[\displaystyle \begin{eqnarray} p_{s} & = & \frac{\sqrt{\lambda\left(s, m_{i}^{2}, m_{j}^{2}\right)}}{2 \sqrt{s}} \end{eqnarray}\]
\[\displaystyle \begin{eqnarray} q_{s} & = & \frac{\sqrt{\lambda\left(s, m_{0}^{2}, m_{k}^{2}\right)}}{2 m_{0}} \end{eqnarray}\]
Hide code cell source
R = sp.Symbol("R")
parameter_defaults = {
    R: 5,  # GeV^{-1} (length factor)
}


@make_commutative
@implement_doit_method
class EnergyDependentWidth(UnevaluatedExpression):
    def __new__(cls, s, m0, Γ0, m1, m2, L, R):
        return create_expression(cls, s, m0, Γ0, m1, m2, L, R)

    def evaluate(self):
        s, m0, Γ0, m1, m2, L, R = self.args
        p = P(s, m1, m2)
        p0 = P(m0**2, m1, m2)
        ff = BlattWeisskopf(p * R, L) ** 2
        ff0 = BlattWeisskopf(p0 * R, L) ** 2
        return sp.Mul(
            Γ0,
            (p / p0) ** (2 * L + 1),
            m0 / sp.sqrt(s),
            ff / ff0,
            evaluate=False,
        )

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"\Gamma\left({s}\right)"


l_R = sp.Symbol("l_R", integer=True, positive=True)
m, Γ0, m1, m2 = sp.symbols("m Γ0 m1 m2", nonnegative=True)
display_doit(EnergyDependentWidth(s, m, Γ0, m1, m2, l_R, R))
\[\displaystyle \begin{eqnarray} \Gamma\left(s\right) & = & Γ_{0} \frac{m}{\sqrt{s}} \frac{F_{l_{R}}\left(R p_{s}\right)^{2}}{F_{l_{R}}\left(R p_{m^{2}}\right)^{2}} \left(\frac{p_{s}}{p_{m^{2}}}\right)^{2 l_{R} + 1} \end{eqnarray}\]
Hide code cell source
@make_commutative
@implement_doit_method
class RelativisticBreitWigner(UnevaluatedExpression):
    def __new__(cls, s, m0, Γ0, m1, m2, l_R, l_Λc, R):
        return create_expression(cls, s, m0, Γ0, m1, m2, l_R, l_Λc, R)

    def evaluate(self):
        s, m0, Γ0, m1, m2, l_R, l_Λc, R = self.args
        q = Q(s, m1, m2)
        q0 = Q(m0**2, m1, m2)
        p = P(s, m1, m2)
        p0 = P(m0**2, m1, m2)
        width = EnergyDependentWidth(s, m0, Γ0, m1, m2, l_R, R)
        return sp.Mul(
            (q / q0) ** l_Λc,
            BlattWeisskopf(q * R, l_Λc) / BlattWeisskopf(q0 * R, l_Λc),
            1 / (m0**2 - s - sp.I * m0 * width),
            (p / p0) ** l_R,
            BlattWeisskopf(p * R, l_R) / BlattWeisskopf(p0 * R, l_R),
            evaluate=False,
        )

    def _latex(self, printer, *args) -> str:
        s = printer._print(self.args[0])
        return Rf"\mathcal{{R}}\left({s}\right)"


l_Λc = sp.Symbol(R"l_{\Lambda_c}", integer=True, positive=True)
display_doit(RelativisticBreitWigner(s, m, Γ0, m1, m2, l_R, l_Λc, R))
\[\displaystyle \begin{eqnarray} \mathcal{R}\left(s\right) & = & \frac{\frac{F_{l_{R}}\left(R p_{s}\right)}{F_{l_{R}}\left(R p_{m^{2}}\right)} \frac{F_{l_{\Lambda_c}}\left(R q_{s}\right)}{F_{l_{\Lambda_c}}\left(R q_{m^{2}}\right)} \left(\frac{p_{s}}{p_{m^{2}}}\right)^{l_{R}} \left(\frac{q_{s}}{q_{m^{2}}}\right)^{l_{\Lambda_c}}}{m^{2} - i m \Gamma\left(s\right) - s} \end{eqnarray}\]
Decay chain amplitudes#
Hide code cell source
def formulate_chain_amplitude(chain_id: int, λ_Λc, λ_p):
    resonances = ls_resonances[chain_id]
    if chain_id == 1:
        return formulate_K_amplitude(λ_Λc, λ_p, resonances)
    if chain_id == 2:
        return formulate_Λ_amplitude(λ_Λc, λ_p, resonances)
    if chain_id == 3:
        return formulate_Δ_amplitude(λ_Λc, λ_p, resonances)
    raise NotImplementedError


H_prod = sp.IndexedBase(R"\mathcal{H}^\mathrm{production}")
H_dec = sp.IndexedBase(R"\mathcal{H}^\mathrm{decay}")

θ23 = sp.Symbol("theta23", real=True)
θ31 = sp.Symbol("theta31", real=True)
θ12 = sp.Symbol("theta12", real=True)

σ1, σ2, σ3 = sp.symbols("sigma1:4", nonnegative=True)
m1, m2, m3 = sp.symbols(R"m_p m_pi m_K", nonnegative=True)


def formulate_K_amplitude(λ_Λc, λ_p, resonances: list[Resonance]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(*[
        PoolSum(
            sp.KroneckerDelta(λ_Λc, τ - λ_p)
            * H_prod[stringify(res), τ, λ_p]
            * formulate_dynamics(res, σ1, m2, m3)
            * (-1) ** (half - Îť_p)
            * Wigner.d(sp.Rational(res.particle.spin), τ, 0, θ23)
            * H_dec[stringify(res), 0, 0],
            (τ, create_spin_range(res.particle.spin)),
        )
        for res in resonances
    ])


def formulate_Λ_amplitude(λ_Λc, λ_p, resonances: list[Resonance]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(*[
        PoolSum(
            sp.KroneckerDelta(λ_Λc, τ)
            * H_prod[stringify(res), τ, 0]
            * formulate_dynamics(res, σ2, m1, m3)
            * Wigner.d(sp.Rational(res.particle.spin), τ, -λ_p, θ31)
            * H_dec[stringify(res), 0, Îť_p]
            * (-1) ** (half - Îť_p),
            (τ, create_spin_range(res.particle.spin)),
        )
        for res in resonances
    ])


def formulate_Δ_amplitude(λ_Λc, λ_p, resonances: list[Resonance]):
    τ = sp.Symbol("tau", rational=True)
    return sp.Add(*[
        PoolSum(
            sp.KroneckerDelta(λ_Λc, τ)
            * H_prod[stringify(res), τ, 0]
            * formulate_dynamics(res, σ3, m1, m2)
            * Wigner.d(sp.Rational(res.particle.spin), τ, λ_p, θ12)
            * H_dec[stringify(res), Îť_p, 0],
            (τ, create_spin_range(res.particle.spin)),
        )
        for res in resonances
    ])


def formulate_dynamics(decay: Resonance, s, m1, m2):
    l_R = sp.Rational(decay.l_R)
    l_Λc = sp.Rational(decay.l_Λc)
    mass = sp.Symbol(f"m_{{{decay.particle.latex}}}")
    width = sp.Symbol(Rf"\Gamma_{{{decay.particle.latex}}}")
    parameter_defaults[mass] = decay.particle.mass
    parameter_defaults[width] = decay.particle.width
    return RelativisticBreitWigner(s, mass, width, m1, m2, l_R, l_Λc, R)


def stringify(particle: Particle | Resonance) -> Str:
    if isinstance(particle, Resonance):
        particle = particle.particle
    return Str(particle.latex)


def create_spin_range(j):
    return arange(-j, +j)


display(
    formulate_chain_amplitude(1, ν, Ν),
    formulate_chain_amplitude(2, ν, Ν),
    formulate_chain_amplitude(3, ν, Ν),
)
\[\displaystyle \sum_{\tau=-1}^{1}{\left(-1\right)^{\frac{1}{2} - \lambda} \delta_{\nu, - \lambda + \tau} \mathcal{H}^\mathrm{decay}_{K^{*}(892)^{0}, 0, 0} \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, \tau, \lambda} \mathcal{R}\left(\sigma_{1}\right) d^{1}_{\tau,0}\left(\theta_{23}\right)}\]
\[\displaystyle \sum_{\tau=-3/2}^{3/2}{\left(-1\right)^{\frac{1}{2} - \lambda} \delta_{\nu \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1520), 0, \lambda} \mathcal{H}^\mathrm{production}_{\Lambda(1520), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{3}{2}}_{\tau,- \lambda}\left(\theta_{31}\right)} + \sum_{\tau=-1/2}^{1/2}{\left(-1\right)^{\frac{1}{2} - \lambda} \delta_{\nu \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1670), 0, \lambda} \mathcal{H}^\mathrm{production}_{\Lambda(1670), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{1}{2}}_{\tau,- \lambda}\left(\theta_{31}\right)}\]
\[\displaystyle \sum_{\tau=-3/2}^{3/2}{\delta_{\nu \tau} \mathcal{H}^\mathrm{decay}_{\Delta(1232)^{++}, \lambda, 0} \mathcal{H}^\mathrm{production}_{\Delta(1232)^{++}, \tau, 0} \mathcal{R}\left(\sigma_{3}\right) d^{\frac{3}{2}}_{\tau,\lambda}\left(\theta_{12}\right)}\]
Angle definitions#

Angles with repeated lower indices are trivial. The other angles are computed from the invariants and these are the angles with sign

\[\begin{split} \begin{align} \zeta_{1(1)}^{0} &= \hat{\theta}_{1(1)}^{0} = 0\,, & \zeta_{1(1)}^{1} &= 0\,,\\ \zeta_{2(1)}^0 &=\hat{\theta}_{2(1)} =-\hat{\theta}_{1(2)}\,,\\ \zeta_{3(1)}^0 &= \hat{\theta}_{3(1)}\,, &\zeta_{3(1)}^1 &= -\zeta_{1(3)}^1\,, \end{align} \end{split}\]

The expressions for the cosine of the positive (anticlockwise) angles, \(\theta_{12}, \theta_{23}, \theta_{13}\) and \(\hat\theta_{1(2)}, \hat\theta_{3(1)}, \zeta^1_{1(3)}\) can be expressed in terms of Mandelstam variables \(\sigma_1, \sigma_2, \sigma_3\) using [Mikhasenko et al., 2020], Appendix A:

Hide code cell source
m0 = sp.Symbol(R"m_{\Lambda_c}", nonnegative=True)
angles = {
    θ12: sp.acos(
        (2 * σ3 * (σ2 - m3**2 - m1**2) - (σ3 + m1**2 - m2**2) * (m0**2 - σ3 - m3**2))
        / (sp.sqrt(KällÊn(m0**2, m3**2, σ3)) * sp.sqrt(KällÊn(σ3, m1**2, m2**2)))
    ),
    θ23: sp.acos(
        (2 * σ1 * (σ3 - m1**2 - m2**2) - (σ1 + m2**2 - m3**2) * (m0**2 - σ1 - m1**2))
        / (sp.sqrt(KällÊn(m0**2, m1**2, σ1)) * sp.sqrt(KällÊn(σ1, m2**2, m3**2)))
    ),
    θ31: sp.acos(
        (2 * σ2 * (σ1 - m2**2 - m3**2) - (σ2 + m3**2 - m1**2) * (m0**2 - σ2 - m2**2))
        / (sp.sqrt(KällÊn(m0**2, m2**2, σ2)) * sp.sqrt(KällÊn(σ2, m3**2, m1**2)))
    ),
    Îś_0_11: sp.S.Zero,  # = \hat\theta^0_{1(1)}
    Îś_0_21: -sp.acos(  # = -\hat\theta^{1(2)}
        (
            (m0**2 + m1**2 - σ1) * (m0**2 + m2**2 - σ2)
            - 2 * m0**2 * (σ3 - m1**2 - m2**2)
        )
        / (sp.sqrt(KällÊn(m0**2, m2**2, σ2)) * sp.sqrt(KällÊn(m0**2, σ1, m1**2)))
    ),
    Îś_0_31: sp.acos(  # = \hat\theta^{3(1)}
        (
            (m0**2 + m3**2 - σ3) * (m0**2 + m1**2 - σ1)
            - 2 * m0**2 * (σ2 - m3**2 - m1**2)
        )
        / (sp.sqrt(KällÊn(m0**2, m1**2, σ1)) * sp.sqrt(KällÊn(m0**2, σ3, m3**2)))
    ),
    Îś_1_11: sp.S.Zero,
    Îś_1_21: sp.acos(
        (
            2 * m1**2 * (σ3 - m0**2 - m3**2)
            + (m0**2 + m1**2 - σ1) * (σ2 - m1**2 - m3**2)
        )
        / (sp.sqrt(KällÊn(m0**2, m1**2, σ1)) * sp.sqrt(KällÊn(σ2, m1**2, m3**2)))
    ),
    Îś_1_31: -sp.acos(  # = -\zeta^1_{1(3)}
        (
            2 * m1**2 * (σ2 - m0**2 - m2**2)
            + (m0**2 + m1**2 - σ1) * (σ3 - m1**2 - m2**2)
        )
        / (sp.sqrt(KällÊn(m0**2, m1**2, σ1)) * sp.sqrt(KällÊn(σ3, m1**2, m2**2)))
    ),
}

display_definitions(angles)
\[\begin{split}\displaystyle \begin{array}{rcl} \theta_{12} & = & \operatorname{acos}{\left(\frac{2 \sigma_{3} \left(- m_{K}^{2} - m_{p}^{2} + \sigma_{2}\right) - \left(- m_{K}^{2} + m_{\Lambda_c}^{2} - \sigma_{3}\right) \left(m_{p}^{2} - m_{\pi}^{2} + \sigma_{3}\right)}{\sqrt{\lambda\left(m_{\Lambda_c}^{2}, m_{K}^{2}, \sigma_{3}\right)} \sqrt{\lambda\left(\sigma_{3}, m_{p}^{2}, m_{\pi}^{2}\right)}} \right)} \\ \theta_{23} & = & \operatorname{acos}{\left(\frac{2 \sigma_{1} \left(- m_{p}^{2} - m_{\pi}^{2} + \sigma_{3}\right) - \left(- m_{K}^{2} + m_{\pi}^{2} + \sigma_{1}\right) \left(- m_{p}^{2} + m_{\Lambda_c}^{2} - \sigma_{1}\right)}{\sqrt{\lambda\left(m_{\Lambda_c}^{2}, m_{p}^{2}, \sigma_{1}\right)} \sqrt{\lambda\left(\sigma_{1}, m_{\pi}^{2}, m_{K}^{2}\right)}} \right)} \\ \theta_{31} & = & \operatorname{acos}{\left(\frac{2 \sigma_{2} \left(- m_{K}^{2} - m_{\pi}^{2} + \sigma_{1}\right) - \left(m_{K}^{2} - m_{p}^{2} + \sigma_{2}\right) \left(- m_{\pi}^{2} + m_{\Lambda_c}^{2} - \sigma_{2}\right)}{\sqrt{\lambda\left(m_{\Lambda_c}^{2}, m_{\pi}^{2}, \sigma_{2}\right)} \sqrt{\lambda\left(\sigma_{2}, m_{K}^{2}, m_{p}^{2}\right)}} \right)} \\ \zeta^0_{1(1)} & = & 0 \\ \zeta^0_{2(1)} & = & - \operatorname{acos}{\left(\frac{- 2 m_{\Lambda_c}^{2} \left(- m_{p}^{2} - m_{\pi}^{2} + \sigma_{3}\right) + \left(m_{p}^{2} + m_{\Lambda_c}^{2} - \sigma_{1}\right) \left(m_{\pi}^{2} + m_{\Lambda_c}^{2} - \sigma_{2}\right)}{\sqrt{\lambda\left(m_{\Lambda_c}^{2}, m_{\pi}^{2}, \sigma_{2}\right)} \sqrt{\lambda\left(m_{\Lambda_c}^{2}, \sigma_{1}, m_{p}^{2}\right)}} \right)} \\ \zeta^0_{3(1)} & = & \operatorname{acos}{\left(\frac{- 2 m_{\Lambda_c}^{2} \left(- m_{K}^{2} - m_{p}^{2} + \sigma_{2}\right) + \left(m_{K}^{2} + m_{\Lambda_c}^{2} - \sigma_{3}\right) \left(m_{p}^{2} + m_{\Lambda_c}^{2} - \sigma_{1}\right)}{\sqrt{\lambda\left(m_{\Lambda_c}^{2}, m_{p}^{2}, \sigma_{1}\right)} \sqrt{\lambda\left(m_{\Lambda_c}^{2}, \sigma_{3}, m_{K}^{2}\right)}} \right)} \\ \zeta^1_{1(1)} & = & 0 \\ \zeta^1_{2(1)} & = & \operatorname{acos}{\left(\frac{2 m_{p}^{2} \left(- m_{K}^{2} - m_{\Lambda_c}^{2} + \sigma_{3}\right) + \left(- m_{K}^{2} - m_{p}^{2} + \sigma_{2}\right) \left(m_{p}^{2} + m_{\Lambda_c}^{2} - \sigma_{1}\right)}{\sqrt{\lambda\left(m_{\Lambda_c}^{2}, m_{p}^{2}, \sigma_{1}\right)} \sqrt{\lambda\left(\sigma_{2}, m_{p}^{2}, m_{K}^{2}\right)}} \right)} \\ \zeta^1_{3(1)} & = & - \operatorname{acos}{\left(\frac{2 m_{p}^{2} \left(- m_{\pi}^{2} - m_{\Lambda_c}^{2} + \sigma_{2}\right) + \left(- m_{p}^{2} - m_{\pi}^{2} + \sigma_{3}\right) \left(m_{p}^{2} + m_{\Lambda_c}^{2} - \sigma_{1}\right)}{\sqrt{\lambda\left(m_{\Lambda_c}^{2}, m_{p}^{2}, \sigma_{1}\right)} \sqrt{\lambda\left(\sigma_{3}, m_{p}^{2}, m_{\pi}^{2}\right)}} \right)} \\ \end{array}\end{split}\]

where \(m_0\) is the mass of the initial state \(\Lambda_c\) and \(m_1, m_2, m_3\) are the masses of \(p, \pi, K\), respectively:

Hide code cell source
masses = {
    m0: Λc.mass,
    m1: p.mass,
    m2: π.mass,
    m3: K.mass,
}
parameter_defaults.update(masses)
display_definitions(masses)
\[\begin{split}\displaystyle \begin{array}{rcl} m_{\Lambda_c} & = & 2.28646 \\ m_{p} & = & 0.938272081 \\ m_{\pi} & = & 0.13957039 \\ m_{K} & = & 0.493677 \\ \end{array}\end{split}\]
Helicity coupling values#
Hide code cell source
dec_couplings = {}
for res in ls_resonances[1]:
    i = stringify(res)
    dec_couplings[H_dec[i, 0, 0]] = 1
for res in ls_resonances[2]:
    i = stringify(res.particle)
    dec_couplings[H_dec[i, 0, half]] = 1
    dec_couplings[H_dec[i, 0, -half]] = int(
        int(res.particle.parity)
        * int(K.parity)
        * int(p.parity)
        * (-1) ** (res.particle.spin - K.spin - p.spin)
    )
for res in ls_resonances[3]:
    i = stringify(res.particle)
    dec_couplings[H_dec[i, half, 0]] = 1
    dec_couplings[H_dec[i, -half, 0]] = int(
        int(res.particle.parity)
        * int(p.parity)
        * int(π.parity)
        * (-1) ** (res.particle.spin - p.spin - π.spin)
    )
parameter_defaults.update(dec_couplings)
display_definitions(dec_couplings)
\[\begin{split}\displaystyle \begin{array}{rcl} \mathcal{H}^\mathrm{decay}_{K^{*}(892)^{0}, 0, 0} & = & 1 \\ \mathcal{H}^\mathrm{decay}_{\Lambda(1520), 0, \frac{1}{2}} & = & 1 \\ \mathcal{H}^\mathrm{decay}_{\Lambda(1520), 0, - \frac{1}{2}} & = & -1 \\ \mathcal{H}^\mathrm{decay}_{\Lambda(1670), 0, \frac{1}{2}} & = & 1 \\ \mathcal{H}^\mathrm{decay}_{\Lambda(1670), 0, - \frac{1}{2}} & = & 1 \\ \mathcal{H}^\mathrm{decay}_{\Delta(1232)^{++}, \frac{1}{2}, 0} & = & 1 \\ \mathcal{H}^\mathrm{decay}_{\Delta(1232)^{++}, - \frac{1}{2}, 0} & = & 1 \\ \end{array}\end{split}\]
Hide code cell source
prod_couplings = {
    # chain 23:
    H_prod[Str("K^{*}(892)^{0}"), 0, -half]: 1,
    H_prod[Str("K^{*}(892)^{0}"), -1, -half]: 1 - 1j,
    H_prod[Str("K^{*}(892)^{0}"), +1, +half]: -3 - 3j,
    H_prod[Str("K^{*}(892)^{0}"), 0, +half]: -1 - 4j,
    #
    H_prod[Str("K_{0}^{*}(1430)^{0}"), 0, -half]: 1,
    H_prod[Str("K_{0}^{*}(1430)^{0}"), -1, -half]: 1 - 1j,
    H_prod[Str("K_{0}^{*}(1430)^{0}"), +1, +half]: -3 - 3j,
    H_prod[Str("K_{0}^{*}(1430)^{0}"), 0, +half]: -1 - 4j,
    #
    H_prod[Str("K_{2}^{*}(1430)^{0}"), 0, -half]: 1,
    H_prod[Str("K_{2}^{*}(1430)^{0}"), -1, -half]: 1 - 1j,
    H_prod[Str("K_{2}^{*}(1430)^{0}"), +1, +half]: -3 - 3j,
    H_prod[Str("K_{2}^{*}(1430)^{0}"), 0, +half]: -1 - 4j,
    #
    # chain 31:
    H_prod[Str(R"\Lambda(1520)"), +half, 0]: 1.5,
    H_prod[Str(R"\Lambda(1520)"), -half, 0]: 0.3,
    H_prod[Str(R"\Lambda(1670)"), +half, 0]: -0.5 + 1j,
    H_prod[Str(R"\Lambda(1670)"), -half, 0]: -0.3 - 0.1j,
    # chain 12:
    H_prod[Str(R"\Delta(1232)^{++}"), +half, 0]: -13 + 5j,
    H_prod[Str(R"\Delta(1232)^{++}"), -half, 0]: -7 + 3j,
}
display_definitions(prod_couplings)
couplings = dict(dec_couplings)
couplings.update(prod_couplings)
parameter_defaults.update(prod_couplings)
\[\begin{split}\displaystyle \begin{array}{rcl} \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, 0, - \frac{1}{2}} & = & 1 \\ \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, -1, - \frac{1}{2}} & = & 1.0 - 1.0 i \\ \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, 1, \frac{1}{2}} & = & -3.0 - 3.0 i \\ \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, 0, \frac{1}{2}} & = & -1.0 - 4.0 i \\ \mathcal{H}^\mathrm{production}_{K_{0}^{*}(1430)^{0}, 0, - \frac{1}{2}} & = & 1 \\ \mathcal{H}^\mathrm{production}_{K_{0}^{*}(1430)^{0}, -1, - \frac{1}{2}} & = & 1.0 - 1.0 i \\ \mathcal{H}^\mathrm{production}_{K_{0}^{*}(1430)^{0}, 1, \frac{1}{2}} & = & -3.0 - 3.0 i \\ \mathcal{H}^\mathrm{production}_{K_{0}^{*}(1430)^{0}, 0, \frac{1}{2}} & = & -1.0 - 4.0 i \\ \mathcal{H}^\mathrm{production}_{K_{2}^{*}(1430)^{0}, 0, - \frac{1}{2}} & = & 1 \\ \mathcal{H}^\mathrm{production}_{K_{2}^{*}(1430)^{0}, -1, - \frac{1}{2}} & = & 1.0 - 1.0 i \\ \mathcal{H}^\mathrm{production}_{K_{2}^{*}(1430)^{0}, 1, \frac{1}{2}} & = & -3.0 - 3.0 i \\ \mathcal{H}^\mathrm{production}_{K_{2}^{*}(1430)^{0}, 0, \frac{1}{2}} & = & -1.0 - 4.0 i \\ \mathcal{H}^\mathrm{production}_{\Lambda(1520), \frac{1}{2}, 0} & = & 1.5 \\ \mathcal{H}^\mathrm{production}_{\Lambda(1520), - \frac{1}{2}, 0} & = & 0.3 \\ \mathcal{H}^\mathrm{production}_{\Lambda(1670), \frac{1}{2}, 0} & = & -0.5 + 1.0 i \\ \mathcal{H}^\mathrm{production}_{\Lambda(1670), - \frac{1}{2}, 0} & = & -0.3 - 0.1 i \\ \mathcal{H}^\mathrm{production}_{\Delta(1232)^{++}, \frac{1}{2}, 0} & = & -13.0 + 5.0 i \\ \mathcal{H}^\mathrm{production}_{\Delta(1232)^{++}, - \frac{1}{2}, 0} & = & -7.0 + 3.0 i \\ \end{array}\end{split}\]
Intensity expression#

Incoherent sum of the amplitudes defined by Aligned amplitude:

Hide code cell source
intensity_expr = PoolSum(
    sp.Abs(formulate_aligned_amplitude(ν, Ν)) ** 2,
    (Îť, [-half, +half]),
    (ν, [-half, +half]),
)
intensity_expr
\[\displaystyle \sum_{\lambda=-1/2}^{1/2} \sum_{\nu=-1/2}^{1/2}{\left|{\sum_{\lambda^{\prime}=-1/2}^{1/2} \sum_{\nu^{\prime}=-1/2}^{1/2}{A^{K}_{\nu^{\prime}, \lambda^{\prime}} d^{\frac{1}{2}}_{\lambda^{\prime},\lambda}\left(\zeta^1_{1(1)}\right) d^{\frac{1}{2}}_{\nu,\nu^{\prime}}\left(\zeta^0_{1(1)}\right) + A^{\Delta}_{\nu^{\prime}, \lambda^{\prime}} d^{\frac{1}{2}}_{\lambda^{\prime},\lambda}\left(\zeta^1_{3(1)}\right) d^{\frac{1}{2}}_{\nu,\nu^{\prime}}\left(\zeta^0_{3(1)}\right) + A^{\Lambda}_{\nu^{\prime}, \lambda^{\prime}} d^{\frac{1}{2}}_{\lambda^{\prime},\lambda}\left(\zeta^1_{2(1)}\right) d^{\frac{1}{2}}_{\nu,\nu^{\prime}}\left(\zeta^0_{2(1)}\right)}}\right|^{2}}\]

Remaining free_symbols are indeed the specific amplitudes as defined by Decay chain amplitudes:

The specific amplitudes from Decay chain amplitudes need to be formulated for each value of \(\nu, \lambda\), so that they can be substituted in the top expression:

Hide code cell source
A = {1: A_K, 2: A_Λ, 3: A_Δ}
amp_definitions = {}
for chain_id in chain_ids:
    for Λc_heli, p_heli in itertools.product([-half, +half], [-half, +half]):
        symbol = A[chain_id][Λc_heli, p_heli]
        expr = formulate_chain_amplitude(chain_id, ν, Ν)
        amp_definitions[symbol] = expr.subs({ν: Λc_heli, λ: p_heli})
display_definitions(amp_definitions)
\[\begin{split}\displaystyle \begin{array}{rcl} A^{K}_{- \frac{1}{2}, - \frac{1}{2}} & = & \sum_{\tau=-1}^{1}{- \delta_{- \frac{1}{2}, \tau + \frac{1}{2}} \mathcal{H}^\mathrm{decay}_{K^{*}(892)^{0}, 0, 0} \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, \tau, - \frac{1}{2}} \mathcal{R}\left(\sigma_{1}\right) d^{1}_{\tau,0}\left(\theta_{23}\right)} \\ A^{K}_{- \frac{1}{2}, \frac{1}{2}} & = & \sum_{\tau=-1}^{1}{\delta_{- \frac{1}{2}, \tau - \frac{1}{2}} \mathcal{H}^\mathrm{decay}_{K^{*}(892)^{0}, 0, 0} \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, \tau, \frac{1}{2}} \mathcal{R}\left(\sigma_{1}\right) d^{1}_{\tau,0}\left(\theta_{23}\right)} \\ A^{K}_{\frac{1}{2}, - \frac{1}{2}} & = & \sum_{\tau=-1}^{1}{- \delta_{\frac{1}{2}, \tau + \frac{1}{2}} \mathcal{H}^\mathrm{decay}_{K^{*}(892)^{0}, 0, 0} \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, \tau, - \frac{1}{2}} \mathcal{R}\left(\sigma_{1}\right) d^{1}_{\tau,0}\left(\theta_{23}\right)} \\ A^{K}_{\frac{1}{2}, \frac{1}{2}} & = & \sum_{\tau=-1}^{1}{\delta_{\frac{1}{2}, \tau - \frac{1}{2}} \mathcal{H}^\mathrm{decay}_{K^{*}(892)^{0}, 0, 0} \mathcal{H}^\mathrm{production}_{K^{*}(892)^{0}, \tau, \frac{1}{2}} \mathcal{R}\left(\sigma_{1}\right) d^{1}_{\tau,0}\left(\theta_{23}\right)} \\ A^{\Lambda}_{- \frac{1}{2}, - \frac{1}{2}} & = & \sum_{\tau=-3/2}^{3/2}{- \delta_{- \frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1520), 0, - \frac{1}{2}} \mathcal{H}^\mathrm{production}_{\Lambda(1520), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{3}{2}}_{\tau,\frac{1}{2}}\left(\theta_{31}\right)} + \sum_{\tau=-1/2}^{1/2}{- \delta_{- \frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1670), 0, - \frac{1}{2}} \mathcal{H}^\mathrm{production}_{\Lambda(1670), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{1}{2}}_{\tau,\frac{1}{2}}\left(\theta_{31}\right)} \\ A^{\Lambda}_{- \frac{1}{2}, \frac{1}{2}} & = & \sum_{\tau=-3/2}^{3/2}{\delta_{- \frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1520), 0, \frac{1}{2}} \mathcal{H}^\mathrm{production}_{\Lambda(1520), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{3}{2}}_{\tau,- \frac{1}{2}}\left(\theta_{31}\right)} + \sum_{\tau=-1/2}^{1/2}{\delta_{- \frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1670), 0, \frac{1}{2}} \mathcal{H}^\mathrm{production}_{\Lambda(1670), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{1}{2}}_{\tau,- \frac{1}{2}}\left(\theta_{31}\right)} \\ A^{\Lambda}_{\frac{1}{2}, - \frac{1}{2}} & = & \sum_{\tau=-3/2}^{3/2}{- \delta_{\frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1520), 0, - \frac{1}{2}} \mathcal{H}^\mathrm{production}_{\Lambda(1520), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{3}{2}}_{\tau,\frac{1}{2}}\left(\theta_{31}\right)} + \sum_{\tau=-1/2}^{1/2}{- \delta_{\frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1670), 0, - \frac{1}{2}} \mathcal{H}^\mathrm{production}_{\Lambda(1670), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{1}{2}}_{\tau,\frac{1}{2}}\left(\theta_{31}\right)} \\ A^{\Lambda}_{\frac{1}{2}, \frac{1}{2}} & = & \sum_{\tau=-3/2}^{3/2}{\delta_{\frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1520), 0, \frac{1}{2}} \mathcal{H}^\mathrm{production}_{\Lambda(1520), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{3}{2}}_{\tau,- \frac{1}{2}}\left(\theta_{31}\right)} + \sum_{\tau=-1/2}^{1/2}{\delta_{\frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Lambda(1670), 0, \frac{1}{2}} \mathcal{H}^\mathrm{production}_{\Lambda(1670), \tau, 0} \mathcal{R}\left(\sigma_{2}\right) d^{\frac{1}{2}}_{\tau,- \frac{1}{2}}\left(\theta_{31}\right)} \\ A^{\Delta}_{- \frac{1}{2}, - \frac{1}{2}} & = & \sum_{\tau=-3/2}^{3/2}{\delta_{- \frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Delta(1232)^{++}, - \frac{1}{2}, 0} \mathcal{H}^\mathrm{production}_{\Delta(1232)^{++}, \tau, 0} \mathcal{R}\left(\sigma_{3}\right) d^{\frac{3}{2}}_{\tau,- \frac{1}{2}}\left(\theta_{12}\right)} \\ A^{\Delta}_{- \frac{1}{2}, \frac{1}{2}} & = & \sum_{\tau=-3/2}^{3/2}{\delta_{- \frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Delta(1232)^{++}, \frac{1}{2}, 0} \mathcal{H}^\mathrm{production}_{\Delta(1232)^{++}, \tau, 0} \mathcal{R}\left(\sigma_{3}\right) d^{\frac{3}{2}}_{\tau,\frac{1}{2}}\left(\theta_{12}\right)} \\ A^{\Delta}_{\frac{1}{2}, - \frac{1}{2}} & = & \sum_{\tau=-3/2}^{3/2}{\delta_{\frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Delta(1232)^{++}, - \frac{1}{2}, 0} \mathcal{H}^\mathrm{production}_{\Delta(1232)^{++}, \tau, 0} \mathcal{R}\left(\sigma_{3}\right) d^{\frac{3}{2}}_{\tau,- \frac{1}{2}}\left(\theta_{12}\right)} \\ A^{\Delta}_{\frac{1}{2}, \frac{1}{2}} & = & \sum_{\tau=-3/2}^{3/2}{\delta_{\frac{1}{2} \tau} \mathcal{H}^\mathrm{decay}_{\Delta(1232)^{++}, \frac{1}{2}, 0} \mathcal{H}^\mathrm{production}_{\Delta(1232)^{++}, \tau, 0} \mathcal{R}\left(\sigma_{3}\right) d^{\frac{3}{2}}_{\tau,\frac{1}{2}}\left(\theta_{12}\right)} \\ \end{array}\end{split}\]
Hide code cell content
unfolded_intensity_expr = perform_cached_doit(
    perform_cached_doit(intensity_expr).xreplace(amp_definitions)
)
expr = unfolded_intensity_expr.xreplace(angles).doit()
expr = expr.xreplace(parameter_defaults)
assert expr.free_symbols == {σ1, σ2, σ3}
del expr
Polarization sensitivity#

We introduce the polarimeter vector field (polarization sensitivity) of the \(\Lambda_c\) decay. It is defined by three quantities \((\alpha_x,\alpha_y,\alpha_z)\) forming a three-dimensional vector \(\vec\alpha\) dependent on just two decay variables, \(\sigma_1=m_{K\pi}^2\), and \(\sigma_2=m_{pK}^2\).

The polarimeter vector field is computed by averaging the Pauli matrices \(\vec\sigma\) contracted with the \(\Lambda_c^+\) helicity indices given the transition amplitude.

(3)#\[ \begin{align} \vec\alpha(m_{K\pi},m_{pK}) = \sum_{\lambda,\nu,\nu'} A^{*}_{\nu,\lambda}\vec\sigma_{\nu,\nu'} A_{\nu',\lambda} \,\big / \sum_{\lambda,\nu} \left|A_{\nu,\lambda}\right|^2 \end{align} \]

The quantities \(\vec\alpha(m_{K\pi},m_{pK})\) give the model-independent representation of the \(\Lambda_c^+\) decay process. It can be used to study \(\Lambda_c^+\) production polarization using

(4)#\[ \begin{align} I(\alpha,\beta,\gamma,m_{K\pi},m_{pK}) = I_0(m_{K\pi},m_{pK})\, \left(1 + \sum_{i,j} P_i R_{ij}(\alpha,\beta,\gamma) \alpha_j(m_{K\pi},m_{pK}) \right)\,, \end{align} \]

where \(R_{ij}(\alpha,\beta,\gamma)\) is a three-dimensional rotation matrix:

\[ \begin{align} R(\alpha,\beta,\gamma) = R_z(\alpha)R_y(\beta)R_z(\gamma)\,, \end{align} \]

and \(I_0\) is the averaged decay rate

\[ \begin{align} I_0(m_{K\pi},m_{pK}) = \sum_{\lambda,\nu}\left|A_{\nu,\lambda}\right|^2\,. \end{align} \]
def to_index(helicity):
    """Symbolic conversion of half-value helicities to Pauli matrix indices."""
    # https://github.com/ComPWA/compwa.github.io/pull/129#issuecomment-1096599896
    return sp.Piecewise(
        (1, sp.LessThan(helicity, 0)),
        (0, True),
    )


ν_prime = sp.Symbol(R"\nu^{\prime}")
polarimetry_exprs = tuple(
    PoolSum(
        formulate_aligned_amplitude(ν, Ν).conjugate()
        * msigma(i)[to_index(ν), to_index(ν_prime)]
        * formulate_aligned_amplitude(ν_prime, Ν),
        (Îť, [-half, +half]),
        (ν, [-half, +half]),
        (ν_prime, [-half, +half]),
    )
    / intensity_expr
    for i in (1, 2, 3)
)
Hide code cell source
unfolded_polarimetry_exprs = tuple(
    perform_cached_doit(perform_cached_doit(x).xreplace(amp_definitions))
    for x in polarimetry_exprs
)
Properties of the vector \(\vec\alpha\)#

The vector \(\vec \alpha\) introduced in Eq. (3) obeys the following properties:

  1. It is a three-dimensional vector defined in the rest frame of the decaying particle. Particularly, it is transformed as a regular vector in case initial (alignment) configuration change.

  2. The length of the vector is limited by 1: \(|\vec{\alpha}| < 1\)

  3. \(\alpha_y=0\) for the decays of a fermion to a fermions and (pseudo)scalar

Here is the prove of the second statement:

\[\begin{split} \begin{align} I_{\nu',\nu} = \sum_{\lambda} A_{\nu',\lambda}^* A_{\nu,\lambda} = \begin{pmatrix} a & c^*\\ c & b \end{pmatrix} = \frac{a+b}{2}\left( \mathbb{I} + (\vec{\sigma} \cdot \vec{\alpha}) \right)\,, \end{align} \end{split}\]

where

\[\begin{split} \begin{align} a &= \left|A_{+,+}\right|^2+\left|A_{+,-}\right|^2\,,\\ \nonumber b &= \left|A_{-,+}\right|^2+\left|A_{-,-}\right|^2\,,\\ \nonumber c &= A_{+,+}^*A_{-,+} + A_{+,-}^*A_{-,-}\,, \end{align} \end{split}\]

and

\[ \begin{align} \alpha_x &= \frac{\text{Re}\,c}{a+b}\,,& \alpha_x &= \frac{\text{Im}\,c}{a+b}\,,& \alpha_z &= \frac{a-b}{a+b}\,, \end{align} \]

To constraint the length of the \(\vec\alpha\), one notices \(ab - c \ge 0\). Therefore,

\[ \begin{align} |\vec\alpha|^2 &= \frac{(a-b)^2+c^2}{(a+b)^2} = \frac{(a+b)^2-4ab+c^2}{(a+b)^2} \leq \frac{(a+b)^2-3ab}{(a+b)^2} \leq 1\, \end{align} \]

since \(a,b \geq 0\).

Computations with TensorWaves#
Conversion to computational backend#

The full expression tree can be converted to a computational, parametrized function as follows. Note that identify all coupling symbols are interpreted as parameters. The remaining symbols (the angles) become arguments to the function.

free_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol.name.startswith("m_")
    and symbol not in masses
    or symbol.name.startswith(R"\Gamma_")
    or symbol in prod_couplings
}
fixed_parameters = {
    symbol: value
    for symbol, value in parameter_defaults.items()
    if symbol not in free_parameters
}
intensity_func = create_parametrized_function(
    unfolded_intensity_expr.xreplace(fixed_parameters),
    parameters=free_parameters,
    backend="jax",
)
polarimetry_funcs = tuple(
    create_parametrized_function(
        expr.xreplace(fixed_parameters),
        parameters=free_parameters,
        backend="jax",
    )
    for expr in unfolded_polarimetry_exprs
)
Phase space#

The \(\Lambda_c^+ \to p K \pi\) kinematics is fully described by two dynamic variables, \(m_{K\pi}\) and \(m_{pK}\) (see Phase space for a three-body decay). The third Mandelstam variable can be computed from the other two and the masses of the initial and final state:

Hide code cell source
computed_σ3 = m0**2 + m1**2 + m2**2 + m3**2 - σ1 - σ2
compute_third_mandelstam = create_function(computed_σ3.subs(masses), backend="jax")
display_definitions({σ3: computed_σ3})
\[\begin{split}\displaystyle \begin{array}{rcl} \sigma_{3} & = & m_{K}^{2} + m_{p}^{2} + m_{\pi}^{2} + m_{\Lambda_c}^{2} - \sigma_{1} - \sigma_{2} \\ \end{array}\end{split}\]

Values for the angles will be computed form the Mandelstam values with a data transformer for the symbolic angle definitions:

kinematic_variables = {
    symbol: expression.doit().xreplace(masses).xreplace(fixed_parameters)
    for symbol, expression in angles.items()
}
kinematic_variables.update({s: s for s in [σ1, σ2, σ3]})  # include identity
transformer = SympyDataTransformer.from_sympy(kinematic_variables, backend="jax")

We now define phase space over a grid that contains the space in the Dalitz plane that is kinematically ‘available’ to the decay:

m0_val, m1_val, m2_val, m3_val = masses.values()
σ1_min = (m2_val + m3_val) ** 2
σ1_max = (m0_val - m1_val) ** 2
σ2_min = (m1_val + m3_val) ** 2
σ2_max = (m0_val - m2_val) ** 2


def generate_phsp_grid(resolution: int):
    x = np.linspace(σ1_min, σ1_max, num=resolution)
    y = np.linspace(σ2_min, σ2_max, num=resolution)
    X, Y = np.meshgrid(x, y)
    Z = compute_third_mandelstam.function(X, Y)
    phsp = {"sigma1": X, "sigma2": Y, "sigma3": Z}
    return X, Y, transformer(phsp)


X, Y, phsp = generate_phsp_grid(resolution=500)
Intensity distribution#

Finally, all intensities can be computed as follows:

Hide code cell source
%config InlineBackend.figure_formats = ['png']
s1_label = R"$\sigma_1=m^2\left(K\pi\right)$"
s2_label = R"$\sigma_2=m^2\left(pK\right)$"
s3_label = R"$\sigma_3=m^2\left(p\pi\right)$"

plt.rc("font", size=15)
fig, ax = plt.subplots(
    figsize=(9, 8),
    tight_layout=True,
)
ax.set_box_aspect(1)
ax.set_title("Intensity distribution")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)

total_intensities = intensity_func(phsp)
mesh = ax.pcolormesh(X, Y, total_intensities, norm=LogNorm())
fig.colorbar(mesh, ax=ax, fraction=0.05, pad=0.02)
fig.savefig("021-intensity-distribution.png", dpi=200)
plt.show()

Hide code cell source
%config InlineBackend.figure_formats = ['svg']


def compute_sub_function(
    func: ParametrizedFunction, phsp: DataSample, non_zero_couplings: str
) -> jnp.ndarray:
    zero_couplings = {
        par: 0
        for par in func.parameters
        if par.startswith(R"\mathcal{H}")
        if "production" in par
        if not any(s in par for s in non_zero_couplings)
    }
    original_parameters = dict(func.parameters)
    func.update_parameters(zero_couplings)
    computed_values = func(phsp)
    func.update_parameters(original_parameters)
    return computed_values


def set_ylim_to_zero(ax):
    _, y_max = ax.get_ylim()
    ax.set_ylim(0, y_max)


fig, (ax1, ax2) = plt.subplots(
    ncols=2,
    figsize=(12, 5),
    sharey=True,
    tight_layout=True,
)
ax1.set_xlabel(s1_label)
ax2.set_xlabel(s2_label)
ax1.set_yticks([])

x = X[0]
y = Y[:, 0]
ax1.fill(x, np.nansum(total_intensities, axis=0), alpha=0.3)
ax2.fill(y, np.nansum(total_intensities, axis=1), alpha=0.3)
for chain_id, chain_name in chain_ids.items():
    label = f"${chain_labels[chain_id]}$"
    sub_intensities = compute_sub_function(
        intensity_func, phsp, non_zero_couplings=[chain_name]
    )
    ax1.plot(x, np.nansum(sub_intensities, axis=0), label=label)
    ax2.plot(y, np.nansum(sub_intensities, axis=1), label=label)
set_ylim_to_zero(ax1)
set_ylim_to_zero(ax2)
ax2.legend()
fig.savefig("021-intensity-projections.svg")
plt.show()

Fit fractions#

The total decay rate for \(\Lambda_c^+ \to pK\pi\) can be broken into fractions that correspond to the different decay chains and interference terms. The total rate is computed as an integral of the intensity over decay kinematics:

\[ \begin{align} I_\text{tot}(\{\mathcal{H}\}) = \int d m_{pK}^2 d m_{K\pi}^2\, I_0(m_{pK}, m_{K\pi} | \{\mathcal{H}\}) \approx \frac{\Phi_0}{N_\text{MC}} \sum_{e=1}^{N_\text{MC}}\,\,I_0(m_{pK,e}, m_{K\pi,e} | \{\mathcal{H}\})\,, \end{align} \]

where \(\Phi_0\) is an (irrelevant) constant equal to the flat phase-space integral, \((m_{pK,e}, m_{K\pi,e})\) is a vector of the kinematic variables for the \(e\)-th point in the MC sample.

The conditional argument \(\{\mathcal{H}\}\) indicates dependence of the rate on the value of the couplings. The individual fractions are found by computing the total rate for a subset of couplings set to zero,

\[\begin{split} \begin{align} I_\text{tot}^{K} &= I_\text{tot}\left(\{\mathcal{H}^{\Lambda_c^+\to\Delta^{**} K}, \mathcal{H}^{\Lambda_c^+\to\Lambda^{**} \pi} = 0\}\right)\,,\\ I_\text{tot}^{\Delta} &= I_\text{tot}\left(\{\mathcal{H}^{\Lambda_c^+\to K^{**} p}, \mathcal{H}^{\Lambda_c^+\to\Lambda^{**} \pi} = 0\}\right)\,,\\ I_\text{tot}^{\Lambda} &= I_\text{tot}\left(\{\mathcal{H}^{\Lambda_c^+\to\Delta^{**} K}, \mathcal{H}^{\Lambda_c^+\to K^{**} p} = 0\}\right)\,,\\ I_\text{tot}^{K/\Lambda} &= I_\text{tot}\left(\{\mathcal{H}^{\Lambda_c^+\to\Delta^{**} K} = 0\}\right) - I_\text{tot}^{K} - I_\text{tot}^{\Lambda}\,,\\ & \dots\,, \end{align} \end{split}\]

where the terms with a single chain index are the rate of the decay chain. The sum of all fractions should give the total rate:

\[ \begin{align} I_\text{tot}\left(\{\mathcal{H}\}\right) = \sum_{R} I_\text{tot}^{R} + \sum_{R < R'} I_\text{tot}^{R/R'} \end{align} \]
Code for computing decay rates
def integrate_intensity(
    intensity_func: ParametrizedFunction,
    phsp: DataSample,
    non_zero_couplings: list[str] | None = None,
) -> float:
    if non_zero_couplings is None:
        intensities = intensity_func(phsp)
    else:
        intensities = compute_sub_function(intensity_func, phsp, non_zero_couplings)
    return np.nansum(intensities) / len(intensities)


def compute_interference(
    intensity_func: ParametrizedFunction,
    phsp: DataSample,
    chain1: list[str],
    chain2: list[str],
) -> float:
    I_interference = integrate_intensity(intensity_func, phsp, chain1 + chain2)
    I_chain1 = integrate_intensity(intensity_func, phsp, chain1)
    I_chain2 = integrate_intensity(intensity_func, phsp, chain2)
    return I_interference - I_chain1 - I_chain2


I_tot = integrate_intensity(intensity_func, phsp)
np.testing.assert_allclose(
    I_tot,
    integrate_intensity(intensity_func, phsp, ["K", R"\Lambda", R"\Delta"]),
)
I_K = integrate_intensity(intensity_func, phsp, non_zero_couplings=["K"])
I_Λ = integrate_intensity(intensity_func, phsp, non_zero_couplings=["Lambda"])
I_Δ = integrate_intensity(intensity_func, phsp, non_zero_couplings=["Delta"])
I_ΛΔ = compute_interference(intensity_func, phsp, ["Lambda"], ["Delta"])
I_KΔ = compute_interference(intensity_func, phsp, ["K"], ["Delta"])
I_KΛ = compute_interference(intensity_func, phsp, ["K"], ["Lambda"])
np.testing.assert_allclose(I_tot, I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ)
Hide code cell source
def render_resonance_row(chain_id):
    rows = [
        (
            Rf"\color{{gray}}{{{p.latex}}}",
            (
                Rf"\color{{gray}}{{{integrate_intensity(intensity_func, phsp, [p.name])/I_tot:.3f}}}"
            ),
        )
        for p in resonance_particles[chain_id]
    ]
    if len(rows) > 1:
        return rows
    return []


rows = [
    R"\hline",
    ("K^{**}", f"{I_K/I_tot:.3f}"),
    *render_resonance_row(chain_id=1),
    (R"\Lambda^{**}", f"{I_Λ/I_tot:.3f}"),
    *render_resonance_row(chain_id=2),
    (R"\Delta^{**}", f"{I_Δ/I_tot:.3f}"),
    *render_resonance_row(chain_id=3),
    (R"\Delta/\Lambda", f"{I_ΛΔ/I_tot:.3f}"),
    (R"K/\Delta", f"{I_KΔ/I_tot:.3f}"),
    (R"K/\Lambda", f"{I_KΛ/I_tot:.3f}"),
    R"\hline",
    (
        R"\mathrm{total}",
        f"{(I_K + I_Λ + I_Δ + I_ΛΔ + I_KΔ + I_KΛ) /I_tot:.3f}",
    ),
]

latex = R"\begin{array}{crr}" + "\n"
latex += R"& I_\mathrm{sub}\,/\,I \\" + "\n"
for row in rows:
    if row == R"\hline":
        latex += R"\hline"
    else:
        latex += "  " + " & ".join(row) + R" \\" + "\n"
latex += R"\end{array}"
Math(latex)
\[\begin{split}\displaystyle \begin{array}{crr} & I_\mathrm{sub}\,/\,I \\ \hline K^{**} & 0.371 \\ \Lambda^{**} & 0.051 \\ \color{gray}{\Lambda(1520)} & \color{gray}{0.031} \\ \color{gray}{\Lambda(1670)} & \color{gray}{0.020} \\ \Delta^{**} & 0.582 \\ \Delta/\Lambda & 0.013 \\ K/\Delta & -0.018 \\ K/\Lambda & 0.001 \\ \hline \mathrm{total} & 1.000 \\ \end{array}\end{split}\]
Polarimetry distributions#
Hide code cell source
def render_mean(array, plus=True):
    array = array.real
    mean = f"{np.nanmean(array):.3f}"
    std = f"{np.nanstd(array):.3f}"
    if plus and float(mean) > 0:
        mean = f"+{mean}"
    return Rf"{mean} \pm {std}"


latex = R"\begin{array}{cccc}" + "\n"
latex += R"& \bar{|\alpha|} & \bar\alpha_x & \bar\alpha_y & \bar\alpha_z \\" + "\n"
for chain_id, chain_name in chain_ids.items():
    latex += f"  {chain_labels[chain_id]} & "
    x, y, z = tuple(
        compute_sub_function(func, phsp, non_zero_couplings=[chain_name])
        for func in polarimetry_funcs
    )
    latex += render_mean(np.sqrt(x**2 + y**2 + z**2), plus=False) + " & "
    latex += " & ".join(map(render_mean, [x, y, z]))
    latex += R" \\" + "\n"
latex += R"\end{array}"
Math(latex)
\[\begin{split}\displaystyle \begin{array}{cccc} & \bar{|\alpha|} & \bar\alpha_x & \bar\alpha_y & \bar\alpha_z \\ K^{**} & 0.873 \pm 0.040 & 0.000 \pm 0.554 & 0.000 \pm 0.396 & +0.100 \pm 0.538 \\ \Lambda^{**} & 0.906 \pm 0.101 & -0.529 \pm 0.214 & 0.000 \pm 0.190 & -0.338 \pm 0.596 \\ \Delta^{**} & 0.540 \pm 0.000 & +0.320 \pm 0.153 & -0.000 \pm 0.000 & -0.324 \pm 0.246 \\ \end{array}\end{split}\]
Hide code cell content
# Slider construction
sliders = {}
for symbol, _value in free_parameters.items():
    if symbol.name.startswith(R"\mathcal{H}"):
        real_slider = create_slider(symbol)
        imag_slider = create_slider(symbol)
        sliders[f"{symbol.name}_real"] = real_slider
        sliders[f"{symbol.name}_imag"] = imag_slider
        real_slider.description = R"\(\mathrm{Re}\)"
        imag_slider.description = R"\(\mathrm{Im}\)"
    else:
        slider = create_slider(symbol)
        sliders[symbol.name] = slider

# Slider ranges
σ3_max = (m0_val - m3_val) ** 2
σ3_min = (m1_val + m2_val) ** 2

for name, slider in sliders.items():
    slider.continuous_update = True
    slider.step = 0.01
    if name.startswith("m_"):
        if "K" in name:
            slider.min = np.sqrt(σ1_min)
            slider.max = np.sqrt(σ1_max)
        elif R"\Lambda" in name:
            slider.min = np.sqrt(σ2_min)
            slider.max = np.sqrt(σ2_max)
        elif R"\Delta" in name:
            slider.min = np.sqrt(σ3_min)
            slider.max = np.sqrt(σ3_max)
    elif name.startswith(R"\Gamma_"):
        slider.min = 0
        slider.max = max(0.5, 2 * slider.value)
    elif name.startswith(R"\mathcal{H}"):
        slider.min = -15
        slider.max = +15


# Slider values
def reset_sliders(click_event):
    for symbol, value in free_parameters.items():
        if symbol.name.startswith(R"\mathcal{H}"):
            set_slider(sliders[symbol.name + "_real"], value)
            set_slider(sliders[symbol.name + "_imag"], value)
        else:
            set_slider(sliders[symbol.name], value)


def set_coupling_to_zero(filter_pattern):
    if isinstance(filter_pattern, Combobox):
        filter_pattern = filter_pattern.value
    for name, _slider in sliders.items():
        if not name.startswith(R"\mathcal{H}"):
            continue
        if filter_pattern not in name:
            continue
        set_slider(sliders[name], 0)


def set_slider(slider, value):
    if slider.description == R"\(\mathrm{Im}\)":
        value = complex(value).imag
    else:
        value = complex(value).real
    n_decimals = -round(np.log10(slider.step))
    if slider.value != round(value, n_decimals):  # widget performance
        slider.value = value


reset_sliders(click_event=None)
reset_button = Button(description="Reset slider values")
reset_button.on_click(reset_sliders)

all_resonances = [r.latex for r_list in resonance_particles.values() for r in r_list]
filter_button = Combobox(
    placeholder="Enter coupling filter pattern",
    options=all_resonances,
    description=R"$\mathcal{H}=0$",
)
filter_button.on_submit(set_coupling_to_zero)

# UI design
latex = {symbol.name: sp.latex(symbol) for symbol in free_parameters}
mass_sliders = [sliders[n] for n in sliders if n.startswith("m_")]
width_sliders = [sliders[n] for n in sliders if n.startswith(R"\Gamma_")]
coupling_sliders = {}
for res_list in resonance_particles.values():
    for res in res_list:
        coupling_sliders[res.name] = (
            [
                s
                for n, s in sliders.items()
                if n.endswith("_real") and res.latex in n
            ],
            [
                s
                for n, s in sliders.items()
                if n.endswith("_imag") and res.latex in n
            ],
            [
                HTMLMath(f"${latex[n[:-5]]}$")
                for n in sliders
                if n.endswith("_real") and res.latex in n
            ],
        )
slider_tabs = Tab(
    children=[
        Tab(
            children=[
                VBox([HBox(s) for s in zip(*pair)])
                for pair in coupling_sliders.values()
            ],
            titles=tuple(coupling_sliders),
        ),
        VBox([HBox([r, i]) for r, i in zip(mass_sliders, width_sliders)]),
    ],
    titles=("Couplings", "Masses and widths"),
)
ui = VBox([slider_tabs, HBox([reset_button, filter_button])])
Hide code cell source
%config InlineBackend.figure_formats = ['png']
fig, axes = plt.subplots(
    figsize=(12, 6.2),
    ncols=2,
    sharey=True,
)
ax1, ax2 = axes
ax1.set_title("Intensity distribution")
ax2.set_title("Polarimeter vector field")
ax1.set_xlabel(Rf"${s1_label[1:-1]}, \alpha_x$")
ax2.set_xlabel(Rf"${s1_label[1:-1]}, \alpha_x$")
ax1.set_ylabel(Rf"${s2_label[1:-1]}, \alpha_z$")
for ax in axes:
    ax.set_box_aspect(1)
fig.canvas.toolbar_visible = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False

mesh = None
quiver = None
XÎą, YÎą, phsp_Îą = generate_phsp_grid(resolution=35)
XI, YI, phsp_I = generate_phsp_grid(resolution=200)


def plot3(**kwargs):
    global quiver, mesh
    kwargs = to_complex_kwargs(**kwargs)
    for func in [*list(polarimetry_funcs), intensity_func]:
        func.update_parameters(kwargs)
    intensity = intensity_func(phsp_I)
    Îąx, Îąy, Îąz = tuple(func(phsp_Îą).real for func in polarimetry_funcs)
    abs_Îą = jnp.sqrt(Îąx**2 + Îąy**2 + Îąz**2)
    if mesh is None:
        mesh = ax1.pcolormesh(XI, YI, intensity, cmap=plt.cm.Reds)
        c_bar = fig.colorbar(mesh, ax=ax1, pad=0.01, fraction=0.0473)
        c_bar.ax.set_yticks([])
    else:
        mesh.set_array(intensity)
    if quiver is None:
        quiver = ax2.quiver(
            XÎą, YÎą, Îąx, Îąz, abs_Îą, cmap=plt.cm.viridis_r, clim=(0, 1)
        )
        c_bar = fig.colorbar(quiver, ax=ax2, pad=0.01, fraction=0.0473)
        c_bar.ax.set_ylabel(R"$\left|\vec\alpha\right|$")
    else:
        quiver.set_UVC(Îąx, Îąz, abs_Îą)
    fig.canvas.draw_idle()


def to_complex_kwargs(**kwargs):
    complex_valued_kwargs = {}
    for key, value in dict(kwargs).items():
        if key.endswith("real"):
            symbol_name = key[:-5]
            imag = kwargs[f"{symbol_name}_imag"]
            complex_valued_kwargs[symbol_name] = complex(value, imag)
        elif key.endswith("imag"):
            continue
        else:
            complex_valued_kwargs[key] = value
    return complex_valued_kwargs


output = interactive_output(plot3, controls=sliders)
fig.tight_layout()
display(ui, output)

B-matrix extension of polarimeter#

Hide code cell content
from __future__ import annotations

import logging
from pathlib import Path
from warnings import filterwarnings

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import polarimetry
import sympy as sp
from ampform.sympy import PoolSum
from IPython.display import display
from matplotlib import cm
from polarimetry import _to_index
from polarimetry.data import create_data_transformer, generate_meshgrid_sample
from polarimetry.io import (
    mute_jax_warnings,
    perform_cached_doit,
    perform_cached_lambdify,
)
from polarimetry.lhcb import load_model_builder, load_model_parameters
from polarimetry.lhcb.particle import load_particles
from sympy.physics.matrices import msigma
from tqdm.auto import tqdm

filterwarnings("ignore")
logging.getLogger("polarimetry.function").setLevel(logging.INFO)
mute_jax_warnings()
POLARIMETRY_DIR = Path(polarimetry.__file__).parent
Formulate expressions#

Reference subsystem 1 is defined as:

Hide code cell source
MODEL_CHOICE = 0
MODEL_FILE = POLARIMETRY_DIR / "lhcb/model-definitions.yaml"
PARTICLES = load_particles(POLARIMETRY_DIR / "lhcb/particle-definitions.yaml")
BUILDER = load_model_builder(MODEL_FILE, PARTICLES, model_id=MODEL_CHOICE)
IMPORTED_PARAMETER_VALUES = load_model_parameters(
    MODEL_FILE, BUILDER.decay, MODEL_CHOICE, PARTICLES
)
REFERENCE_SUBSYSTEM = 1
MODEL = BUILDER.formulate(REFERENCE_SUBSYSTEM, cleanup_summations=True)
MODEL.parameter_defaults.update(IMPORTED_PARAMETER_VALUES)
\[\begin{split} \vec\alpha = \sum_{\nu',\nu,\lambda} A^*_{\nu',\lambda}\vec\sigma_{\nu',\nu} A_{\nu,\lambda} / I_0 \\ \vec\beta = \sum_{\nu,\lambda',\lambda} A^*_{\nu,\lambda'} \vec\sigma_{\lambda',\lambda} A^*_{\nu,\lambda} / I_0 \\ B_{\tau,\rho} = \sum_{\nu,\nu',\lambda',\lambda} A^*_{\nu',\lambda'} \sigma_{\nu',\nu}^\tau A_{\nu,\lambda} \sigma_{\lambda',\lambda}^\rho \end{split}\]
Hide code cell source
half = sp.Rational(1, 2)
Îť, Îťp = sp.symbols(R"lambda \lambda^{\prime}", rational=True)
v, vp = sp.symbols(R"nu \nu^{\prime}", rational=True)
σ = [sp.Matrix([[1, 0], [0, 1]])]
σ.extend(msigma(i) for i in (1, 2, 3))
ref = REFERENCE_SUBSYSTEM
B = tuple(
    tuple(
        PoolSum(
            BUILDER.formulate_aligned_amplitude(vp, Îťp, 0, 0, ref)[0].conjugate()
            * σ[τ][_to_index(vp), _to_index(v)]
            * BUILDER.formulate_aligned_amplitude(v, Îť, 0, 0, ref)[0]
            * σ[ρ][_to_index(Νp), _to_index(Ν)],
            (v, [-half, +half]),
            (vp, [-half, +half]),
            (Îť, [-half, +half]),
            (Îťp, [-half, +half]),
        ).cleanup()
        for ρ in range(4)
    )
    for τ in range(4)
)
del ref
B = sp.Matrix(B)
Functions and data#
Hide code cell content
progress_bar = tqdm(desc="Unfolding expressions", total=16)
B_exprs = []
for τ in range(4):
    row = []
    for ρ in range(4):
        expr = perform_cached_doit(B[τ, ρ].doit().xreplace(MODEL.amplitudes))
        progress_bar.update()
        row.append(expr)
    B_exprs.append(row)
progress_bar.close()
B_exprs = np.array(B_exprs)
B_exprs.shape
(4, 4)
Hide code cell content
progress_bar = tqdm(desc="Lambdifying", total=16)
B_funcs = []
for τ in range(4):
    row = []
    for ρ in range(4):
        func = perform_cached_lambdify(
            B_exprs[τ, ρ].xreplace(MODEL.parameter_defaults),
            backend="jax",
        )
        progress_bar.update()
        row.append(func)
    B_funcs.append(row)
progress_bar.close()
B_funcs = np.array(B_funcs)
Hide code cell content
transformer = create_data_transformer(MODEL)
GRID_SAMPLE = generate_meshgrid_sample(MODEL.decay, resolution=400)
GRID_SAMPLE.update(transformer(GRID_SAMPLE))
X = GRID_SAMPLE["sigma1"]
Y = GRID_SAMPLE["sigma2"]
del transformer
Hide code cell content
B_arrays = jnp.array(
    [[B_funcs[τ, ρ](GRID_SAMPLE) for ρ in range(4)] for τ in range(4)]
)
B_norm = B_arrays / B_arrays[0, 0]
B_arrays.shape
(4, 4, 400, 400)
Plots#
Hide code cell source
%config InlineBackend.figure_formats = ['png']
plt.rcdefaults()
plt.rc("font", size=16)
s1_label = R"$m^2\left(K^-\pi^+\right)$ [GeV$^2$]"
s2_label = R"$m^2\left(pK^-\right)$ [GeV$^2$]"
fig, ax = plt.subplots(figsize=(8, 6.8))
ax.set_title("$I_0 = B_{0, 0}$")
ax.set_xlabel(s1_label)
ax.set_ylabel(s2_label)
ax.set_box_aspect(1)
ax.pcolormesh(X, Y, B_arrays[0, 0].real)
fig.savefig("022/b00-is-intensity.png")
plt.show()

Hide code cell source
%config InlineBackend.figure_formats = ['png']
plt.rcdefaults()
plt.rc("font", size=10)
fig, axes = plt.subplots(
    dpi=200,
    figsize=(11, 10),
    ncols=4,
    nrows=4,
    sharex=True,
    sharey=True,
)
fig.suptitle(
    R"$B_{\tau,\rho} = \sum_{\nu,\nu',\lambda',\lambda} A^*_{\nu',\lambda'}"
    R" \sigma_{\nu',\nu}^\tau A_{\nu,\lambda} \sigma_{\lambda',\lambda}^\rho$"
)
progress_bar = tqdm(total=16)
for ρ in range(4):
    for τ in range(4):
        ax = axes[τ, ρ]
        ax.set_box_aspect(1)
        if τ == 0 and ρ == 0:
            Z = B_arrays[τ, ρ].real
            ax.set_title(f"$B_{{{τ}{p}}}$")
            cmap = cm.viridis
        else:
            Z = B_norm[τ, ρ].real
            ax.set_title(f"$B_{{{τ}{p}}} / B_{{00}}$")
            cmap = cm.coolwarm
        mesh = ax.pcolormesh(X, Y, Z, cmap=cmap)
        cbar = fig.colorbar(mesh, ax=ax, fraction=0.047, pad=0.01)
        if τ != 0 or ρ != 0:
            mesh.set_clim(vmin=-1, vmax=+1)
            cbar.set_ticks([-1, 0, +1])
            cbar.set_ticklabels(["-1", "0", "+1"])
        if τ == 3:
            ax.set_xlabel(s1_label)
        if ρ == 0:
            ax.set_ylabel(s2_label)
        progress_bar.update()
progress_bar.close()
fig.tight_layout()
fig.savefig("022/b-matrix-elements.png")
plt.show()

Hypothesis:

\[\begin{split} B_{0,\rho} = \vec\beta B_{00} \\ B_{\tau,0} = \vec\alpha B_{00} \\ B_{00} = I_0 \end{split}\]
Hide code cell content
def plot_field(vx, vy, v_abs, ax, strides=12, cmap=cm.viridis_r):
    mesh = ax.quiver(
        X[::strides, ::strides],
        Y[::strides, ::strides],
        vx[::strides, ::strides].real,
        vy[::strides, ::strides].real,
        v_abs[::strides, ::strides],
        cmap=cmap,
    )
    mesh.set_clim(vmin=0, vmax=+1)
    return mesh


def plot(x, y, z, strides=14):
    plt.rcdefaults()
    plt.rc("font", size=16)
    fig, ax = plt.subplots(figsize=(8, 6.8), tight_layout=True)
    ax.set_box_aspect(1)
    v_abs = jnp.sqrt(x.real**2 + y.real**2 + z.real**2)
    mesh = plot_field(x, y, v_abs, ax, strides)
    color_bar = fig.colorbar(mesh, ax=ax, pad=0.01)
    return fig, ax, color_bar
Hide code cell source
%config InlineBackend.figure_formats = ['svg']
fig, ax, cbar = plot(
    x=B_norm[3, 0],
    y=B_norm[1, 0],
    z=B_norm[2, 0],
    strides=10,
)
ax.set_title(
    R"$B_{\tau, 0} / B_{00} = \sum_{\nu',\nu,\lambda}"
    R" A^*_{\nu',\lambda}\vec\sigma_{\nu',\nu} A_{\nu,\lambda} / I_0$"
)
ax.set_xlabel(Rf"{s1_label}, $\quad\alpha_z$")
ax.set_ylabel(Rf"{s2_label}, $\quad\alpha_x$")
cbar.set_label(R"$\left|\vec{\alpha}\right|$")
fig.savefig("022/alpha-field.svg")
plt.show()
Hide code cell source
%config InlineBackend.figure_formats = ['svg']
fig, ax, cbar = plot(
    x=B_norm[0, 3],
    y=B_norm[0, 1],
    z=B_norm[0, 2],
    strides=10,
)
ax.set_title(
    R"$B_{0,\rho} / B_{00} = \sum_{\nu,\lambda',\lambda} A^*_{\nu,\lambda'}"
    R" \vec\sigma_{\lambda',\lambda} A^*_{\nu,\lambda} / I_0$"
)
ax.set_xlabel(Rf"{s1_label}, $\quad \beta_z = B_{{03}}$")
ax.set_ylabel(Rf"{s2_label}, $\quad \beta_x = B_{{01}}$")
cbar.set_label(R"$\left|\vec{\beta}\right|$")
fig.savefig("022/beta-field.svg")
plt.show()

Note that \(|\alpha| = |\beta|\):

Îą_abs = jnp.sqrt(jnp.sum(B_norm[1:, 0] ** 2, axis=0))
β_abs = jnp.sqrt(jnp.sum(B_norm[0, 1:] ** 2, axis=0))
np.testing.assert_allclose(ι_abs, β_abs, rtol=1e-14)
Hide code cell source
%config InlineBackend.figure_formats = ['png']
fig, axes = plt.subplots(
    figsize=(11, 6),
    ncols=2,
    sharey=True,
    tight_layout=True,
)
for ax in axes:
    ax.set_box_aspect(1)
ax1, ax2 = axes
ax1.set_title(R"$\alpha$")
ax2.set_title(R"$\beta$")
ax1.pcolormesh(X, Y, Îą_abs.real, cmap=cm.coolwarm).set_clim(vmin=-1, vmax=+1)
ax2.pcolormesh(X, Y, β_abs.real, cmap=cm.coolwarm).set_clim(vmin=-1, vmax=+1)
ax1.set_xlabel(s1_label)
ax2.set_xlabel(s1_label)
ax1.set_ylabel(s2_label)
fig.savefig("022/alpha-beta-comparison.png")
plt.show()
https://github.com/ComPWA/compwa.github.io/assets/29308176/c7268301-11c9-45f2-a5ec-4c2928352a68

3D plots with Plotly#

This TR tests whether the HTML build of the TR notebooks supports Plotly figures. It’s a follow-up to TR-006, without the interactivity of ipywidgets, but with the better 3D rendering of Plotly. For more info on how Plotly figures can be embedded in Sphinx HTML builds, see this page of MyST-NB (particularly the remark on html_js_files.

The following example is copied from this tutorial.

import numpy as np
import plotly.graph_objects as go

X, Y, Z = np.mgrid[-5:5:40j, -5:5:40j, -5:5:40j]
ellipsoid = X * X * 0.5 + Y * Y + Z * Z * 2
fig = go.Figure(
    data=go.Isosurface(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=ellipsoid.flatten(),
        isomin=5,
        isomax=50,
        surface_fill=0.4,
        caps=dict(x_show=False, y_show=False),
        slices_z=dict(
            show=True,
            locations=[
                -1,
                -3,
            ],
        ),
        slices_y=dict(show=True, locations=[0]),
    )
)
fig.show()

Symbolic model serialization#

Hide code cell content
import os
from pathlib import Path
from textwrap import shorten

import graphviz
import polarimetry
import sympy as sp
from ampform.io import aslatex
from ampform.sympy import unevaluated
from IPython.display import Markdown, Math
from polarimetry.amplitude import simplify_latex_rendering
from polarimetry.io import perform_cached_doit
from polarimetry.lhcb import load_model
from polarimetry.lhcb.particle import load_particles
from sympy.printing.mathml import MathMLPresentationPrinter

simplify_latex_rendering()
Expression trees#

SymPy expressions are built up from symbols and mathematical operations as follows:

x, y, z = sp.symbols("x y z")
expression = sp.sin(x * y) / 2 - x**2 + 1 / z
expression
\[\displaystyle - x^{2} + \frac{\sin{\left(x y \right)}}{2} + \frac{1}{z}\]

In the back, SymPy represents these expressions as trees. There are a few ways to visualize this for this particular example:

sp.printing.tree.print_tree(expression, assumptions=False)
Add: -x**2 + sin(x*y)/2 + 1/z
+-Pow: 1/z
| +-Symbol: z
| +-NegativeOne: -1
+-Mul: sin(x*y)/2
| +-Half: 1/2
| +-sin: sin(x*y)
|   +-Mul: x*y
|     +-Symbol: x
|     +-Symbol: y
+-Mul: -x**2
  +-NegativeOne: -1
  +-Pow: x**2
    +-Symbol: x
    +-Integer: 2
Hide code cell source
src = sp.dotprint(
    expression,
    styles=[
        (sp.Number, {"color": "grey", "fontcolor": "grey"}),
        (sp.Symbol, {"color": "royalblue", "fontcolor": "royalblue"}),
    ],
)
graphviz.Source(src)
_images/835e0fa8d228e4245d9775d9d9fc0dd55d7f03eb5161088dd48b827149e3d1d2.svg

Expression trees are powerful, because we can use them as templates for any human-readable presentation we are interested in. In fact, the LaTeX representation that we saw when constructing the expression was generated by SymPy’s LaTeX printer.

src = sp.latex(expression)
Markdown(f"```latex\n{src}\n```")
- x^{2} + \frac{\sin{\left(x y \right)}}{2} + \frac{1}{z}

Hint

SymPy expressions can serve as a template for generating code!

Here’s a number of other representations:

Hide code cell source
def to_mathml(expr: sp.Expr) -> str:
    printer = MathMLPresentationPrinter()
    xml = printer._print(expr)
    return xml.toprettyxml().replace("\t", "  ")


Markdown(f"""
```python
# Python
{sp.pycode(expression)}
```
```cpp
// C++
{sp.cxxcode(expression, standard="c++17")}
```
```fortran
! Fortran
{sp.fcode(expression).strip()}
```
```matlab
% Matlab / Octave
{sp.octave_code(expression)}
```
```julia
# Julia
{sp.julia_code(expression)}
```
```rust
// Rust
{sp.rust_code(expression)} 
```
```xml
<!-- MathML -->
{to_mathml(expression)}
```
""")
# Python
-x**2 + (1/2)*math.sin(x*y) + 1/z
// C++
-std::pow(x, 2) + (1.0/2.0)*std::sin(x*y) + 1.0/z
! Fortran
-x**2 + (1.0d0/2.0d0)*sin(x*y) + 1d0/z
% Matlab / Octave
-x.^2 + sin(x.*y)/2 + 1./z
# Julia
-x .^ 2 + sin(x .* y) / 2 + 1 ./ z
// Rust
-x.powi(2) + (1_f64/2.0)*(x*y).sin() + z.recip() 
<!-- MathML -->
<mrow>
  <mrow>
    <mo>-</mo>
    <msup>
      <mi>x</mi>
      <mn>2</mn>
    </msup>
  </mrow>
  <mo>+</mo>
  <mrow>
    <mfrac>
      <mrow>
        <mi>sin</mi>
        <mfenced>
          <mrow>
            <mi>x</mi>
            <mo>&InvisibleTimes;</mo>
            <mi>y</mi>
          </mrow>
        </mfenced>
      </mrow>
      <mn>2</mn>
    </mfrac>
  </mrow>
  <mo>+</mo>
  <mfrac>
    <mn>1</mn>
    <mi>z</mi>
  </mfrac>
</mrow>
Foldable expressions#

The previous example is quite simple, but SymPy works just as well with huge expressions, as we will see in Large expressions. Before, though, let’s have a look how to define these larger expressions in such a way that we can still read them. A nice solution is to define sp.Expr classes with the @unevaluated decorator (see ComPWA/ampform#364). Here, we define a Chew-Mandelstam function \(\rho^\text{CM}\) for \(S\)-waves. This function requires the definition of a break-up momentum \(q\).

@unevaluated(real=False)
class PhspFactorSWave(sp.Expr):
    s: sp.Symbol
    m1: sp.Symbol
    m2: sp.Symbol
    _latex_repr_ = R"\rho^\text{{CM}}\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        q = BreakupMomentum(s, m1, m2)
        cm = (
            (2 * q / sp.sqrt(s))
            * sp.log((m1**2 + m2**2 - s + 2 * sp.sqrt(s) * q) / (2 * m1 * m2))
            - (m1**2 - m2**2) * (1 / s - 1 / (m1 + m2) ** 2) * sp.log(m1 / m2)
        ) / (16 * sp.pi**2)
        return 16 * sp.pi * sp.I * cm


@unevaluated(real=False)
class BreakupMomentum(sp.Expr):
    s: sp.Symbol
    m1: sp.Symbol
    m2: sp.Symbol
    _latex_repr_ = R"q\left({s}\right)"

    def evaluate(self) -> sp.Expr:
        s, m1, m2 = self.args
        return sp.sqrt((s - (m1 + m2) ** 2) * (s - (m1 - m2) ** 2) / (s * 4))

We now have a very clean mathematical representation of how the \(\rho^\text{CM}\) function is defined in terms of \(q\):

s, m1, m2 = sp.symbols("s m1 m2")
q_expr = BreakupMomentum(s, m1, m2)
ρ_expr = PhspFactorSWave(s, m1, m2)
Math(aslatex({e: e.evaluate() for e in [ρ_expr, q_expr]}))
\[\begin{split}\displaystyle \begin{array}{rcl} \rho^\text{CM}\left(s\right) &=& \frac{i \left(- \left(m_{1}^{2} - m_{2}^{2}\right) \left(- \frac{1}{\left(m_{1} + m_{2}\right)^{2}} + \frac{1}{s}\right) \log{\left(\frac{m_{1}}{m_{2}} \right)} + \frac{2 \log{\left(\frac{m_{1}^{2} + m_{2}^{2} + 2 \sqrt{s} q\left(s\right) - s}{2 m_{1} m_{2}} \right)} q\left(s\right)}{\sqrt{s}}\right)}{\pi} \\ q\left(s\right) &=& \frac{\sqrt{\frac{\left(s - \left(m_{1} - m_{2}\right)^{2}\right) \left(s - \left(m_{1} + m_{2}\right)^{2}\right)}{s}}}{2} \\ \end{array}\end{split}\]

Now, let’s build up a more complicated expression that contains this phase space factor. Here, we use SymPy to derive a Breit-Wigner using a single-channel \(K\) matrix [Chung et al., 1995]:

I = sp.Identity(n=1)
K = sp.MatrixSymbol("K", m=1, n=1)
ρ = sp.MatrixSymbol("rho", m=1, n=1)
T = (I - sp.I * K * ρ).inv() * K
T
\[\displaystyle \left(\mathbb{I} + - i K \rho\right)^{-1} K\]
T.as_explicit()[0, 0]
\[\displaystyle \frac{K_{0, 0}}{- i K_{0, 0} \rho_{0, 0} + 1}\]

Here we need to provide definitions for the matrix elements of \(K\) and \(\rho\). A suitable choice is our phase space factor for \(S\) waves we defined above:

m0, Γ0, γ0 = sp.symbols("m0 Gamma0 gamma0")
K_expr = (γ0**2 * m0 * Γ0) / (s - m0**2)
substitutions = {
    K[0, 0]: K_expr,
    ρ[0, 0]: ρ_expr,
}
Math(aslatex(substitutions))
\[\begin{split}\displaystyle \begin{array}{rcl} K_{0, 0} &=& \frac{\Gamma_{0} \gamma_{0}^{2} m_{0}}{- m_{0}^{2} + s} \\ \rho_{0, 0} &=& \rho^\text{CM}\left(s\right) \\ \end{array}\end{split}\]

And there we have it! After some algebraic simplifications, we get a Breit-Wigner with Chew-Mandelstam phase space factor for \(S\) waves:

T_expr = T.as_explicit().xreplace(substitutions)
BW_expr = T_expr[0, 0].simplify(doit=False)
BW_expr
\[\displaystyle \frac{\Gamma_{0} \gamma_{0}^{2} m_{0}}{- i \Gamma_{0} \gamma_{0}^{2} m_{0} \rho^\text{CM}\left(s\right) - m_{0}^{2} + s}\]

The expression tree now has a node that is ‘folded’:

Hide code cell source
dot_style = [
    (sp.Basic, {"style": "filled", "fillcolor": "white"}),
    (sp.Atom, {"color": "gray", "style": "filled", "fillcolor": "white"}),
    (sp.Symbol, {"color": "dodgerblue1"}),
    (PhspFactorSWave, {"color": "indianred2"}),
]
dot = sp.dotprint(BW_expr, bgcolor=None, styles=dot_style)
graphviz.Source(dot)
_images/25770ec8d019309013ad3947b85dd0a7e73f791af4e8b69130ac9be27d078f02.svg

After unfolding, we get the full expression tree of fundamental mathematical operations:

Hide code cell source
dot = sp.dotprint(BW_expr.doit(), bgcolor=None, styles=dot_style)
graphviz.Source(dot)
_images/43908a020149d350fe0363d0d95bc422c95477a65f6672d2230ae636c3940476.svg
Large expressions#

Here, we import the large symbolic intensity expression that was used for 10.1007/JHEP07(2023)228 and see how well SymPy serialization performs on a much more complicated model.

DATA_DIR = Path(polarimetry.__file__).parent / "lhcb"
PARTICLES = load_particles(DATA_DIR / "particle-definitions.yaml")
MODEL = load_model(DATA_DIR / "model-definitions.yaml", PARTICLES, model_id=0)
unfolded_intensity_expr = perform_cached_doit(MODEL.full_expression)

The model contains 43,198 mathematical operations. See ComPWA/polarimetry#319 for the origin of this investigation.

Serialization with srepr#

SymPy expressions can directly be serialized to Python code as well, with the function srepr(). For the full intensity expression, we can do so with:

%%time
eval_str = sp.srepr(unfolded_intensity_expr)
CPU times: user 1.25 s, sys: 3.96 ms, total: 1.25 s
Wall time: 1.25 s
Hide code cell source
n_nodes = sp.count_ops(unfolded_intensity_expr)
byt = len(eval_str.encode("utf-8"))
mb = f"{1e-6*byt:.2f}"
rendering = shorten(eval_str, placeholder=" ...", width=85)
src = f"""
This serializes the intensity expression of {n_nodes:,d} nodes
to a string of **{mb} MB**.

```python
{rendering} {")" * (rendering.count("(") - rendering.count(")"))}
```
"""
Markdown(src)

This serializes the intensity expression of 43,198 nodes to a string of 1.04 MB.

Add(Pow(Abs(Add(Mul(Add(Mul(Integer(-1), Pow(Add(Mul(Integer(-1), I, ... ))))))))))

It is up to the user, however, to import the classes of each exported node before the string can be unparsed with eval() (see this comment).

imported_intensity_expr = eval(eval_str)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[22], line 1
----> 1 imported_intensity_expr = eval(eval_str)

File <string>:1

NameError: name 'Add' is not defined

In the case of this intensity expression, it is sufficient to import all definition from the main sympy module and the Str class. Optionally, the required import statements can be embedded into the string:

exec_str = f"""\
from sympy import *
from sympy.core.symbol import Str

def get_intensity_function() -> Expr:
    return {eval_str}
"""
exec_filename = Path("../_static/exported_intensity_model.py")
with open(exec_filename, "w") as f:
    f.write(exec_str)

See exported_intensity_model.py for the exported model.

The parsing is then done with exec() instead of the eval() function:

%%time
exec(exec_str)
imported_intensity_expr = get_intensity_function()
CPU times: user 517 ms, sys: 72 ms, total: 589 ms
Wall time: 587 ms

Notice how the imported expression is exactly the same as the serialized one, including assumptions:

assert imported_intensity_expr == unfolded_intensity_expr
assert hash(imported_intensity_expr) == hash(unfolded_intensity_expr)
Common sub-expressions#

A problem is that the expression exported generated with srepr() is not human-readable in practice for large expressions. One way out may be to extract common components of the main expression with Foldable expressions. Another may be to use SymPy to detect and collect common sub-expressions.

sub_exprs, common_expr = sp.cse(unfolded_intensity_expr, order="none")
Hide code cell source
Math(sp.multiline_latex(sp.Symbol("I"), common_expr[0], environment="eqnarray"))
\[\begin{split}\displaystyle \begin{eqnarray} I & = & \left|{x_{113} x_{118} + x_{205} x_{210} + x_{220} x_{223} + x_{239} x_{240} - x_{247} x_{249} - x_{256} x_{258} - x_{262} x_{263} - x_{268} x_{269} + x_{273} \mathcal{H}^\mathrm{production}_{K(892), -1, - \frac{1}{2}} - x_{275} \mathcal{H}^\mathrm{production}_{K(892), 1, \frac{1}{2}} + x_{29} x_{34} + x_{35} x_{38}}\right|^{2} \nonumber\\ & & + \left|{- x_{113} x_{249} - x_{118} x_{247} + x_{205} x_{269} + x_{210} x_{268} - x_{220} x_{240} + x_{223} x_{239} + x_{256} x_{263} - x_{258} x_{262} - x_{270} x_{35} + x_{272} x_{34} \mathcal{H}^\mathrm{production}_{K(892), 1, \frac{1}{2}} + x_{272} x_{38} \mathcal{H}^\mathrm{production}_{K(892), -1, - \frac{1}{2}} + x_{274} x_{29}}\right|^{2} \nonumber\\ & & + \left|{- x_{113} x_{263} + x_{118} x_{262} + x_{205} x_{223} + x_{210} x_{220} - x_{239} x_{269} + x_{240} x_{268} - x_{247} x_{258} - x_{249} x_{256} + x_{273} \mathcal{H}^\mathrm{production}_{K(892), 1, \frac{1}{2}} - x_{275} \mathcal{H}^\mathrm{production}_{K(892), -1, - \frac{1}{2}} + x_{29} x_{38} + x_{34} x_{35}}\right|^{2} \nonumber\\ & & + \left|{x_{118} \left(x_{251} x_{281} + x_{253} x_{283} + x_{255} x_{285}\right) + x_{210} \left(- x_{226} x_{303} - x_{228} x_{304} - x_{230} x_{305} - x_{232} x_{306} - x_{236} x_{299} x_{307} - x_{238} x_{301} x_{307}\right) + x_{223} \left(x_{126} x_{130} x_{163} x_{177} x_{178} x_{179} x_{185} x_{186} x_{187} x_{188} x_{233} x_{299} \mathcal{H}^\mathrm{decay}_{L(1520), 0, \frac{1}{2}} \mathcal{H}^\mathrm{production}_{L(1520), \frac{1}{2}, 0} + x_{126} x_{130} x_{177} x_{178} x_{179} x_{191} x_{199} x_{200} x_{201} x_{202} x_{233} x_{301} \mathcal{H}^\mathrm{decay}_{L(1690), 0, \frac{1}{2}} \mathcal{H}^\mathrm{production}_{L(1690), \frac{1}{2}, 0} - x_{264} x_{303} - x_{265} x_{304} - x_{266} x_{305} - x_{267} x_{306}\right) + x_{240} \left(x_{135} x_{292} + x_{141} x_{294} + x_{146} x_{296} + x_{162} x_{298} + x_{190} x_{300} + x_{204} x_{302}\right) + x_{249} \left(x_{259} x_{281} + x_{260} x_{283} + x_{261} x_{285}\right) + x_{258} \left(x_{112} x_{290} + x_{287} x_{76} + x_{289} x_{96}\right) + x_{263} \left(x_{243} x_{286} x_{287} + x_{245} x_{288} x_{289} + x_{246} x_{288} x_{290}\right) + x_{269} \left(- x_{211} x_{292} - x_{212} x_{294} - x_{213} x_{296} - x_{215} x_{298} - x_{217} x_{300} - x_{219} x_{302}\right) + x_{270} \left(x_{276} \mathcal{H}^\mathrm{production}_{K(1430), 0, \frac{1}{2}} + x_{277} \mathcal{H}^\mathrm{production}_{K(700), 0, \frac{1}{2}} + x_{279} \mathcal{H}^\mathrm{production}_{K(892), 0, \frac{1}{2}}\right) + x_{274} \left(- x_{276} \mathcal{H}^\mathrm{production}_{K(1430), 0, - \frac{1}{2}} - x_{277} \mathcal{H}^\mathrm{production}_{K(700), 0, - \frac{1}{2}} - x_{279} \mathcal{H}^\mathrm{production}_{K(892), 0, - \frac{1}{2}}\right) - x_{308} x_{34} \mathcal{H}^\mathrm{production}_{K(892), -1, - \frac{1}{2}} - x_{308} x_{38} \mathcal{H}^\mathrm{production}_{K(892), 1, \frac{1}{2}}}\right|^{2} \end{eqnarray}\end{split}\]
Hide code cell source
Math(aslatex(dict(sub_exprs[:10])))
\[\begin{split}\displaystyle \begin{array}{rcl} x_{0} &=& m_{K(1430)}^{2} \\ x_{1} &=& m_{2}^{2} \\ x_{2} &=& m_{3}^{2} \\ x_{3} &=& \frac{x_{1}}{2} - x_{2} \\ x_{4} &=& i \left(\sigma_{1} + x_{3}\right) \\ x_{5} &=& \frac{\Gamma_{K(1430)} m_{K(1430)} x_{4} e^{- \gamma_{K(1430)} \sigma_{1}}}{x_{0} + x_{3}} + \sigma_{1} - x_{0} \\ x_{6} &=& \frac{\mathcal{H}^\mathrm{decay}_{K(1430), 0, 0}}{x_{5}} \\ x_{7} &=& m_{K(700)}^{2} \\ x_{8} &=& \frac{\Gamma_{K(700)} m_{K(700)} x_{4} e^{- \gamma_{K(700)} \sigma_{1}}}{x_{3} + x_{7}} + \sigma_{1} - x_{7} \\ x_{9} &=& \frac{\mathcal{H}^\mathrm{decay}_{K(700), 0, 0}}{x_{8}} \\ \end{array}\end{split}\]

This already works quite well with sp.lambdify (without cse=True, this would takes minutes):

%%time
args = sorted(unfolded_intensity_expr.free_symbols, key=str)
_ = sp.lambdify(args, unfolded_intensity_expr, cse=True, dummify=True)
CPU times: user 2.05 s, sys: 0 ns, total: 2.05 s
Wall time: 2.05 s

Still, as can be seen above, there are many sub-expressions that have exactly the same form. It would be better to find those expressions that have a similar structure, so that we can serialize them to functions or custom sub-definitions.

In SymPy, the equivalence between the expressions can be determined by the match() method using Wild symbols. We therefore first have to make all symbols in the common sub-expressions ‘wild’. In addition, in the case of this intensity expression, some of symbols are indexed and need to be replaced first.

pure_symbol_expr = unfolded_intensity_expr.replace(
    query=lambda z: isinstance(z, sp.Indexed),
    value=lambda z: sp.Symbol(sp.latex(z), **z.assumptions0),
)
sub_exprs, common_expr = sp.cse(pure_symbol_expr, order="none")

Note that for example the following two common sub-expressions are equivalent:

\[\begin{split}\displaystyle \begin{array}{rcl} x_{5} &=& \frac{\Gamma_{K(1430)} m_{K(1430)} x_{4} e^{- \gamma_{K(1430)} \sigma_{1}}}{x_{0} + x_{3}} + \sigma_{1} - x_{0} \\ x_{8} &=& \frac{\Gamma_{K(700)} m_{K(700)} x_{4} e^{- \gamma_{K(700)} \sigma_{1}}}{x_{3} + x_{7}} + \sigma_{1} - x_{7} \\ \end{array}\end{split}\]

Wild symbols now allow us to find how these expressions relate to each other.

is_symbol = lambda z: isinstance(z, sp.Symbol)
make_wild = lambda z: sp.Wild(z.name)
X = [x.replace(is_symbol, make_wild) for _, x in sub_exprs]
Math(aslatex(X[5].match(X[8])))
\[\begin{split}\displaystyle \begin{array}{rcl} \gamma_{K(700)} &=& \gamma_{K(1430)} \\ x_{7} &=& x_{0} \\ m_{K(700)} &=& m_{K(1430)} \\ \Gamma_{K(700)} &=& \Gamma_{K(1430)} \\ \end{array}\end{split}\]

Hint

This can be used to define functions for larger, common expression blocks.

Rotating square root cuts#

Hide code cell content
%matplotlib widget
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import sympy as sp
from ampform.io import aslatex
from ampform.sympy import unevaluated
from IPython.display import Image, Math, display
from ipywidgets import FloatSlider, VBox, interactive_output
from plotly.colors import DEFAULT_PLOTLY_COLORS
from plotly.subplots import make_subplots

There are multiple solutions for \(x\) to the equation \(y^2 = x\). The fact that we usually take \(y = \sqrt{x}\) with \(\sqrt{-1} = i\) to be ‘the’ solution to this equation is just a matter of convention. It would be more complete to represent the solution as a set of points in the complex plane, that is, the set \(S = \left\{\left(z, w\right)\in\mathbb{C}^2 | w^2=z\right\}\). This is set forms a Riemann surface in \(\mathbb{C}^2\) space.

In the figure below we see the Riemann surface of a square root in \(\mathbb{C}^2\) space. The \(xy\) plane forms the complex domain \(\mathbb{C}\), the \(z\) axis indicates the imaginary part of the Riemann surface and the color indicates the real part.

Hide code cell source
resolution = 30
R, Θ = np.meshgrid(
    np.linspace(0, 1, num=resolution),
    np.linspace(-np.pi, +np.pi, num=resolution),
)
X = R * np.cos(Θ)
Y = R * np.sin(Θ)
Z = X + Y * 1j
T = np.sqrt(Z)
style = lambda t: dict(
    cmin=-1,
    cmax=+1,
    colorscale="RdBu_r",
    surfacecolor=t.real,
)
fig = go.Figure([
    go.Surface(x=X, y=Y, z=+T.imag, **style(+T), name="+√z"),
    go.Surface(x=X, y=Y, z=-T.imag, **style(-T), name="-√z", showscale=False),
])
fig.update_traces(selector=0, colorbar=dict(title="Re ±√z"))
fig.update_layout(
    height=600,
    scene=dict(
        xaxis_title="Re z",
        yaxis_title="Im z",
        zaxis_title="Im ±√z",
    ),
    title_text="Riemann surface of a square root",
    title_x=0.5,
)
fig.show()

From this figure it becomes clear that it is impossible to define one single-valued function that gives the solution to \(w^2 = u\) is \(w \neq 0\). The familiar single-valued square root operation \(\sqrt{}\) covers only one segment, or sheet, of the Riemann surface and it is defined in such a way that \(\sqrt{-1}=i\). The other half of the surface is covered by \(-\sqrt{}\).

Notice, however, that the sheets for the imaginary component of \(\sqrt{}\) are not smoothly connected at each point. The sign flips around \(z\in\mathbb{R^-}\), because we have \(\sqrt{-1+0i}=-1\) and \(\sqrt{-1+0i}=+1\). We call this discontinuity in the Riemann sheet a branch cut.

Hide code cell source
x = np.linspace(-1, 0, num=resolution // 2)
y = np.zeros(resolution // 2)
t = np.sqrt(x + 1e-8j)
T = np.sqrt(Z)

C0 = DEFAULT_PLOTLY_COLORS[0]
C1 = DEFAULT_PLOTLY_COLORS[1]

style = lambda color, legend: dict(
    colorscale=[[0, color], [1, color]],
    showlegend=legend,
    showscale=False,
    surfacecolor=np.ones(T.shape),
)
linestyle = dict(
    line_color="crimson",
    line_showscale=False,
    line_width=15,
    mode="lines",
    name="Branch cut",
)

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Re ±√z", "Im ±√z"),
    specs=[[{"type": "surface"}, {"type": "surface"}]],
)
fig.add_traces(
    [
        go.Surface(x=X, y=Y, z=+T.real, **style(C0, True), name="+√z"),
        go.Surface(x=X, y=Y, z=-T.real, **style(C1, True), name="-√z"),
    ],
    cols=1,
    rows=1,
)
fig.add_traces(
    [
        go.Surface(x=X, y=Y, z=+T.imag, **style(C0, False), name="+√z"),
        go.Surface(x=X, y=Y, z=-T.imag, **style(C1, False), name="-√z"),
        go.Scatter3d(x=x, y=y, z=-t.imag, **linestyle, showlegend=True),
        go.Scatter3d(x=x, y=y, z=+t.imag, **linestyle, showlegend=False),
    ],
    cols=2,
    rows=1,
)
ticks = dict(
    tickvals=[-1, 0, +1],
    ticktext=["-1", "0", "+1"],
)
fig.update_layout(height=400)
fig.update_scenes(
    xaxis=dict(title="Re z", **ticks),
    yaxis=dict(title="Im z", **ticks),
    zaxis=dict(title="±√z", **ticks),
)
fig.show()

By definition, the branch cut of \(\sqrt{}\) is located at \(\mathbb{R}^-\). There is no requirement about this definition though: we can segment the Riemann surface in any way into two sheets, as long as the sheets remain single-valued. One option is to rotate the cut. With the following definition, we have a single-value square-root function, where the cut is rotated over an angle \(\phi\) around \(z=0\).

Hide code cell source
@unevaluated
class RotatedSqrt(sp.Expr):
    z: Any
    phi: Any = 0
    _latex_repr_ = R"\sqrt[{phi}]{{{z}}}"

    def evaluate(self) -> sp.Expr:
        z, phi = self.args
        return sp.exp(-phi * sp.I / 2) * sp.sqrt(z * sp.exp(phi * sp.I))


z, phi = sp.symbols("z phi")
expr = RotatedSqrt(z, phi)
Math(aslatex({expr: expr.doit(deep=False)}))
\[\begin{split}\displaystyle \begin{array}{rcl} \sqrt[\phi]{z} &=& \sqrt{z e^{i \phi}} e^{- \frac{i \phi}{2}} \\ \end{array}\end{split}\]

In the following widget, we see what the new rotated square root looks like in the complex plane. The left panes show the imaginary part and the right side shows the real part. The upper figures show the value of the rotated square root on the real axis, \(\mathrm{Re}\,z\).

Hide code cell source
symbols = (z, phi)
func = sp.lambdify(symbols, expr.doit())

mpl_fig, axes = plt.subplots(
    figsize=(12, 8.5),
    gridspec_kw=dict(
        height_ratios=[1, 2],
        width_ratios=[1, 1, 0.03],
    ),
    ncols=3,
    nrows=2,
)
mpl_fig.canvas.toolbar_visible = False
mpl_fig.canvas.header_visible = False
mpl_fig.canvas.footer_visible = False
axes[0, 2].remove()
ax1re, ax2re = axes[:, 0]
ax1im, ax2im = axes[:, 1]
ax_bar = axes[1, 2]
ax1re.set_ylabel(f"${sp.latex(expr)}$")
ax1im.set_title(f"$\mathrm{{Im}}\,{sp.latex(expr)}$")
ax1re.set_title(f"$\mathrm{{Re}}\,{sp.latex(expr)}$")
ax2re.set_ylabel("$\mathrm{Im}\,z$")
for ax in (ax1im, ax1re):
    ax.set_yticks([-1, -0.5, 0, +0.5, +1])
    ax.set_yticklabels(["-1", R"$-\frac{1}{2}$", "0", R"$+\frac{1}{2}$", "+1"])
for ax in axes[:, :2].flatten():
    ax.set_xlabel("$\mathrm{Re}\,z$")
    ax.set_xticks([-1, 0, +1])
    ax.set_xticklabels(["-1", "0", "+1"])
    ax.set_yticks([-1, 0, +1])
    ax.set_yticklabels(["-1", "0", "+1"])
for i, ax in enumerate((ax2im, ax2re)):
    ax.axhline(0, c=f"C{i}", ls="dotted", zorder=99)
    ax.set_ylim(-1, +1)

data = None
x = np.linspace(-1, +1, num=400)
X_mpl, Y_mpl = np.meshgrid(x, x)
Z_mpl = X_mpl + Y_mpl * 1j


def plot(phi):
    global data
    mpl_fig.suptitle(Rf"$\phi={phi / np.pi:.4g}\pi$")
    t_mpl = func(x, phi)
    T_mpl = func(Z_mpl, phi)
    if data is None:
        data = {
            "im": ax1im.plot(x, t_mpl.imag, label="imag", c="C0", ls="dotted")[0],
            "re": ax1re.plot(x, t_mpl.real, label="real", c="C1", ls="dotted")[0],
            "im2D": ax2im.pcolormesh(X_mpl, Y_mpl, T_mpl.imag, cmap=plt.cm.coolwarm),
            "re2D": ax2re.pcolormesh(X_mpl, Y_mpl, T_mpl.real, cmap=plt.cm.coolwarm),
        }
    else:
        data["re"].set_ydata(t_mpl.real)
        data["im"].set_ydata(t_mpl.imag)
        data["im2D"].set_array(T_mpl.imag)
        data["re2D"].set_array(T_mpl.real)
    data["im2D"].set_clim(vmin=-1, vmax=+1)
    data["re2D"].set_clim(vmin=-1, vmax=+1)
    ax1im.set_ylim(-1.2, +1.2)
    ax1re.set_ylim(-1.2, +1.2)
    mpl_fig.canvas.draw_idle()


sliders = dict(
    phi=FloatSlider(
        min=-3 * np.pi,
        max=+3 * np.pi,
        step=np.pi / 8,
        description="phi",
        value=-np.pi / 4,
    ),
)
ui = VBox(tuple(sliders.values()))
output = interactive_output(plot, controls=sliders)
cbar = plt.colorbar(data["re2D"], cax=ax_bar)
cbar.ax.set_xlabel(f"${sp.latex(expr)}$")
cbar.ax.set_yticks([-1, 0, +1])
mpl_fig.tight_layout()
display(ui, output)
_images/cac3bb4498375e8b1a1ad01842ea527aa58bf09c02631ef56a7f720da4a89385.png

Note

The real part does not have a cut if \(\phi = 2\pi n, n \in \mathbb{Z}\). The cut in the imaginary part disappears if \(\phi = \pi + 2\pi n\).

Execution times

Document

Modified

Method

Run Time (s)

Status

report/004

2024-01-19 22:06

cache

20.59

✅

report/006

2024-01-19 22:06

cache

15.44

✅

report/007

2024-01-19 22:07

cache

13.55

✅

report/023

2024-01-19 22:07

cache

4.64

✅

report/024

2024-01-19 22:08

cache

62.17

✅

report/025

2024-01-19 22:08

cache

6.72

✅

Bibliography#

Tip

Download this bibliography as BibTeX here.

[1]

B. Slatkin. Effective Python: 90 Specific Ways to Write Better Python. Addison-Wesley, November 2019. ISBN:978-0-13-485398-7.

[2]

R. C. Martin, editor. Clean Code: A Handbook of Agile Software Craftsmanship. Prentice Hall, Upper Saddle River, NJ, 2009. ISBN:978-0-13-235088-4.

[3]

H. Percival. Test-Driven Development with Python: Obey the Testing Goat: Using Django, Selenium, and JavaScript. O'Reilly Media, Sebastopol, CA, second edition edition, 2017. ISBN:978-1-4919-5870-4.

[4]

K. Beck. Test-Driven Development by Example. The Addison-Wesley Signature Series. Addison-Wesley, Boston, 2003. ISBN:978-0-321-14653-3.

[5]

E. Gamma, editor. Design Patterns: Elements of Reusable Object-Oriented Software. Addison-Wesley Professional Computing Series. Addison-Wesley, Reading, Mass, 1995. ISBN:978-0-201-63361-0.

[6]

R. Sedgewick and K. D. Wayne. Algorithms. Addison-Wesley, Upper Saddle River, NJ, fourth edition, 2011. ISBN:978-0-321-57351-3.

[7]

D. Marangotto. Helicity Amplitudes for Generic Multibody Particle Decays Featuring Multiple Decay Chains. Advances in High Energy Physics, 2020:1–15, December 2020. doi:10.1155/2020/6674595.

[8]

I. J. R. Aitchison. Unitarity, Analyticity and Crossing Symmetry in Two- and Three-hadron Final State Interactions. arXiv:1507.02697 [hep-ph], July 2015. arXiv:1507.02697.

[9]

S.-U. Chung et al. Partial wave analysis in 𝐾-matrix formalism. Annalen der Physik, 507(5):404–430, May 1995. doi:10.1002/andp.19955070504.

[10]

K. Peters. Partial Wave Analysis. June 2004. slideplayer.com/slide/1676572.

[11]

C. A. Meyer. A 𝐾-Matrix Tutorial. October 2008. www.curtismeyer.com/talks/PWA_Munich_KMatrix.pdf.

[12]

I.J.R. Aitchison. The 𝐾-matrix formalism for overlapping resonances. Nuclear Physics A, 189(2):417–423, July 1972. doi:10.1016/0375-9474(72)90305-3.

[13]

M. Jacob and G.C. Wick. On the general theory of collisions for particles with spin. Annals of Physics, 7(4):404–428, August 1959. doi:10.1016/0003-4916(59)90051-X.

[14]

J. D. Richman. An Experimenter's Guide to the Helicity Formalism. June 1984. inspirehep.net/literature/202987.

[15]

R. Kutschke. An Angular Distribution Cookbook. January 1996. home.fnal.gov/~kutschke/Angdist/angdist.ps.

[16]

S.-U. Chung. Spin Formalisms (Updated Version). Technical Report, Brookhaven National Laboratory, July 2014. suchung.web.cern.ch/spinfm1.pdf.

[17]

M. Mikhasenko et al. Dalitz-plot decomposition for three-body decays. Physical Review D, 101(3):034033, February 2020. doi:10.1103/PhysRevD.101.034033.

[18]

M. Wang et al. A novel method to test particle ordering and final state alignment in helicity formalism. arXiv, December 2020. arXiv:2012.03699.

[19]

E. Byckling and K. Kajantie. Particle Kinematics. Wiley, London, New York, 1973. ISBN:978-0-471-12885-4.