Source code for pyyeti.column_plotter
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
[docs]
class ColumnPlotter:
r"""
A very basic class to visually compare matrices by plotting
columns in an interactive window.
The user can scroll forward and backward through the columns of
data, while using the plotting tools provided by `matplotlib
<https://matplotlib.org>`_ (such as zoom).
Examples
--------
.. plot::
:context: close-figs
>>> import numpy as np
>>> from pyyeti.column_plotter import ColumnPlotter
>>> #
>>> h = 0.01
>>> t = np.arange(0, 1, h)
>>> sin = np.sin(3 * 2 * np.pi * t)
>>> sin_a = sin + np.random.randn(*sin.shape) * 0.1
>>> sin_b = sin + np.random.randn(*sin.shape) * 0.1
>>> A = np.column_stack((sin, sin_a, sin_b))
>>> B = A + np.random.randn(*A.shape) * 0.2 + 2
>>> C = A + np.random.randn(*A.shape) * 0.2 + 3
>>> #
>>> cp = ColumnPlotter(
... t,
... dict(A=A, B=B, C=C),
... ["1st column", "2nd column", "3rd column"],
... )
"""
[docs]
def __init__(self, x, dct, column_labels=None):
"""
Instantiates a :class:`ColumnPlotter` object
Parameters
----------
x : 1d array_like
The x-axis data
dct : dict
Dictionary of y-axis data. The key is the label for the
data and will be in the legend. The values are the
matrices to compare and the number of rows must equal
``len(x)``. All matrices are expected to be the same size.
column_labels : list or None; optional
List of strings to be used for the plot titles;
``len(column_labels)`` is expected to be equal to the
number of columns in the y-axis data matrices.
"""
self.x = x
self.dct = dct
self.column_labels = column_labels
self.ind = 0
fig, self.ax = plt.subplots()
plt.subplots_adjust(bottom=0.2)
self.lines = {}
self.n_columns = None
for key, value in self.dct.items():
line = plt.plot(self.x, value[:, self.ind], label=key)[0]
self.lines[line] = value
if self.n_columns is None:
self.n_columns = value.shape[1]
plt.legend()
plt.grid(True)
if self.column_labels:
self.ttl = plt.title(self.column_labels[self.ind])
axprev = plt.axes([0.7, 0.05, 0.1, 0.075])
axnext = plt.axes([0.81, 0.05, 0.1, 0.075])
self.bnext = Button(axnext, "Next")
self.bnext.on_clicked(self._next)
self.bprev = Button(axprev, "Previous")
self.bprev.on_clicked(self._prev)
axcolor = "lightgoldenrodyellow"
slider_ax = plt.axes([0.1, 0.1, 0.50, 0.03], facecolor=axcolor)
self.slider = Slider(
slider_ax, "Column", 0, self.n_columns - 1, valinit=0, valfmt="%.0f"
)
self.slider.on_changed(self._update_column)
def _update_plot(self):
for line, value in self.lines.items():
line.set_ydata(value[:, self.ind])
self.ax.relim()
# update ax.viewLim using the new dataLim
self.ax.autoscale()
# self.ax.autoscale_view()
if self.column_labels:
self.ttl.set_text(self.column_labels[self.ind])
plt.draw()
def _next(self, event):
# print("next")
if self.ind < self.n_columns - 1:
self.ind += 1
self.slider.set_val(self.ind)
self._update_plot()
def _prev(self, event):
# print("prev")
if self.ind > 0:
self.ind -= 1
self.slider.set_val(self.ind)
self._update_plot()
def _update_column(self, new_column):
self.ind = int(new_column)
self._update_plot()