Skip to content

Speeding up the plotting slider responsiveness #834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 173 additions & 43 deletions torax/plotting/plotruns_lib.py
Original file line number Diff line number Diff line change
@@ -413,21 +413,184 @@ def _transform_data(ds: xr.Dataset):
)


# --- Global state for full-res update management ---
_full_update_state = {'timer': None, 'last_value': None, 'cancel': False}


# --- Downsampling Helper Functions ---
def downsample_array(arr: np.ndarray, factor: int) -> np.ndarray:
"""
Downsample an array along axis 0.
Supports 1D and 2D arrays.
"""
if arr.ndim == 1:
return arr[::factor]
elif arr.ndim == 2:
return arr[::factor, :]
else:
return arr[::factor]


def downsample_time(arr: np.ndarray, factor: int) -> np.ndarray:
"""Downsample a 1D time array."""
return arr[::factor]


# --- Core Update Function ---
def _update(
newtime,
plot_config,
plotdata,
lines: Sequence[matplotlib.lines.Line2D],
use_downsampled: bool = True,
):
"""
Update the given plotdata's lines at the specified newtime.
If use_downsampled is True and a downsampled version exists, use it.
Check for cancellation between attribute updates.
"""
if use_downsampled and hasattr(plotdata, 'downsampled'):
time_array = plotdata.downsampled['t']
else:
time_array = plotdata.t

idx = np.searchsorted(time_array, newtime)
if idx > 0 and (
idx == len(time_array)
or abs(newtime - time_array[idx - 1]) < abs(newtime - time_array[idx])
):
idx -= 1

line_idx = 0
for cfg in plot_config.axes:
if cfg.plot_type == PlotType.TIME_SERIES:
continue # Skip time series plots.
for attr in cfg.attrs:
if (
use_downsampled
and hasattr(plotdata, 'downsampled')
and attr in plotdata.downsampled
):
data = plotdata.downsampled[attr]
else:
data = getattr(plotdata, attr)
if cfg.suppress_zero_values and np.all(data == 0):
continue
if _full_update_state.get('cancel', False):
return
lines[line_idx].set_ydata(data[idx, :])
line_idx += 1


# --- Module-Level Update Functions ---
def update_low_res(
newtime, plot_config, plotdata1, lines1, plotdata2, lines2, fig
):
"""
Immediately update the plots using downsampled (low-res) data.
"""
_update(newtime, plot_config, plotdata1, lines1, use_downsampled=True)
if plotdata2 is not None and lines2 is not None:
_update(newtime, plot_config, plotdata2, lines2, use_downsampled=True)
fig.canvas.draw_idle()


def update_full_res(
newtime, plot_config, plotdata1, lines1, plotdata2, lines2, fig
):
"""
Update the plots using full-resolution data.
"""

_full_update_state['cancel'] = False
_update(newtime, plot_config, plotdata1, lines1, use_downsampled=False)
if plotdata2 is not None and lines2 is not None:
_update(newtime, plot_config, plotdata2, lines2, use_downsampled=False)
fig.canvas.draw_idle()
_full_update_state['timer'] = None


def slider_callback(
newtime, plot_config, plotdata1, lines1, plotdata2, lines2, fig
):
"""
Slider callback: perform an immediate low-res update and schedule a full-res update.
Cancel any pending full-res update if the slider moves again.
"""
_full_update_state['last_value'] = newtime
update_low_res(
newtime, plot_config, plotdata1, lines1, plotdata2, lines2, fig
)
_full_update_state['cancel'] = True
# Cancel any pending full-res update timer.
if _full_update_state['timer'] is not None:
_full_update_state['timer'].stop()
_full_update_state['timer'] = None
# Schedule a full-res update after 1 ms.
timer = fig.canvas.new_timer(interval=1)
timer.single_shot = True # Ensure it fires only once.
# Reset cancellation flag right before full update.
timer.add_callback(
lambda: full_update_if_still(
newtime, plot_config, plotdata1, lines1, plotdata2, lines2, fig
)
)
timer.start()
_full_update_state['timer'] = timer


