"""Plotting module for PyTorch prototyping.
:author: Julian M. Kleber
"""
from typing import Type, Optional, Dict, Any, List, Tuple
import logging
import numpy as np
import matplotlib.pyplot as plt
from amarium.utils import load_json_from_file, prepare_file_name_saving, attach_slash
from carate.statistics.analysis import (
get_avg,
get_min,
get_max,
unpack_run,
get_stacked_list,
load_result_json,
get_min_max_avg_cv_run,
)
logger = logging.getLogger(__name__)
[docs]
def plot_trajectory_single_algorithm(
path_to_directory: str,
parameter: str,
save_dir: str = "./plots",
data_name: Optional[str] = None,
) -> None:
"""The plot_classification_algorithm function takes in a path to a
directory containing the results of a classification algorithm and plots
the accuracy of that algorithm on both training and testing data.
:param path_to_directory: str: Used to specify the directory where
the results are stored.
:return: None. :doc-author: Julian M. Kleber
"""
path_to_directory = attach_slash(path_to_directory) + "data/"
legend_text = path_to_directory.split("/")[-3]
if data_name is None:
data_name = f"{legend_text}.json"
result = get_stacked_list(
path_to_directory=path_to_directory,
num_cv=5,
json_name=data_name,
)
plot_range_band_single(
result,
file_name=f"{legend_text}_{parameter}",
save_dir=save_dir,
key_val=parameter,
alpha=0.4,
legend_text=legend_text,
)
[docs]
def plot_range_band_multi(
result: List[Dict[str, float]],
key_vals: List[str],
file_name: str,
alpha: float = 0.5,
y_lim: Tuple[float] = (0.0, 1.01),
set_ylim_dynamically:bool=False,
save_dir: Optional[str] = None,
title_text: Optional[str] = None,
) -> None:
"""The plot_range_band_multi function is used to plot multiple range bands
on the same graph. The function takes in a list of dictionaries, each
dictionary containing the results from one cross-validation run.It also
takes in a list of keys that correspond to values within each dictionary
that should be plotted as range bands. The function then plots all these
values as separate lines with their corresponding ranges filled in between
them.
:param result: List[Dict[str: Used to Pass in the list of
dictionaries :param float]]: Used to Set the alpha value of the
fill.
:param key_vals: List[str]: Used to Specify which metrics to plot.
:param file_name: str: Used to name the file that will be saved.
:param alpha: float=0.5: Used to Set the transparency of the fill
between.
:param save_dir: Optional[str]=None: Used to Save the plot in a
specific directory.
:param : Used to Set the transparency of the fill between min and
max values.
:return: A plot of the average, maximum and minimum values. :doc-
author: Julian M. Kleber
"""
fig, axis = plt.subplots()
axis.set_xlabel("Training step")
for i in range(len(key_vals)):
max_val: List[float]
min_val: List[float]
avg_val: List[float]
max_val, min_val, avg_val = get_min_max_avg_cv_run(
result=result, key_val=key_vals[i]
)
if set_ylim_dynamically:
if i == 0:
y_min = np.min(min_val)
y_max = np.max(max_val)
else:
new_y_min = np.min(min_val)
new_y_max = np.max(max_val)
if y_min > new_y_min:
y_min = new_y_min
if y_max < new_y_max:
y_max = new_y_max
axis.plot(avg_val, label=key_vals[i])
plot_range_fill(max_val, min_val, alpha, axis)
if set_ylim_dynamically:
y_min = y_min - 0.03*y_min
y_max = y_max + 0.03*y_max
axis.set_ylim(y_min, y_max)
else:
axis.set_ylim(*y_lim)
axis.set_ylabel("Value")
axis.legend()
axis.set_title(title_text)
save_publication_graphic(fig_object=fig, file_name=file_name, prefix=save_dir)
[docs]
def plot_range_band_single(
result: List[Dict[str, List[float]]],
key_val: str,
file_name: str,
alpha: float = 0.5,
fixed_y_lim=(0.0, 1.01),
set_ylim_dynamically:bool=False,
save_dir: Optional[str] = None,
legend_text: Optional[str] = None,
) -> None:
"""The plot_range_band function takes in a list of dictionaries, each
dictionary containing the results from one run. It then plots the average
value for each key_val (e.g., 'accuracy') and also plots a range band
between the minimum and maximum values for that key_val across all runs.
:param result: List[Dict[str: Used to plot the results of each run.
:param float]]: Used to specify the type of data that is being
passed into the function.
:param key_val: str: Used to specify which key in the dictionary to
plot.
:param file_name: str: Used to specify the file_name on where to save the plot
:param: alpha: float: Set the alpha level of the range band
:param: fixed_y_lim: str: Setting a fixed limit on the y_axis
:param: set_ylim_dynamically: bool: If the y_lim should be set based on min, max values of
the defined key_val
:param: save_dir: str: Where to save the plot
:param: legend_text: Text to put on the legend
:return: None, but saves a plot with the average value of a list, and the minimum
and maximum values.
:doc-author: Julian M. Kleber
:author: Julian M. Kleber
"""
max_val: List[float]
min_val: List[float]
avg_val: List[float]
max_val, min_val, avg_val = get_min_max_avg_cv_run(result=result, key_val=key_val)
fig, axis = plt.subplots()
if legend_text is not None:
axis.plot(avg_val, "-", label=legend_text)
else:
axis.plot(avg_val, "-", label=legend_text)
plot_range_fill(max_val, min_val, alpha, axis)
if set_ylim_dynamically:
y_min = np.min(min_val) - 0.03*np.min(min_val)
y_max = np.max(max_val) + 0.03*np.max(max_val)
axis.set_ylim(y_min, y_max)
else:
axis.set_ylim(*fixed_y_lim)
axis.set_ylabel(key_val)
if legend_text is not None:
axis.legend()
axis.set_xlabel("Training step")
save_publication_graphic(fig_object=fig, file_name=file_name, prefix=save_dir)
[docs]
def plot_range_fill(
max_val: List[float], min_val: List[float], alpha: float, axis
) -> None:
"""The plot_range_lines function takes in three lists of floats, max_val,
min_val and avg_val. It then plots the average value as a line graph with
the training steps on the x-axis and the average values on the y-axis. It
also fills in between each point with a color to show the range of values
for that particular step.
:param max_val: List[float]: Used to Plot the maximum value of each
training step.
:param min_val: List[float]: Used to Plot the minimum value of each
metric.
:param avg_val: List[float]: Used to Plot the average value of a
given metric.
:return: A plot with the average value, max value and min values for
each training step. :doc-author: Julian M. Kleber
"""
training_steps = np.arange(0, len(max_val), 1)
axis.fill_between(training_steps, min_val, max_val, alpha=alpha)
[docs]
def save_publication_graphic(
fig_object: Type[plt.figure], file_name: str, prefix: Optional[str] = None
) -> None:
"""The save_publication_graphic function saves the current figure to a
file.
The save_publication_graphic function saves the current figure to a
file, with a default resolution of 300 dpi. The function also
tightens up the layout of the plot before saving it, so that there
is no wasted space around it in its saved form.
:param file_name: str: Used to Specify the name of the file to be
saved.
:param prefix: Optional[str]=None: Used to Specify the directory
where the file is saved.
:return: None. :doc-author: Julian M. Kleber
"""
file_name = prepare_file_name_saving(
file_name=file_name, prefix=prefix, suffix=".png"
)
plt.tight_layout()
plt.savefig(file_name, dpi=300)
logging.info("Saved plot" + str(file_name))
[docs]
def parse_min_max_avg(result_list: List[List[float]]) -> List[float]:
"""The parse_min function takes a list of lists and returns the minimum
value for each sublist.
:param result_list: List[List[float]]: Used to Store the results of
the simulation.
:return: A list of the minimum values for each step. :doc-author:
Julian M. Kleber
"""
minima = []
maxima = []
averages = []
for i in range(len(result_list[0])):
step_list = result_list[:, i]
minimum = get_min(step_list=step_list)
maximum = get_max(step_list=step_list)
average = get_avg(step_list=step_list)
minima.append(minimum)
maxima.append(maximum)
averages.append(average)
return [minima, maxima, averages]