import matplotlib
import matplotlib.pyplot as plt
from smadi.metadata import indicators_thresholds
[docs]
def set_thresholds(method):
"""
Set the thresholds for the specified method based on the method name.
parameters:
-----------
method: str
The method name for which the thresholds are to be set. Supported methods are:
'zscore', 'smapi', 'smdi', 'smca', 'smad', 'smci', 'smds', 'essmi', 'beta', 'gamma'
"""
if method in indicators_thresholds.keys():
return indicators_thresholds[method]
else:
return None
[docs]
def get_plot_options(**kwargs):
"""
Set the basic plot options based on the provided kwargs for the plot.
parameters:
-----------
kwargs: dict
The keyword arguments for the matplotlib plot.
returns:
--------
plot_options: dict
The plot options for the figure.
"""
plot_options = {
"title": kwargs.get("title", None),
"xlabel": kwargs.get("xlabel", None),
"ylabel": kwargs.get("ylabel", None),
"legend": kwargs.get("legend", None),
"legend_labels": kwargs.get("legend_labels", None),
"figsize": kwargs.get("figsize", None),
"grid": kwargs.get("grid", None),
"savefig": kwargs.get("savefig", None),
}
return plot_options
[docs]
def plot_colmns(df, x_axis, colmns_kwargs):
"""
Plot the data in each column of the dataframe with the provided x_axis.
parameters:
-----------
df: pd.DataFrame
The dataframe containing the data to plot.
x_axis: list
The x-axis values for the plot.
colmns_kwargs: dict
The dictionary containing the column names and their respective matplotlib plot options.
"""
for colmn, kwargs in colmns_kwargs.items():
plt.plot(x_axis, df[colmn], **kwargs)
[docs]
def draw_hbars(thresholds, x_axis):
"""
Draw horizontal bars on the plot based on the provided thresholds for each anomaly method.
parameters:
-----------
thresholds: dict
The dictionary containing the thresholds for each category of the anomaly method.
x_axis: list
The x-axis values for the plot.
"""
thresholds = set_thresholds(thresholds)
for key, value in thresholds.items():
alpha = (
int(key.split("-")[1]) * 0.2
if key.startswith("D") or key.startswith("W")
else 0.1
)
color = (
"brown"
if key.startswith("D")
else "blue" if key.startswith("W") else "green"
)
plt.fill_between(x_axis, value[0], value[1], color=color, alpha=alpha)
[docs]
def clss_counter(df, columns, thresholds):
"""
Count the number of values in the dataframe that fall within the thresholds for each category of the anomaly method.
parameters:
-----------
df: pd.DataFrame
The dataframe containing the data to plot.
columns: dict
The dictionary containing the column names and their respective matplotlib plot options.
thresholds: str
The name of the anomaly method to use its thresholds.
"""
results = []
anomaly_thresholds = set_thresholds(thresholds)
if len(columns) > 2:
raise ValueError("The number of columns should not exceed 2.")
for colm, _ in columns.items():
category_counter = {}
for key, value in anomaly_thresholds.items():
category_counter[key] = df[colm].between(value[0], value[1]).sum()
results.append(category_counter)
return results
[docs]
def plot_categories_count(x_axis, results, anomaly_method):
"""
Plot the number of values in each category of the anomaly method that fall within the thresholds.
parameters:
-----------
x_axis: list
The x-axis values for the plot.
results: list
The list containing the number of values in each category of the anomaly method.
anomaly_method: str
The name of the anomaly method to use its thresholds.
"""
for i, result in enumerate(results):
for key, value in result.items():
thresholds = set_thresholds(anomaly_method)
y = thresholds[key][1]
x = x_axis[0] if i == 0 else x_axis[-1]
halignment = "right" if i == 0 else "left"
plt.text(
x=x,
y=y,
s=f"{key}:{value}",
fontsize=10,
color="black",
horizontalalignment=halignment,
fontstyle="italic",
in_layout=True,
weight="bold",
)
[docs]
def plot_anomaly(
df, x_axis, colmns, thresholds, plot_hbars=True, plot_categories=True, **kwargs
):
"""
Plot the anomaly detection results for the provided dataframe.
parameters:
-----------
df: pd.DataFrame
The dataframe containing the data to plot.
x_axis: list
The x-axis values for the plot.
colmns: dict
The dictionary containing the column names and their respective matplotlib plot options.
thresholds: str
The name of the anomaly method to use its thresholds.
plot_hbars: bool
Whether to plot the horizontal bars on the plot according to the thresholds of the anomaly method used.
plot_categories: bool
Whether to plot the number of values in each category of the anomaly method that fall within the thresholds.
kwargs: dict
The keyword arguments for the matplotlib plot for the figure such as title, xlabel, ylabel, legend, figsize, and grid.
"""
# Set values for kwargs based on provided values
plot_params = get_plot_options(**kwargs)
plt.figure(figsize=plot_params["figsize"])
plot_colmns(df, x_axis, colmns)
if plot_hbars:
draw_hbars(thresholds, x_axis)
if plot_categories:
results = clss_counter(df, colmns, thresholds)
plot_categories_count(x_axis, results, thresholds)
plot_figure(plot_params)
[docs]
def plot_ts(
df,
x_axis,
colmns_kwargs,
plot_raw=False,
clim_obj=None,
raw_var=None,
raw_kwargs=None,
**kwargs,
):
"""
Plot the time series data for the provided dataframe.
parameters:
-----------
df: pd.DataFrame
The dataframe containing the data to plot.
x_axis: list
The x-axis values for the plot.
colmns_kwargs: dict
The dictionary containing the column names and their respective matplotlib plot options.
plot_raw: bool
Whether to plot the raw data on the plot as background.
clim_obj: Climatology
The climatology object containing the original data.
raw_var: str
The name of the raw variable to plot.
raw_kwargs: dict
The dictionary containing the matplotlib plot options for the raw data.
kwargs: dict
The keyword arguments for the matplotlib plot for the figure such as title, xlabel, ylabel, legend, figsize, and grid.
"""
# Set values for kwargs based on provided values
plot_params = get_plot_options(**kwargs)
if plot_params["figsize"] is not None:
plt.figure(figsize=plot_params["figsize"])
if plot_raw:
plt.plot(
clim_obj.original_df.index,
clim_obj.original_df[raw_var],
**raw_kwargs if raw_kwargs else {},
)
plot_colmns(df, x_axis, colmns_kwargs)
plot_figure(plot_params)