Source code for flood_forecast.explain_model_output

import random
from datetime import datetime
from typing import Optional, Tuple
import numpy as np
import shap
import torch

import wandb
from flood_forecast.plot_functions import (
from flood_forecast.preprocessing.pytorch_loaders import CSVTestLoader


[docs] def handle_dl_output(dl, dl_class: str, datetime_start: datetime, device: str) -> Tuple[torch.Tensor, int]: """ :param dl: The test data-loader. Should be passed directly :type dl: Union[CSVTestLoader, TemporalTestLoader] :param dl_class: A string that is the name of DL passef from the params file :type dl_class: str :param datetime_start: The start datetime for the forecast :type datetime_start: datetime :param device: Typical device should be either cpu or cuda :type device: str :return: Returns a tuple containing either a.. :rtype: Tuple[torch.Tensor, int] """ if dl_class == "TemporalLoader": device = "cpu" his, tar, _, forecast_start_idx = dl.get_from_start_date(datetime_start) t = tar[1].unsqueeze(0).to(device) t1 = tar[0].unsqueeze(0).to(device) history = [his[0].unsqueeze(0).to(device), his[1].unsqueeze(0).to(device), t, t1] else: history, _, forecast_start_idx = dl.get_from_start_date(datetime_start) history = return history, forecast_start_idx
def _prepare_background_tensor( csv_test_loader: CSVTestLoader, backgound_batch_size: int = BACKGROUND_BATCH_SIZE ) -> torch.Tensor: """Generate background batches for deep explainer. Random sample batches as background data background tensor of size (batch_size, history_len, num_feature) Args: csv_test_loader (CSVTestLoader): test data loader backgound_batch_size (int): number of batches used as background data for deep explainer. Default to BACKGROUND_BATCH_SIZE. Returns: torch.Tensor: background tensor of size (batch_size, history_len, num_feature) """ background_data = csv_test_loader.original_df background_batches = csv_test_loader.convert_history_batches( csv_test_loader.df.columns, background_data ) # remove last batch in the list because it may not be of # size (history_len, num_feature) due to length of original df background_tensor = torch.stack( random.sample(background_batches[:-1], backgound_batch_size) ).float() return background_tensor
[docs] def deep_explain_model_summary_plot( model, csv_test_loader: CSVTestLoader, datetime_start: Optional[datetime] = None ) -> None: """Generate feature summary plot for trained deep learning models Args: model (object): trained model csv_test_loader (CSVTestLoader): test data loader datetime_start (datetime, optional): start date of the test prediction, Defaults to None, i.e. using model inference parameters. """ if model.params["model_name"] == "SimpleTransformer": print("SimpleTransformer currently not supported.") return use_wandb = model.wandb device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if model.params["model_name"] == "DARNN" and device.type == "cuda": print("DARNN does not work with shap on CUDA") return if datetime_start is None: datetime_start = model.params["inference_params"]["datetime_start"] history, forecast_start_idx = handle_dl_output(csv_test_loader, model.params["dataset_params"]["class"], datetime_start, device) background_tensor = _prepare_background_tensor(csv_test_loader) background_tensor = model.model.eval() # background shape (L, N, M) # L - batch size, N - history length, M - feature size s_values_list = [] if isinstance(history, list): model.model ="cpu") deep_explainer = shap.DeepExplainer(model.model, history) shap_values = deep_explainer.shap_values(history) s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) shap_values = deep_explainer.shap_values(background_tensor) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) # shap_values needs to be 4-dimensional if len(shap_values.shape) != 4: shap_values = np.expand_dims(shap_values, axis=0) shap_values = torch.tensor( shap_values, names=["preds", "batches", "observations", "features"] ) # summary plot shows overall feature ranking # by average absolute shap values fig = plot_summary_shap_values(shap_values, csv_test_loader.df.columns) abs_mean_shap_values = shap_values.mean(axis=["preds", "batches"]) multi_shap_values = abs_mean_shap_values.mean(axis="observations") if use_wandb: wandb.log({"Overall feature ranking by shap values": fig}) for idx, col in enumerate(csv_test_loader.df.columns): wandb.log({"shap_value_" + col: multi_shap_values}) # summary plot for multi-step outputs # multi_shap_values = shap_values.apply_along_axis(np.mean, 'batches') fig = plot_summary_shap_values_over_time_series( shap_values, csv_test_loader.df.columns ) if use_wandb: wandb.log({"Overall feature ranking per prediction time-step": fig}) # summary plot for one prediction at datetime_start if isinstance(history, list): hist = history[0] else: hist = history history_numpy = torch.tensor( hist.cpu().numpy(), names=["batches", "observations", "features"] ) shap_values = deep_explainer.shap_values(history) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: shap_values = np.expand_dims(shap_values, axis=0) shap_values = torch.tensor( shap_values, names=["preds", "batches", "observations", "features"] ) figs = plot_shap_values_from_history(shap_values, history_numpy) if use_wandb: for fig, feature in zip(figs, csv_test_loader.df.columns.tolist()): wandb.log( { "Feature ranking for prediction" f" at {datetime_start} - {feature}": fig } )
[docs] def fix_shap_values(shap_values, history): if isinstance(history, list): shap_values = list(zip(*shap_values))[0] return shap_values return shap_values
[docs] def deep_explain_model_heatmap( model, csv_test_loader: CSVTestLoader, datetime_start: Optional[datetime] = None ) -> None: """Generate feature heatmap for prediction at a start time Args: model ([type]): trained model csv_test_loader ([CSVTestLoader]): test data loader datetime_start (Optional[datetime], optional): start date of the test prediction, Defaults to None, i.e. using model inference parameters. Returns: None """ if model.params["model_name"] == "SimpleTransformer": print("SimpleTransformer currently not supported.") return elif "probabilistic" in model.params: print("Probabilistic currently not supported.") return use_wandb = model.wandb device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if model.params["model_name"] == "DARNN" and device.type == "cuda": # TO-DO check if this is still true print("Currently DARNN doesn't work with shap on CUDA") return if datetime_start is None: datetime_start = model.params["inference_params"]["datetime_start"] history, forecast_start_idx = handle_dl_output(csv_test_loader, model.params["dataset_params"]["class"], datetime_start, device) background_tensor = _prepare_background_tensor(csv_test_loader) background_tensor = model.model.eval() # background shape (L, N, M) # L - batch size, N - history length, M - feature size # for each element in each N x M batch in L, # attribute to each prediction in forecast len s_values_list = [] if isinstance(history, list): deep_explainer = shap.DeepExplainer(model.model, history) shap_values = deep_explainer.shap_values(history) s_values_list.append(shap_values) else: deep_explainer = shap.DeepExplainer(model.model, background_tensor) shap_values = deep_explainer.shap_values(background_tensor) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) # forecast_len x N x L x M if len(shap_values.shape) != 4: shap_values = np.expand_dims(shap_values, axis=0) shap_values = torch.tensor( shap_values, names=["preds", "batches", "observations", "features"] ) figs = plot_shap_value_heatmaps(shap_values) if use_wandb: for fig, feature in zip(figs, csv_test_loader.df.columns): wandb.log({f"Average prediction heatmaps - {feature}": fig}) # heatmap one prediction sequence at datetime_start # (seq_len*forecast_len) per fop feature to_explain = history shap_values = deep_explainer.shap_values(to_explain) shap_values = fix_shap_values(shap_values, history) shap_values = np.stack(shap_values) if len(shap_values.shape) != 4: shap_values = np.expand_dims(shap_values, axis=0) shap_values = torch.tensor( shap_values, names=["preds", "batches", "observations", "features"] ) # no fake ballo t figs = plot_shap_value_heatmaps(shap_values) if use_wandb: for fig, feature in zip(figs, csv_test_loader.df.columns): wandb.log( {"Heatmap for prediction " f"at {datetime_start} - {feature}": fig} )