def full_update_if_still(
newtime, plot_config, plotdata1, lines1, plotdata2, lines2, fig
):
"""
Trigger full-res update only if the slider's last value hasn't changed.
"""
if newtime != _full_update_state['last_value']:
return # The slider moved again; skip this update.
# Clear cancellation flag and run full-res update.
_full_update_state['cancel'] = False
update_full_res(
newtime, plot_config, plotdata1, lines1, plotdata2, lines2, fig
)
_full_update_state['timer'] = None


# --- Main Plotting Function ---
def plot_run(
plot_config: FigureProperties, outfile: str, outfile2: str | None = None
):
"""Plots a single run or comparison of two runs."""
"""
Plots a single run or comparison of two runs.
Computes downsampled data for interactive updates.
"""
if not path.exists(outfile):
raise ValueError(f'File {outfile} does not exist.')
if outfile2 is not None and not path.exists(outfile2):
raise ValueError(f'File {outfile2} does not exist.')
plotdata1 = load_data(outfile)
plotdata2 = load_data(outfile2) if outfile2 else None

# Attribute check. Sufficient to check one PlotData object.
plotdata_attrs = set(
plotdata1.__dataclass_fields__
) # Get PlotData attributes
# Compute downsampled versions for interactive updates.
downsample_factor = 50 # for more better accuracy we are kepping it low
ds1 = {}
ds1['t'] = downsample_time(plotdata1.t, downsample_factor)
for cfg in plot_config.axes:
for attr in cfg.attrs:
data = getattr(plotdata1, attr)
ds1[attr] = downsample_array(data, downsample_factor)
plotdata1.downsampled = ds1

if plotdata2:
ds2 = {}
ds2['t'] = downsample_time(plotdata2.t, downsample_factor)
for cfg in plot_config.axes:
for attr in cfg.attrs:
data = getattr(plotdata2, attr)
ds2[attr] = downsample_array(data, downsample_factor)
plotdata2.downsampled = ds2

# Attribute check.
plotdata_attrs = set(plotdata1.__dataclass_fields__)
for cfg in plot_config.axes:
for attr in cfg.attrs:
if attr not in plotdata_attrs:
@@ -436,8 +599,6 @@ def plot_run(
)

fig, axes, slider_ax = create_figure(plot_config)

# Title handling:
title_lines = [f'(1)={outfile}']
if outfile2:
title_lines.append(f'(2)={outfile2}')
@@ -455,47 +616,16 @@ def plot_run(
# Only create the slider if needed.
if plot_config.contains_spatial_plot_type:
timeslider = create_slider(slider_ax, plotdata1, plotdata2)
def update(newtime):
"""Update plots with new values following slider manipulation."""
fig.constrained_layout = False
_update(newtime, plot_config, plotdata1, lines1, plotdata2, lines2)
fig.constrained_layout = True
fig.canvas.draw_idle()

timeslider.on_changed(update)
timeslider.on_changed(
lambda newtime: slider_callback(
newtime, plot_config, plotdata1, lines1, plotdata2, lines2, fig
)
)

fig.canvas.draw()
plt.show()


def _update(
newtime,
plot_config: FigureProperties,
plotdata1: PlotData,
lines1: Sequence[matplotlib.lines.Line2D],
plotdata2: PlotData | None = None,
lines2: Sequence[matplotlib.lines.Line2D] | None = None,
):
"""Update plots with new values following slider manipulation."""

def update_lines(plotdata, lines):
idx = np.abs(plotdata.t - newtime).argmin()
line_idx = 0
for cfg in plot_config.axes: # Iterate through axes based on plot_config
if cfg.plot_type == PlotType.TIME_SERIES:
continue # Time series plots do not need to be updated
for attr in cfg.attrs: # Update all lines in current subplot.
data = getattr(plotdata, attr)
if cfg.suppress_zero_values and np.all(data == 0):
continue
lines[line_idx].set_ydata(data[idx, :])
line_idx += 1

update_lines(plotdata1, lines1)
if plotdata2 and lines2:
update_lines(plotdata2, lines2)


def create_slider(
ax: matplotlib.axes.Axes,
plotdata1: PlotData,