Source code for smadi.workflow

"""
run_workflow.py - SMADI Workflow Execution

"""

from argparse import ArgumentParser, Namespace, ArgumentError
from typing import List, Tuple, Union, Dict
from concurrent.futures import ProcessPoolExecutor
from functools import partial
import pandas as pd
import numpy as np


from smadi.data_reader import AscatData, read_era5
from smadi.utils import (
    create_logger,
    log_exception,
    log_time,
    load_gpis_by_country,
    get_gpis_from_bbox,
)
from smadi.anomaly_detectors import (
    ZScore,
    SMAPI,
    SMDI,
    SMCA,
    SMAD,
    SMCI,
    SMDS,
    ESSMI,
    ParaDis,
)

# Supported anomaly detection methods
_Detectors = {
    "zscore": ZScore,
    "smapi-mean": SMAPI,
    "smapi-median": SMAPI,
    "smdi": SMDI,
    "smca-mean": SMCA,
    "smca-median": SMCA,
    "smad": SMAD,
    "smci": SMCI,
    "smds": SMDS,
    "essmi": ESSMI,
    "beta": ParaDis,
    "gamma": ParaDis,
}


[docs] def setup_argument_parser() -> ArgumentParser: """ Setup argument parser for SMADI workflow execution. """ parser = ArgumentParser( description="Run the SMADI workflow for anomaly detection on ASCAT data" ) # Required arguments parser.add_argument( "data_path", metavar="data_path", type=str, help="Path to the ASCAT data" ) parser.add_argument( "aoi", metavar="aoi", type=str, help="Country name or bounding box coordinates\ in str format 'lon_min, lon_max , lat_min, lat_max'", ) parser.add_argument( "time_step", metavar="time_step", type=str, default="month", choices=["month", "dekad", "week", "bimonth", "day"], help="The time step for the climatology calculation. Supported values: month, dekad, week, bimonth, day", ) # Optional arguments parser.add_argument( "--data_read_bulk", type=bool, default=False, help="Read data in bulk mode. If 'True' all data will be read in memory", ) parser.add_argument( "--variable", metavar="variable", type=str, default="sm", help="The variable to be used for the anomaly detection.", ) parser.add_argument( "--year", metavar="year", type=int, nargs="*", default=None, required=True, choices=range(2007, 2023), help="The year(s) for the date parameters", ) parser.add_argument( "--month", metavar="month", type=int, nargs="*", default=None, choices=range(1, 13), help="The month(s) for the date parameters", ) parser.add_argument( "--dekad", metavar="dekad", type=int, nargs="*", default=None, choices=(1, 2, 3), help="The dekad(s) for the date parameters", ) parser.add_argument( "--week", metavar="week", type=int, nargs="*", default=None, choices=range(1, 53), help="The week(s) for the date parameters", ) parser.add_argument( "--bimonth", metavar="bimonth", type=int, nargs="*", default=None, choices=(1, 2), help="The bimonth(s) for the date parameters", ) parser.add_argument( "--day", metavar="day", type=int, nargs="*", default=None, choices=range(1, 32), help="The day(s) for the date parameters", ) parser.add_argument( "--methods", metavar="methods", type=str, nargs="*", default=["zscore"], help="Anomaly detection methods. Supported methods: zscore, smapi-mean,\ smapi-median, smdi, smca-mean, smca-median, smad, smci, smds, essmi,\ beta, gamma, abs-mean, abs-median", ) parser.add_argument( "--timespan", metavar="timespan", type=list, default=None, help="To work on a subset of the data. Example: ['2012-01-01', '2012-12-31']", ) parser.add_argument( "--fillna", type=bool, default=False, help="Fill missing values" ) parser.add_argument("--fillna_size", type=int, default=3, help="Fillna window size") parser.add_argument("--smoothing", type=bool, default=False, help="Apply smoothing") parser.add_argument( "--smooth_size", type=int, default=31, help="Smoothing window size" ) parser.add_argument( "--workers", metavar="workers", type=int, default=None, help="The number of workers to determine the degree of concurrent processing in a multiprocessing setup", ) parser.add_argument( "--addi_retrive", metavar="addi_retrive", type=str, nargs="*", default=None, help="Additional parameters to retrieve from the anomaly dataframe. Supported values: 'var', 'norm' , 'abs'", ) parser.add_argument( "--save_to", type=str, default=None, help="Save the output to a file to the given path", ) return parser
[docs] def parse_arguments(parser): """ Parse the arguments and return the parsed arguments as a dictionary. returns: -------- parsed_args: dict The parsed arguments as a dictionary """ args: Namespace = parser.parse_args() aoi = args.aoi if "," in aoi: aoi = tuple(map(float, aoi.split(","))) parsed_args = { "data_path": args.data_path, "aoi": aoi, "variable": args.variable, "time_step": args.time_step, "data_read_bulk": args.data_read_bulk, "methods": args.methods, "workers": args.workers, "fillna": args.fillna, "fillna_size": args.fillna_size, "smoothing": args.smoothing, "smooth_size": args.smooth_size, "timespan": args.timespan, "year": args.year, "month": args.month, "dekad": args.dekad, "week": args.week, "bimonth": args.bimonth, "day": args.day, "addi_retrive": args.addi_retrive, "save_to": args.save_to, } return parsed_args
# Create a logger logger = create_logger("workflow-logger")
[docs] @log_exception(logger) def load_ts(gpi, variable="sm"): """ Load ASCAT time series for a given gpi """ ascat_ts = ascat_obj.read(gpi) valid = ascat_ts["num_sigma"] >= 2 ascat_ts.loc[~valid, ["sm", "sigma40", "slope40", "curvature40"]] = np.nan df = pd.DataFrame(ascat_ts.get(variable)) return df
def _validate_param(param_name: str, param_value: List[int]) -> None: """ Validate the date parameters (year, month, dekad, week, bimonth, day) for the anomaly detection workflow. Parameters: ----------- param_name: str The name of the parameter to validate param_value: List[int] The value of the parameter to validate Raises: ------- ValueError: If the parameter value is None If the parameter value is not a list of integers """ if param_value is None: raise ValueError(f"The '{param_name}' parameter must be provided") if not (isinstance(param_value, (list, int))): raise ValueError(f"The '{param_name}' parameter must be an int of list of ints") def _validate_required_params( time_step: str, required_params: Dict[str, List[str]], local_vars: Dict[str, List[int]], ) -> None: """ Check if the required parameters are provided for the given time_step. Parameters: ----------- time_step: str The time step for the anomaly detection workflow required_params: Dict[str, List[str]] A dictionary containing the required parameters for each time step local_vars: Dict[str, List[int]] A dictionary containing the local variables Raises: ------- ValueError: If any of the required parameters are missing """ missing_params = [ param for param in required_params[time_step] if local_vars.get(param) is None ] if missing_params: raise ValueError( f"For time_step '{time_step}', the following parameters must be provided: " f"{', '.join(missing_params)}" )
[docs] def validate_date_params( time_step: str, year: Union[int, List[int]] = None, month: Union[int, List[int]] = None, dekad: Union[int, List[int]] = None, week: Union[int, List[int]] = None, bimonth: Union[int, List[int]] = None, day: Union[int, List[int]] = None, ) -> Dict[str, List[int]]: """ Validate the date parameters for the anomaly detection workflow. """ year = [year] if isinstance(year, int) else year month = [month] if isinstance(month, int) else month dekad = [dekad] if isinstance(dekad, int) else dekad week = [week] if isinstance(week, int) else week bimonth = [bimonth] if isinstance(bimonth, int) else bimonth day = [day] if isinstance(day, int) else day date_param = {"year": year, "month": month} # Validating dekad or week based on time_step if time_step == "month": pass elif time_step == "dekad": date_param["dekad"] = dekad elif time_step == "bimonth": date_param["bimonth"] = bimonth elif time_step == "day": date_param["day"] = day elif time_step == "week": date_param["week"] = week date_param.pop("month") else: raise ValueError( f"Unsupported time_step: {time_step}. Supported time_steps are month, dekad, week, bimonth, day" ) # Validation for parameters for param_name, param_value in date_param.items(): _validate_param(param_name, param_value) # Checking if the value lists are of the same length if len(set(map(len, date_param.values()))) > 1: raise ValueError( "The length of the date parameters lists must be the same for multiple dates" ) # Checking if required parameters are provided required_params = { "month": ["year", "month"], "dekad": ["year", "month", "dekad"], "week": ["year", "week"], "bimonth": ["year", "month", "bimonth"], "day": ["year", "month", "day"], } local_vars = locals() _validate_required_params(time_step, required_params, local_vars) return date_param
[docs] def validate_anomaly_method(methods): """ Validate the names of the anomaly detection methods. Parameters: ----------- methods: List[str] A list of anomaly detection methods names to validate Raises: ------- ValueError: If any of the methods is not supported """ for method in methods: if method not in _Detectors.keys(): raise ValueError( f"Anomaly method '{method}' is not supported." f"Supported methods are one of the following: {tuple(_Detectors.keys())}" )
[docs] def addi_retrival( addi_retrive: list = ["var", "norm", "abs"], results: pd.DataFrame = None, anomaly: pd.DataFrame = None, variable: str = None, date_str: str = None, agg_metric: str = "mean", ): """ Validate and retrieve additional parameters from the anomaly dataframe. """ if addi_retrive is not None: # if not all(value in ["var", "norm"] for value in addi_retrive): # raise ValueError( # "The additional retrieval parameters must be one of the following: 'var', 'norm'" # ) if "var" in addi_retrive: results[f"{variable}-{agg_metric}" + f"({date_str})"] = anomaly[ f"{variable}-{agg_metric}" ].values if "norm" in addi_retrive: if "norm-mean" in anomaly.columns: results["norm-mean" + f"({date_str})"] = anomaly["norm-mean"].values if "norm-median" in anomaly.columns: results["norm-median" + f"({date_str})"] = anomaly["norm-median"].values if "abs" in addi_retrive: if "norm-mean" in anomaly.columns: results["abs-mean" + f"({date_str})"] = ( anomaly[f"{variable}-{agg_metric}"].values - anomaly["norm-mean"].values ) if "norm-median" in anomaly.columns: results["abs-median" + f"({date_str})"] = ( anomaly[f"{variable}-{agg_metric}"].values - anomaly["norm-median"].values ) return results
[docs] @log_exception(logger) def single_po_run( gpi: int, methods: str = ["zscore"], variable: str = "sm", time_step: str = "month", fillna: bool = False, fillna_window_size: int = None, smoothing: bool = False, smooth_window_size: int = None, year: Union[int, List[int]] = None, month: Union[int, List[int]] = None, dekad: Union[int, List[int]] = None, week: Union[int, List[int]] = None, bimonth: Union[int, List[int]] = None, day: Union[int, List[int]] = None, timespan: List[str] = None, addi_retrive: list = ["var", "norm"], agg_metric: list = "mean", mode: str = "anomaly", era5_path: str = None, ) -> Tuple[int, Dict[str, float]]: """ Run the anomaly detection workflow for a single grid point index. """ # Load the time series for the given gpi global ascat_obj df = load_ts(gpi, variable=variable) # Validate the date parameters date_params = validate_date_params( time_step, year, month, dekad, week, bimonth, day ) # Create a list of dictionaries containing the date parameters date_params = [ dict(zip(date_params.keys(), values)) for values in zip(*date_params.values()) ] # Define a dictionary to store the results results = {} for method in methods: # Define the anomaly detection parameters anomaly_params = { "df": df, "variable": variable, "time_step": time_step, "fillna": fillna, "fillna_window_size": fillna_window_size, "smoothing": smoothing, "smooth_window_size": smooth_window_size, "timespan": timespan, "agg_metric": agg_metric, } # If the method has a metric parameter (e.g. smapi-mean, smapi-median), set the metric parameter if "-" in method: anomaly_params["normal_metrics"] = [method.split("-")[1]] elif method in ["beta", "gamma"]: anomaly_params["dist"] = [method] try: for date_param in date_params: anomaly = _Detectors[method](**anomaly_params).detect_anomaly( **date_param ) # if mode == "anomaly": date_str = f"-".join(str(value) for value in date_param.values()) results[method + f"({date_str})"] = anomaly[method].values results = addi_retrival( addi_retrive=addi_retrive, results=results, anomaly=anomaly, variable=variable, date_str=date_str, agg_metric=agg_metric, ) # elif mode == "spi": # era5_df = read_era5(era5_path, loc=(lon, lat)) except AttributeError as e: return None return (gpi, results)
@log_exception(logger) def _finalize(result: Tuple[int, dict], df: pd.DataFrame, gpis_col="point"): try: gpi, anomaly = result except Exception as e: return df else: for method, value in anomaly.items(): df.loc[df[gpis_col] == gpi, method] = value return df
[docs] @log_time(logger) def run( aoi: Union[str, Tuple[float, float, float, float]], methods: Union[str, List[str]] = ["zscore"], variable: str = "sm", time_step: str = "month", fillna: bool = False, fillna_window_size: int = None, smoothing: bool = False, smooth_window_size: int = None, timespan: List[str] = None, year: List[int] = None, month: List[int] = None, dekad: List[int] = None, week: List[int] = None, bimonth: List[int] = None, day: List[int] = None, workers: int = None, addi_retrive: list = ["anomaly"], ) -> pd.DataFrame: """ Run the anomaly detection workflow for multiple grid point indices with multiprocessing support. """ logger.info("Workflow started....\n") logger.info(f"Loading grid points for {aoi}....") if isinstance(aoi, str): pointlist = load_gpis_by_country(aoi) else: pointlist = get_gpis_from_bbox(aoi) logger.info(f"Grid points loaded successfully for {aoi}\n") logger.info(f"\n\n{pointlist.head()}") pre_compute = partial( single_po_run, methods=methods, variable=variable, time_step=time_step, fillna=fillna, fillna_window_size=fillna_window_size, smoothing=smoothing, smooth_window_size=smooth_window_size, year=year, month=month, dekad=dekad, week=week, bimonth=bimonth, day=day, timespan=timespan, addi_retrive=addi_retrive, ) logger.info(f"Running the anomaly detection workflow for {aoi}....\n") logger.info("Workflow parameters:\n") logger.info(f" Anomaly detection methods: {', '.join(methods)}") logger.info(f" Variable: {variable}") logger.info(f" Time step for Climatology: {time_step}") logger.info(f" Date parameters:\n") date_parameters = { "Year": year, "Month": month, "Dekad": dekad, "Week": week, "Bimonth": bimonth, "Day": day, } for param, value in date_parameters.items(): if value: logger.info(f" {param}: {value}") with ProcessPoolExecutor(workers) as executor: results = executor.map(pre_compute, pointlist["point"]) for result in results: pointlist = _finalize(result, pointlist) return pointlist
[docs] def main(): parser = setup_argument_parser() args = parse_arguments(parser) # Create an instance of the AscatData class global ascat_obj ascat_obj = AscatData(args["data_path"], args["data_read_bulk"]) try: validate_anomaly_method(args["methods"]) validate_date_params( args["time_step"], args["year"], args["month"], args["dekad"], args["week"], args["bimonth"], args["day"], ) except ArgumentError as e: parser.error(str(e)) df = run( aoi=args["aoi"], methods=args["methods"], variable=args["variable"], time_step=args["time_step"], year=args["year"], month=args["month"], week=args["week"], day=args["day"], bimonth=args["bimonth"], dekad=args["dekad"], timespan=args["timespan"], fillna=args["fillna"], fillna_window_size=args["fillna_size"], smoothing=args["smoothing"], smooth_window_size=args["smooth_size"], workers=args["workers"], addi_retrive=args["addi_retrive"], ) logger.info(f"\n\n {df}") if args["save_to"]: try: df.to_csv(args["save_to"]) logger.info(f"Saving the output data frame to {args['save_to']}....") logger.info("Output saved successfully") except ArgumentError as e: parser.error(str(e))
if __name__ == "__main__": main()