######################################################################
# BioSimSpace: Making biomolecular simulation a breeze!
#
# Copyright: 2017-2024
#
# Authors: Lester Hedges <lester.hedges@gmail.com>
#
# BioSimSpace is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# BioSimSpace is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with BioSimSpace. If not, see <http://www.gnu.org/licenses/>.
#####################################################################
"""Tools for plotting data."""
__author__ = "Lester Hedges"
__email__ = "lester.hedges@gmail.com"
__all__ = ["plot", "plotContour", "plotOverlapMatrix"]
import numpy as _np
from warnings import warn as _warn
from os import environ as _environ
from .. import _is_interactive, _is_notebook
from ..Types._type import Type as _Type
# Check to see if DISPLAY is set.
if "DISPLAY" in _environ:
_display = _environ.get("DISPLAY")
else:
_display = None
del _environ
if _display is not None:
_has_display = True
try:
import matplotlib.pyplot as _plt
import matplotlib.colors as _colors
_has_matplotlib = True
except ImportError:
_has_matplotlib = False
else:
if _is_notebook:
try:
import matplotlib.pyplot as _plt
import matplotlib.colors as _colors
_has_matplotlib = True
except ImportError:
_has_matplotlib = False
else:
_has_matplotlib = False
_has_display = False
# _warn("The DISPLAY environment variable is unset. Plotting functionality disabled!")
del _display
if _has_matplotlib:
# Define font sizes.
_SMALL_SIZE = 14
_MEDIUM_SIZE = 16
_BIGGER_SIZE = 18
# Set font sizes.
_plt.rc("font", size=_SMALL_SIZE) # controls default text sizes
_plt.rc("axes", titlesize=_SMALL_SIZE) # fontsize of the axes title
_plt.rc("axes", labelsize=_MEDIUM_SIZE) # fontsize of the x and y labels
_plt.rc("xtick", labelsize=_SMALL_SIZE) # fontsize of the tick labels
_plt.rc("ytick", labelsize=_SMALL_SIZE) # fontsize of the tick labels
_plt.rc("legend", fontsize=_SMALL_SIZE) # legend fontsize
_plt.rc("figure", titlesize=_BIGGER_SIZE) # fontsize of the figure title
[docs]
def plot(
x=None,
y=None,
xerr=None,
yerr=None,
xlabel=None,
ylabel=None,
logx=False,
logy=False,
):
"""
A simple function to create x/y plots with matplotlib.
Parameters
----------
x : list
A list of x data values.
y : list
A list of y data values.
xerr : list
A list of error values for the x data.
yerr : list
A list of error values for the y data.
xlabel : str
The x axis label string.
ylabel : str
The y axis label string.
logx : bool
Whether the x axis is logarithmic.
logy : bool
Whether the y axis is logarithmic.
"""
# Make sure were running interactively.
if not _is_interactive:
_warn("You can only use BioSimSpace.Notebook.plot when running interactively.")
return None
# Matplotlib failed to import.
if not _has_matplotlib and _has_display:
_warn(
"BioSimSpace.Notebook.plot is disabled as matplotlib failed "
"to load. Please check your matplotlib installation."
)
return None
if not isinstance(logx, bool):
raise TypeError("'logx' must be of type 'bool'.")
if not isinstance(logy, bool):
raise TypeError("'logy' must be of type 'bool'.")
# Whether we need to convert the x and y data to floats.
is_unit_x = False
is_unit_y = False
if x is None:
if y is None:
raise ValueError("'y' data must be defined!")
# No x data, use array index as value.
x = [x for x in range(0, len(y))]
else:
# No y data, we assume that the user wants to plot the x
# data as a series.
if y is None:
y = x
x = [x for x in range(0, len(y))]
# The x argument must be a list or tuple of data records.
if not isinstance(x, (list, tuple)):
raise TypeError("'x' must be of type 'list'")
else:
# Make sure all records are of the same type. Missing data will be
# None, so find the first unit type.
for idx, xx in enumerate(x):
if xx is not None:
_type = type(xx)
break
if not all(isinstance(xx, (_type, type(None))) for xx in x):
raise TypeError("All 'x' data values must be of same type")
# Convert int to float.
if _type is int:
x = [float(xx) for xx in x]
_type = float
try:
xerr = [float(xx) for xx in xerr]
except:
pass
# Make sure any associated error has the same unit.
if xerr is not None:
if not all(isinstance(xx, (_type, type(None))) for xx in xerr):
raise TypeError("All 'xerr' values must be of same type as x data")
# Does this type have units?
if isinstance(x[idx], _Type):
is_unit_x = True
# The y argument must be a list or tuple of data records.
if not isinstance(y, (list, tuple)):
raise TypeError("'y' must be of type 'list'")
else:
# Make sure all records are of the same type. Missing data will be
# None, so find the first unit type.
for idx, yy in enumerate(y):
if yy is not None:
_type = type(yy)
break
if not all(isinstance(yy, (_type, type(None))) for yy in y):
raise TypeError("All 'y' data values must be of same type")
# Convert int to float.
if _type is int:
y = [float(yy) for yy in y]
_type = float
try:
yerr = [float(yy) for yy in yerr]
except:
pass
# Make sure any associated error has the same unit.
if yerr is not None:
if not all(isinstance(yy, (_type, type(None))) for yy in yerr):
raise TypeError("All 'yerr' values must be of same type as y data")
# Does this type have units?
if isinstance(y[idx], _Type):
is_unit_y = True
# Strip any missing values.
# x-dimension
idx = [i for i, v in enumerate(x) if v is not None]
x = list(filter(lambda v: v is not None, x))
y = [y[i] for i in idx]
if xerr is not None:
xerr = [xerr[i] for i in idx]
if yerr is not None:
yerr = [yerr[i] for i in idx]
# y-dimension
idx = [i for i, v in enumerate(y) if v is not None]
y = list(filter(lambda v: v is not None, y))
x = [x[i] for i in idx]
if xerr is not None:
xerr = [xerr[i] for i in idx]
if yerr is not None:
yerr = [yerr[i] for i in idx]
# Lists must contain the same number of records.
# Truncate the longer list to the length of the shortest.
if len(x) != len(y):
_warn("Mismatch in list sizes: len(x) = %d, len(y) = %d" % (len(x), len(y)))
len_x = len(x)
len_y = len(y)
if len_x < len_y:
y = y[:len_x]
else:
x = x[:len_y]
if xerr is not None:
xerr = xerr[: len(x)]
if yerr is not None:
yerr = yerr[: len(y)]
if xlabel is not None:
if not isinstance(xlabel, str):
raise TypeError("'xlabel' must be of type 'str'")
else:
if isinstance(x[0], _Type):
xlabel = (
x[0].__class__.__qualname__
+ " ("
+ x[0]._print_format[x[0].unit()]
+ ")"
)
if ylabel is not None:
if not isinstance(ylabel, str):
raise TypeError("'ylabel' must be of type 'str'")
else:
if isinstance(y[0], _Type):
ylabel = (
y[0].__class__.__qualname__
+ " ("
+ y[0]._print_format[y[0].unit()]
+ ")"
)
# Convert the x and y values to floats.
if is_unit_x:
x = [x.value() for x in x]
if xerr is not None:
xerr = [x.value() for x in xerr]
if is_unit_y:
y = [y.value() for y in y]
if yerr is not None:
yerr = [y.value() for y in yerr]
# Set the figure size.
_plt.figure(figsize=(8, 6))
# Create the plot.
if xerr is None and yerr is None:
_plt.plot(x, y, "-bo")
else:
if xerr is None:
_plt.errorbar(x, y, yerr=yerr, fmt="-bo")
else:
if yerr is None:
_plt.errorbar(x, y, xerr=xerr, fmt="-bo")
else:
_plt.errorbar(x, y, xerr=xerr, yerr=yerr, fmt="-bo")
# Add axis labels.
if xlabel is not None:
_plt.xlabel(xlabel)
if ylabel is not None:
_plt.ylabel(ylabel)
# Scale the axes.
if logx:
_plt.xscale("log")
if logy:
_plt.yscale("log")
# Turn on grid.
_plt.grid()
return _plt.show()
[docs]
def plotContour(x, y, z, xlabel=None, ylabel=None, zlabel=None):
"""
A simple function to create two-dimensional contour plots with matplotlib.
Parameters
----------
x : list
A list of x data values.
y : list
A list of y data values.
z : list
A list of z data values.
xlabel : str
The x axis label string.
ylabel : str
The y axis label string.
zlabel : str
The z axis label string.
"""
import numpy as _np
import scipy.interpolate as _interp
from mpl_toolkits.axes_grid1 import make_axes_locatable as _make_axes_locatable
# Make sure were running interactively.
if not _is_interactive:
_warn("You can only use BioSimSpace.Notebook.plot when running interactively.")
return None
# Matplotlib failed to import.
if not _has_matplotlib and _has_display:
_warn(
"BioSimSpace.Notebook.plot is disabled as matplotlib failed "
"to load. Please check your matplotlib installation."
)
return None
# Whether we need to convert the x, y, and z data to floats.
is_unit_x = False
is_unit_y = False
is_unit_z = False
# The x argument must be a list or tuple of data records.
if not isinstance(x, (list, tuple)):
raise TypeError("'x' must be of type 'list'")
else:
# Make sure all records are of the same type.
_type = type(x[0])
if not all(isinstance(xx, _type) for xx in x):
raise TypeError("All 'x' data values must be of same type")
# Convert int to float.
if _type is int:
x = [float(xx) for xx in x]
_type = float
# Does this type have units?
if isinstance(x[0], _Type):
is_unit_x = True
# The y argument must be a list or tuple of data records.
if not isinstance(y, (list, tuple)):
raise TypeError("'y' must be of type 'list'")
else:
# Make sure all records are of the same type.
_type = type(y[0])
if not all(isinstance(yy, _type) for yy in y):
raise TypeError("All 'y' data values must be of same type")
# Convert int to float.
if _type is int:
y = [float(yy) for yy in y]
_type = float
# Does this type have units?
if isinstance(y[0], _Type):
is_unit_y = True
if not isinstance(z, (list, tuple)):
raise TypeError("'z' must be of type 'list'")
else:
# Make sure all records are of the same type.
_type = type(z[0])
if not all(isinstance(zz, _type) for zz in z):
raise TypeError("All 'z' data values must be of same type")
# Convert int to float.
if _type is int:
z = [float(zz) for zz in z]
_type = float
# Does this type have units?
if isinstance(z[0], _Type):
is_unit_z = True
# Lists must contain the same number of records.
# Truncate the longer list to the length of the shortest.
if len(x) != len(y) or len(x) != len(z) or len(y) != len(z):
_warn(
"Mismatch in list sizes: len(x) = %d, len(y) = %d, len(z) = %d"
% (len(x), len(y), len(z))
)
lens = [len(x), len(y), len(z)]
min_len = min(lens)
x = x[:min_len]
y = y[:min_len]
z = z[:min_len]
if xlabel is not None:
if not isinstance(xlabel, str):
raise TypeError("'xlabel' must be of type 'str'")
else:
if isinstance(x[0], _Type):
xlabel = (
x[0].__class__.__qualname__
+ " ("
+ x[0]._print_format[x[0].unit()]
+ ")"
)
if ylabel is not None:
if not isinstance(ylabel, str):
raise TypeError("'ylabel' must be of type 'str'")
else:
if isinstance(y[0], _Type):
ylabel = (
y[0].__class__.__qualname__
+ " ("
+ y[0]._print_format[y[0].unit()]
+ ")"
)
if zlabel is not None:
if not isinstance(zlabel, str):
raise TypeError("'zlabel' must be of type 'str'")
else:
if isinstance(z[0], _Type):
zlabel = (
z[0].__class__.__qualname__
+ " ("
+ z[0]._print_format[z[0].unit()]
+ ")"
)
# Convert the x and y values to floats.
if is_unit_x:
x = [x.value() for x in x]
if is_unit_y:
y = [y.value() for y in y]
if is_unit_z:
z = [z.value() for z in z]
# Convert to two-dimensional arrays. We don't assume the data is on a grid,
# so we interpolate the z values.
try:
(X, Y) = _np.meshgrid(
_np.linspace(_np.min(x), _np.max(x), 1000),
_np.linspace(_np.min(y), _np.max(y), 1000),
)
Z = _interp.griddata((x, y), z, (X, Y), method="linear")
except:
raise ValueError("Unable to interpolate x, y, and z data to a grid.")
# Set the figure size.
_plt.figure(figsize=(8, 8))
# Create the contour plot.
cp = _plt.contourf(X, Y, Z)
# Add axis labels.
if xlabel is not None:
_plt.xlabel(xlabel)
if ylabel is not None:
_plt.ylabel(ylabel)
# Get the current axes.
ax = _plt.gca()
# Make sure the axes are equal.
ax.set_aspect("equal", adjustable="box")
# Make sure the colour bar matches size of the axes.
divider = _make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.1)
# Add a colour bar and label it.
cbar = _plt.colorbar(cp, cax=cax)
if zlabel is not None:
cbar.set_label(zlabel)
return _plt.show()
[docs]
def plotOverlapMatrix(
overlap, continuous_cbar=False, color_bar_cutoffs=[0.03, 0.1, 0.3]
):
"""
Plot the overlap matrix from a free-energy perturbation analysis.
Parameters
----------
overlap : List of List of float, or 2D numpy array of float
The overlap matrix.
continuous_cbar : bool, optional, default=False
If True, use a continuous colour bar. Otherwise, use a discrete
set of values defined by the 'color_bar_cutoffs' argument to
assign a colour to each element in the matrix.
color_bar_cutoffs : List of float, optional, default=[0.03, 0.1, 0.3]
The cutoffs to use when assigning a colour to each element in the
matrix. This is used for both the continuous and discrete color bars.
Can not contain more than 3 elements.
"""
# Make sure were running interactively.
if not _is_interactive:
_warn("You can only use BioSimSpace.Notebook.plot when running interactively.")
return None
# Matplotlib failed to import.
if not _has_matplotlib and _has_display:
_warn(
"BioSimSpace.Notebook.plot is disabled as matplotlib failed "
"to load. Please check your matplotlib installation."
)
return None
# Validate the input
if not isinstance(overlap, (list, tuple, _np.ndarray)):
raise TypeError(
"The 'overlap' matrix must be a list of list types, or a numpy array!"
)
# Try converting to a NumPy array.
try:
overlap = _np.array(overlap)
except:
raise TypeError(
"'overlap' must be of type 'np.matrix', 'np.ndarray', or a list of lists."
)
# Store the number of rows.
num_rows = len(overlap)
# Check the data in each row.
for row in overlap:
if not isinstance(row, (list, tuple, _np.ndarray)):
raise TypeError("The 'overlap' matrix must be a list of list types!")
if len(row) != num_rows:
raise ValueError("The 'overlap' matrix must be square!")
if not all(isinstance(x, float) for x in row):
raise TypeError("The 'overlap' matrix must contain 'float' types!")
# Check the colour bar options
if not isinstance(continuous_cbar, bool):
raise TypeError("The 'continuous_cbar' option must be a boolean!")
if not isinstance(color_bar_cutoffs, (list, tuple, _np.ndarray)):
raise TypeError(
"The 'color_bar_cutoffs' option must be a list of floats "
" or a numpy array when 'continuous_cbar' is False!"
)
if not all(isinstance(x, float) for x in color_bar_cutoffs):
raise TypeError("The 'color_bar_cutoffs' option must be a list of floats!")
if len(color_bar_cutoffs) > 3:
raise ValueError(
"The 'color_bar_cutoffs' option must contain no more than 3 elements!"
)
# Add 0 and 1 to the colour bar cutoffs.
if color_bar_cutoffs is not None:
color_bounds = [0] + color_bar_cutoffs + [1]
# Tuple of colours and associated font colours.
# The last and first colours are for the top and bottom of the scale
# for the continuous colour bar, but are ignored for the discrete bar.
all_colors = (
("#FBE8EB", "black"), # Lighter pink
("#FFD3E0", "black"),
("#88CCEE", "black"),
("#78C592", "black"),
("#117733", "white"),
("#004D00", "white"),
) # Darker green
# Set the colour map.
if continuous_cbar:
# Create a color map using the extended palette and positions
box_colors = [all_colors[i][0] for i in range(len(color_bounds) + 1)]
cmap = _colors.LinearSegmentedColormap.from_list(
"CustomMap", list(zip(color_bounds, box_colors))
)
# Normalise the same way each time so that plots are always comparable.
norm = _colors.Normalize(vmin=0, vmax=1)
else:
# Throw away the first and last colours.
box_colors = [colors[0] for colors in all_colors[1:-1]]
cmap = _colors.ListedColormap(
[box_colors[i] for i in range(len(color_bounds) - 1)]
)
norm = _colors.BoundaryNorm(color_bounds, cmap.N)
# Create the figure and axis. Use a default size for fewer than 16 windows,
# otherwise scale the figure size to the number of windows.
if num_rows < 16:
fig, ax = _plt.subplots(figsize=(8, 8), dpi=300)
else:
fig, ax = _plt.subplots(figsize=(num_rows / 2, num_rows / 2), dpi=300)
# Create the heatmap. Separate the cells with white lines.
im = ax.imshow(overlap, cmap=cmap, norm=norm)
for i in range(num_rows - 1):
for j in range(num_rows - 1):
# Make sure these are on the edges of the cells.
ax.axhline(i + 0.5, color="white", linewidth=0.5)
ax.axvline(j + 0.5, color="white", linewidth=0.5)
# Label each cell with the overlap value.
for i in range(num_rows):
for j in range(num_rows):
# Get the text colour based on the overlap value.
overlap_val = overlap[i][j]
# Get the index of first color bound greater than the overlap value.
for idx, bound in enumerate(color_bounds):
if bound > overlap_val:
break
text_color = all_colors[1:-1][idx - 1][1]
ax.text(
j,
i,
"{:.2f}".format(overlap[i][j]),
ha="center",
va="center",
fontsize=10,
color=text_color,
)
# Create a colorbar. Reduce the height of the colorbar to match the figure and remove the border.
if continuous_cbar:
cbar = ax.figure.colorbar(im, ax=ax, cmap=cmap, norm=norm, shrink=0.7)
else:
cbar = ax.figure.colorbar(
im,
ax=ax,
cmap=cmap,
norm=norm,
boundaries=color_bounds,
ticks=color_bounds,
shrink=0.7,
)
cbar.outline.set_visible(False)
# Set the axis labels.
# Set the x axis at the top of the plot.
_plt.xlabel(r"$\lambda$ Index")
ax.xaxis.set_label_position("top")
_plt.ylabel(r"$\lambda$ Index")
ticks = [x for x in range(0, num_rows)]
# Set ticks every lambda window.
_plt.xticks(ticks)
ax.xaxis.tick_top()
_plt.yticks(ticks)
# Remove the borders.
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
# Create a tight layout to trim whitespace.
fig.tight_layout()
return _plt.show()