import torch
import torch.optim as optim
from typing import Type, Dict, List, Union
from torch.utils.data import DataLoader
import json
import wandb
from flood_forecast.utils import numpy_to_tvar
from flood_forecast.time_model import PyTorchForecast
from flood_forecast.model_dict_function import pytorch_opt_dict, pytorch_criterion_dict
from flood_forecast.transformer_xl.transformer_basic import greedy_decode
from flood_forecast.basic.linear_regression import simple_decode
from flood_forecast.training_utils import EarlyStopper
from flood_forecast.custom.custom_opt import GaussianLoss, MASELoss
from flood_forecast.series_id_helper import handle_csv_id_output, handle_csv_id_validation
from torch.nn import CrossEntropyLoss
TEMPORAL_FEATS_MODELS = ["ITransformer", "Informer"]
[docs]
def multi_crit(crit_multi: List, output, labels, valid=None):
"""Used for computing the loss when there are multiple criteria.
:param crit_multi: The list of criteria to use for training.
:type crit_multi: List
:param output:
:type output: _type_
:param labels: _description_
:type labels: _type_
:param valid: _description_, defaults to None
:type valid: _type_, optional
:return: _description_
:rtype: _type_
"""
i = 0
loss = 0.0
for crit in crit_multi:
if len(output.shape) == 3:
loss += compute_loss(labels[:, :, i], output[:, :, i], torch.rand(1, 2), crit, valid)
else:
loss += compute_loss(labels[:, i], output[:, i], torch.rand(1, 2), crit, valid)
summed_loss = loss
return summed_loss
[docs]
def make_crit(model_params: Dict) -> Union[torch.nn.Module, List]:
"""A function to create the criterion for training from the parameters.
:param model_params: The training params Dict block in FF
:type model_params: Dict
"""
training_params = model_params
criterion_init_params = {}
if "criterion_params" in training_params:
criterion_init_params = training_params["criterion_params"]
if type(training_params["criterion"]) == list:
criterion = []
i = 0
for crit, param in zip(training_params["criterion"], criterion_init_params):
res = pytorch_criterion_dict[crit](**param)
i += 1
criterion.append(res)
else:
criterion = pytorch_criterion_dict[training_params["criterion"]](**criterion_init_params)
return criterion
[docs]
def handle_scaling(validation_dataset, src, output: torch.Tensor, labels, probabilistic, m, output_std):
"""Function that handles un-scaling the model output.
:param validation_dataset: A dataset object for the validation dataset. We use its inverse scale method.
:type validation_dataset: [type]
:param src: [description]
:type src: torch.Tensor
:param output: [description]
:type output: torch.Tensor
:param labels: [description]
:type labels: torch.Tensor
:param probabilistic: Whether the model is probablisitic or not.
:type probabilistic: bool
:param m: Whether there are multiple targets
:type m: int
:param output_std: [description]
:type output_std: [type]
:return: [description]
:rtype: [type]
"""
# To-do move to class fun ction
output_dist = None
if probabilistic:
unscaled_out = validation_dataset.inverse_scale(output)
try:
output_std = numpy_to_tvar(output_std)
except Exception:
pass
output_dist = torch.distributions.Normal(unscaled_out, output_std)
elif m > 1:
output = validation_dataset.inverse_scale(output.cpu())
labels = validation_dataset.inverse_scale(labels.cpu())
elif len(output.shape) == 3:
output = output.cpu().numpy().transpose(0, 2, 1)
labels = labels.cpu().numpy().transpose(0, 2, 1)
output = validation_dataset.inverse_scale(torch.from_numpy(output))
labels = validation_dataset.inverse_scale(torch.from_numpy(labels))
stuff = src.cpu().numpy().transpose(0, 2, 1)
src = validation_dataset.inverse_scale(torch.from_numpy(stuff))
else:
output = validation_dataset.inverse_scale(output.cpu().transpose(1, 0))
labels = validation_dataset.inverse_scale(labels.cpu().transpose(1, 0))
src = validation_dataset.inverse_scale(src.cpu().transpose(1, 0))
return src, output, labels, output_dist
[docs]
def compute_loss(labels, output, src, criterion, validation_dataset, probabilistic=None, output_std=None, m=1):
"""Function for computing the loss.
:param labels: The real values for the target. Shape can be variable but should follow (batch_size, time)
:type labels: torch.Tensor
:param output: The output of the model
:type output: torch.Tensor
:param src: The source values (only really needed for the MASELoss function)
:type src: torch.Tensor
:param criterion: The loss function to use
:type criterion: torch.nn.Loss or some variation
:param validation_dataset: Only passed when unscaling of data is needed.
:type validation_dataset: torch.utils.data.dataset
:param probabilistic: Whether the model is a probabalistic returns a distribution, defaults to None
:type probabilistic: bool, optional
:param output_std: The standard distribution, defaults to None
:type output_std: [type], optional
:param m: The number of targets defaults to 1
:type m: int, optional
:return: Returns the computed loss
:rtype: float
"""
if isinstance(criterion, GaussianLoss):
if len(output[0].shape) > 2:
g_loss = GaussianLoss(output[0][:, :, 0], output[1][:, :, 0])
else:
g_loss = GaussianLoss(output[0][:, 0], output[1][:, 0])
loss = g_loss(labels)
return loss
if not probabilistic and isinstance(output, torch.Tensor):
if len(labels.shape) != len(output.shape):
if len(labels.shape) > 1:
if labels.shape[1] == output.shape[1]:
labels = labels.unsqueeze(2)
else:
labels = labels.unsqueeze(0)
if probabilistic:
if type(output_std) != torch.Tensor:
output_std = torch.from_numpy(output_std)
if type(output) != torch.Tensor:
output = torch.from_numpy(output)
if len(output.shape) == 3:
output = output[:, :, 0]
output_std = output_std[:, :, 0]
output_dist = torch.distributions.Normal(output, output_std)
if validation_dataset:
src, output, labels, output_dist = handle_scaling(validation_dataset, src, output, labels,
probabilistic, m, output_std)
if probabilistic:
if len(labels.shape) != len(output.shape):
output_dist = output_dist[:, :, 0]
loss = -output_dist.log_prob(labels.float()).sum() # FIX THIS?
elif isinstance(criterion, MASELoss):
assert len(labels.shape) == len(output.shape)
loss = criterion(labels.float(), output, src, m)
elif isinstance(criterion, CrossEntropyLoss):
if len(labels.shape) > 2:
labels = labels.permute(0, 2, 1)
output = output.permute(0, 2, 1)
labels = labels.max(dim=1)[1]
loss = criterion(output, labels)
else:
assert len(labels.shape) == len(output.shape)
assert labels.shape[0] == output.shape[0]
loss = criterion(output, labels.float())
return loss
[docs]
def torch_single_train(model: PyTorchForecast,
opt: optim.Optimizer,
criterion: Type[torch.nn.modules.loss._Loss],
data_loader: DataLoader,
takes_target: bool,
meta_data_model: PyTorchForecast,
meta_data_model_representation: torch.Tensor,
meta_loss=None,
multi_targets=1,
forward_params: Dict = {}) -> float:
"""Function that performs training of a single model. Runs through one epoch of the data.
:param model: The PyTorchForecast model that is trained
:type model: PyTorchForecast
:param opt: The optimizer to use in the code
:type opt: optim.Optimizer
:param criterion: [m
:type criterion: Type[torch.nn.modules.loss._Loss]
:param data_loader: [description]
:type data_loader: DataLoader
:param takes_target: A boolean that indicates whether the model takes the target during training
:type takes_target: bool
:param meta_data_model: If supplied a model that handles meta-data else None.
:type meta_data_model: PyTorchForecast
:param meta_data_model_representation: [description]
:type meta_data_model_representation: torch.Tensor
:param meta_loss: [description], defaults to None
:type meta_loss: [type], optional
:param multi_targets: [description], defaults to 1
:type multi_targets: int, optional
:param forward_params: [description], defaults to {}
:type forward_params: Dict, optional
:raises ValueError: [description]
:return: [description]
:rtype: float
"""
probablistic = None
if "probabilistic" in model.params["model_params"]:
probablistic = True
print('running torch_single_train')
i = 0
output_std = None
mulit_targets_copy = multi_targets
running_loss = 0.0
for src, trg in data_loader:
opt.zero_grad()
if meta_data_model:
representation = meta_data_model.model.generate_representation(meta_data_model_representation)
forward_params["meta_data"] = representation
if meta_loss:
output = meta_data_model.model(meta_data_model_representation)
met_loss = compute_loss(meta_data_model_representation, output, torch.rand(2, 3, 2), meta_loss, None)
met_loss.backward()
if takes_target:
forward_params["t"] = trg
elif "TemporalLoader" == model.params["dataset_params"]["class"]:
forward_params["x_mark_enc"] = src[1].to(model.device)
forward_params["x_dec"] = trg[1].to(model.device)
forward_params["x_mark_dec"] = trg[0].to(model.device)
src = src[0]
pred_len = model.model.pred_len
trg = trg[0]
trg[:, -pred_len:, :] = torch.zeros_like(trg[:, -pred_len:, :].long()).float().to(model.device)
# Assign to avoid other if statement
if "SeriesIDLoader" == model.params["dataset_params"]["class"]:
running_loss += handle_csv_id_output(src, trg, model, criterion, opt, False, multi_targets)
i += 1
else:
src = src.to(model.device)
trg = trg.to(model.device)
output = model.model(src, **forward_params)
if hasattr(model.model, "pred_len"):
multi_targets = mulit_targets_copy
pred_len = model.model.pred_len
output = output[:, :, 0:multi_targets]
labels = trg[:, -pred_len:, 0:multi_targets]
multi_targets = False
if model.params["dataset_params"]["class"] == "GeneralClassificationLoader":
labels = trg
elif model.params["dataset_params"]["class"] == "CSVSeriesIDLoader":
labels = trg
elif multi_targets == 1:
labels = trg[:, :, 0]
elif multi_targets > 1:
labels = trg[:, :, 0:multi_targets]
if probablistic:
output1 = output
output = output.mean
output_std = output1.stddev
if type(criterion) == list:
loss = multi_crit(criterion, output, labels, None)
else:
loss = compute_loss(labels, output, src, criterion, None, probablistic, output_std, m=multi_targets)
if loss > 100:
print("Warning: high loss detected")
loss.backward()
opt.step()
if torch.isnan(loss) or loss == float('inf'):
raise ValueError("Error infinite or NaN loss detected. Try normalizing data or performing interpolation")
running_loss += loss.item()
i += 1
print("The running loss iss: ")
print(running_loss)
print("The number of items in train is: " + str(i))
total_loss = running_loss / float(i)
return total_loss
[docs]
def compute_validation(validation_loader: DataLoader,
model,
epoch: int,
sequence_size: int,
criterion: Type[torch.nn.modules.loss._Loss],
device: torch.device,
decoder_structure=False,
meta_data_model=None,
use_wandb: bool = False,
meta_model=None,
multi_targets=1,
val_or_test="validation_loss",
probabilistic=False,
classification=False) -> float:
"""Function to compute the validation loss metrics.
:param validation_loader: The data-loader of either validation or test-data
:type validation_loader: DataLoader
:param model: model
:type model: [type]
:param epoch: The epoch where the validation/test loss is being computed.
:type epoch: int
:param sequence_size: The length of the sequence (equivalent too
:type sequence_size: int
:param criterion: [description]
:type criterion: Type[torch.nn.modules.loss._Loss]
:param device: The device
:type device: torch.device
:param decoder_structure: Whether the model should use sequential decoding, defaults to False
:type decoder_structure: bool, optional
:param meta_data_model: The model to handle the meta-data, defaults to None
:type meta_data_model: PyTorchForecast, optional
:param use_wandb: Whether Weights and Biases is in use, defaults to False
:type use_wandb: bool, optional
:param meta_model: Whether the model leverages meta-data, defaults to None
:type meta_model: bool, optional
:param multi_targets: Whether the model, defaults to 1
:type multi_targets: int, optional
:param val_or_test: Whether validation or test loss is computed, defaults to "validation_loss"
:type val_or_test: str, optional
:param probabilistic: Whether the model is probablistic, defaults to False
:type probabilistic: bool, optional
:return: The loss of the first metric in the list.
:rtype: float
"""
print('Computing validation loss')
unscaled_crit = dict.fromkeys(criterion, 0)
scaled_crit = dict.fromkeys(criterion, 0)
model.eval()
output_std = None
multi_targs1 = multi_targets
scaler = None
if validation_loader.dataset.no_scale:
scaler = validation_loader.dataset
with torch.no_grad():
i = 0
loss_unscaled_full = 0.0
label_list = []
mod_output_list = []
for src, targ in validation_loader:
if validation_loader.dataset.__class__.__name__ == "CSVSeriesIDLoader":
scaled_crit = handle_csv_id_validation(src, targ, model, criterion, False, multi_targets)
unscaled_crit = {}
continue
src = src if isinstance(src, list) else src.to(device)
targ = targ if isinstance(targ, list) else targ.to(device)
# targ = targ if isinstance(targ, list) else targ.to(device)
i += 1
if decoder_structure:
if type(model).__name__ == "SimpleTransformer":
targ_clone = targ.detach().clone()
output = greedy_decode(
model,
src,
targ.shape[1],
targ_clone,
device=device)[
:,
:,
0]
elif type(model).__name__ in TEMPORAL_FEATS_MODELS:
multi_targets = multi_targs1
filled_targ = targ[1].clone()
pred_len = model.pred_len
filled_targ[:, -pred_len:, :] = torch.zeros_like(filled_targ[:, -pred_len:, :]).float().to(device)
output = model(src[0].to(device), src[1].to(device), filled_targ.to(device), targ[0].to(device))
labels = targ[1][:, -pred_len:, 0:multi_targets].to(device)
src = src[0]
assert output.shape[1] != 0
assert labels.shape[1] != 0
else:
output = simple_decode(model=model,
src=src,
max_seq_len=targ.shape[1],
real_target=targ,
output_len=sequence_size,
multi_targets=multi_targets,
probabilistic=probabilistic,
scaler=scaler)
if probabilistic:
output, output_std = output[0], output[1]
output, output_std = output[:, :, 0], output_std[0]
output_dist = torch.distributions.Normal(output, output_std)
else:
if probabilistic:
output_dist = model(src.float())
output = output_dist.mean.detach().numpy()
output_std = output_dist.stddev.detach().numpy()
else:
output = model(src.float())
mod_output_list.append(output)
if type(model).__name__ in TEMPORAL_FEATS_MODELS:
output = output[:, :, 0:multi_targets]
elif classification:
labels = targ
label_list.append(labels)
elif multi_targets == 1:
labels = targ[:, :, 0]
elif multi_targets > 1:
labels = targ[:, :, 0:multi_targets]
validation_dataset = validation_loader.dataset
for crit in criterion:
if validation_dataset.scale:
if len(src.shape) == 2:
src = src.unsqueeze(0)
src1 = src[:, :, 0:multi_targets]
loss_unscaled_full = compute_loss(labels, output, src1, crit, validation_dataset,
probabilistic, output_std, m=multi_targets)
unscaled_crit[crit] += loss_unscaled_full.item() * len(labels.float())
loss = compute_loss(labels, output, src, crit, False, probabilistic, output_std, m=multi_targets)
scaled_crit[crit] += loss.item() * len(labels.float())
if use_wandb:
if loss_unscaled_full:
scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
newD = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in unscaled_crit.items()}
wandb.log({'epoch': epoch,
val_or_test: scaled,
"unscaled_" + val_or_test: newD})
else:
scaled = {k.__class__.__name__: v / (len(validation_loader.dataset) - 1) for k, v in scaled_crit.items()}
wandb.log({'epoch': epoch, val_or_test: scaled})
if classification:
print("Plotting test classification metrics")
label_list = torch.cat(label_list)
label_list = label_list[:, 0, :].detach().cpu()
mod_output1 = torch.cat(mod_output_list)[:, 0, :].detach().cpu()
d = torch.nn.Softmax(dim=1)
mod_output_final = d(mod_output1).numpy()
fin = label_list.max(dim=1)[1]
wandb.log({"roc_" + str(epoch): wandb.plot.roc_curve(fin, mod_output_final, classes_to_plot=None, labels=None,
title="roc_" + str(epoch))})
wandb.log({"pr": wandb.plot.pr_curve(fin, mod_output_final)})
wandb.log({"conf_": wandb.plot.confusion_matrix(probs=mod_output_final,
y_true=fin.detach().cpu().numpy(), class_names=None)})
model.train()
return list(scaled_crit.values())[0]