Interactive 3D plots#

Hide code cell content
import os

import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import sympy as sp
from IPython.display import display
from ipywidgets import widgets as ipywidgets
from matplotlib import cm
from matplotlib import widgets as mpl_widgets

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

This report illustrates how to interact with matplotlib 3D plots through Matplotlib sliders and ipywidgets. This might be implemented later on in symplot and/or mpl_interactions (see ianhi/mpl-interactions#89).

In this example, we create a surface plot (see plot_surface()) for the following function.

x, y, a, b = sp.symbols("x y a b")
expression = sp.sqrt(x**a + sp.sin(y / b) ** 2)
expression
\[\displaystyle \sqrt{x^{a} + \sin^{2}{\left(\frac{y}{b} \right)}}\]

The function is formulated with sympy, but we use lambdify() to express it as a numpy function.

numpy_function = sp.lambdify(
    args=(x, y, a, b),
    expr=expression,
    modules="numpy",
)

A surface plot has to be generated over a numpy.meshgrid(). This defines the \(xy\)-plane over which we want to plot our function.

x_min, x_max = 0.1, 2
y_min, y_max = -50, +50
x_values = np.linspace(x_min, x_max, num=20)
y_values = np.linspace(y_min, y_max, num=40)
X, Y = np.meshgrid(x_values, y_values)

The \(z\)-values for plot_surface() can now be simply computed as follows:

a_init = -0.5
b_init = 20
Z = numpy_function(X, Y, a=a_init, b=b_init)

We now want to create sliders for \(a\) and \(b\), so that we can live-update the surface plot through those sliders.

Matplotlib widgets#

Matplotlib provides its own way to define matplotlib.widgets.

fig1, ax1 = plt.subplots(ncols=1, subplot_kw={"projection": "3d"})

# Create sliders and insert them within the figure
plt.subplots_adjust(bottom=0.25)
a_slider = mpl_widgets.Slider(
    ax=plt.axes([0.2, 0.1, 0.65, 0.03]),
    label=f"${sp.latex(a)}$",
    valmin=-2,
    valmax=2,
    valinit=a_init,
)
b_slider = mpl_widgets.Slider(
    ax=plt.axes([0.2, 0.05, 0.65, 0.03]),
    label=f"${sp.latex(b)}$",
    valmin=10,
    valmax=50,
    valinit=b_init,
    valstep=1,
)


# Define what to do when a slider changes
def update_plot(val=None):
    a = a_slider.val
    b = b_slider.val
    ax1.clear()
    Z = numpy_function(X, Y, a, b)
    ax1.plot_surface(
        X,
        Y,
        Z,
        rstride=3,
        cstride=1,
        cmap=cm.coolwarm,
        antialiased=False,
    )
    ax1.set_xlabel(f"${sp.latex(x)}$")
    ax1.set_ylabel(f"${sp.latex(y)}$")
    ax1.set_zlabel(f"${sp.latex(expression)}$")
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_zticks([])
    ax1.set_facecolor("white")
    fig1.canvas.draw_idle()


a_slider.on_changed(update_plot)
b_slider.on_changed(update_plot)

# Plot the surface as initialization
update_plot()
plt.show()

Interactive inline matplotlib output

ipywidgets#

As an alternative, you can use ipywidgets. This package has lot more sliders to offer than Matplotlib, and they look nicer, but it only work within a Jupyter notebook.

For more info, see Using Interact.

Using interact#

Simplest option is to use the ipywidgets.interact() function:

fig2, ax2 = plt.subplots(ncols=1, subplot_kw={"projection": "3d"})


@ipywidgets.interact(a=(-2.0, 2.0), b=(10, 50))
def plot2(a=a_init, b=b_init):
    ax2.clear()
    Z = numpy_function(X, Y, a, b)
    ax2.plot_surface(
        X,
        Y,
        Z,
        rstride=3,
        cstride=1,
        cmap=cm.coolwarm,
        antialiased=False,
    )
    ax2.set_xlabel(f"${sp.latex(x)}$")
    ax2.set_ylabel(f"${sp.latex(y)}$")
    ax2.set_zlabel(f"${sp.latex(expression)}$")
    ax2.set_xticks([])
    ax2.set_yticks([])
    ax2.set_zticks([])
    ax2.set_facecolor("white")
    fig2.canvas.draw_idle()

Using interactive_output#

You can have more control with ipywidgets.interactive_output(). That allows defining the sliders independently, so that you can arrange them as a user interface:

fig3, ax3 = plt.subplots(ncols=1, subplot_kw={"projection": "3d"})
a_ipyslider = ipywidgets.FloatSlider(
    description=f"${sp.latex(a)}$",
    value=a_init,
    min=-2,
    max=2,
    step=0.1,
    readout_format=".1f",
)
b_ipyslider = ipywidgets.IntSlider(
    description=f"${sp.latex(b)}$",
    value=b_init,
    min=10,
    max=50,
)


def plot3(a=a_init, b=b_init):
    ax3.clear()
    Z = numpy_function(X, Y, a, b)
    ax3.plot_surface(
        X,
        Y,
        Z,
        rstride=3,
        cstride=1,
        cmap=cm.coolwarm,
        antialiased=False,
    )
    ax3.set_xlabel(f"${sp.latex(x)}$")
    ax3.set_ylabel(f"${sp.latex(y)}$")
    ax3.set_zlabel(f"${sp.latex(expression)}$")
    ax3.set_xticks([])
    ax3.set_yticks([])
    ax3.set_zticks([])
    ax3.set_facecolor("white")
    fig3.canvas.draw_idle()


ui = ipywidgets.HBox([a_ipyslider, b_ipyslider])
output = ipywidgets.interactive_output(
    plot3, controls={"a": a_ipyslider, "b": b_ipyslider}
)
display(ui, output)

ipywidgets interactive output with interactive_output()

Plotly with ipywidgets#

3D plots with Plotly look a lot nicer and make it possible for the user to pan and zoom the 3D object. As an added bonus, Plotly figures render as interactive 3D objects in the static HTML Sphinx build.

Making 3D Plotly plots interactive with ipywidgets is quite similar to the previous examples with matplotlib. Two recommendations are:

  1. Set continuous_update=False, because plotly is slower than matplotlib in updating the figure.

  2. Save the camera orientation and update it after calling Figure.show().

  3. When embedding the notebook a static webpage with MyST-NB, avoid calling Figure.show() through ipywidgets.interactive_output(), because it causes the notebook to hang in some cycle (see CI for ComPWA/compwa.github.io@d9240f1). In the example below, the update_plotly() function is aborted if the notebook is run through Sphinx.

Hide code cell source
plotly_a = ipywidgets.FloatSlider(
    description=f"${sp.latex(a)}$",
    value=a_init,
    min=-2,
    max=2,
    step=0.1,
    continuous_update=False,
    readout_format=".1f",
)
plotly_b = ipywidgets.IntSlider(
    description=f"${sp.latex(b)}$",
    value=b_init,
    min=10,
    max=50,
    continuous_update=False,
)
plotly_controls = {"a": plotly_a, "b": plotly_b}

plotly_surface = go.Surface(
    x=X,
    y=Y,
    z=Z,
    surfacecolor=Z,
    colorscale="RdBu_r",
    name="Surface",
)
plotly_fig = go.Figure(data=[plotly_surface])
plotly_fig.update_layout(height=500)
if STATIC_WEB_PAGE:
    plotly_fig.show()


def update_plotly(a, b):
    if STATIC_WEB_PAGE:
        return
    Z = numpy_function(X, Y, a, b)
    camera_orientation = plotly_fig.layout.scene.camera
    plotly_fig.update_traces(
        x=X,
        y=Y,
        z=Z,
        surfacecolor=Z,
        selector=dict(name="Surface"),
    )
    plotly_fig.show()
    plotly_fig.update_layout(scene=dict(camera=camera_orientation))


plotly_ui = ipywidgets.HBox([plotly_a, plotly_b])
plotly_output = ipywidgets.interactive_output(update_plotly, plotly_controls)
display(plotly_ui, plotly_output)