Source code for flood_forecast.pytorch_training
import torch
import torch.optim as optim
from typing import Type, Dict
from torch.utils.data import DataLoader
from flood_forecast.time_model import PyTorchForecast
from flood_forecast.model_dict_function import pytorch_opt_dict, pytorch_criterion_dict
from flood_forecast.transformer_xl.transformer_basic import greedy_decode
from flood_forecast.basic.linear_regression import simple_decode
from flood_forecast.training_utils import EarlyStopper
[docs]def train_transformer_style(
model: PyTorchForecast,
training_params: Dict,
takes_target=False,
forward_params: Dict = {},
model_filepath: str = "model_save") -> None:
"""
Function to train any PyTorchForecast model
:model The initialized PyTorchForecastModel
:training_params_dict A dictionary of the parameters needed to train model
:takes_target boolean: Determines whether to pass target during training
:forward_params: A dictionary for additional forward parameters (for instance target)
"""
use_wandb = model.wandb
es = None
if "early_stopping" in model.params:
es = EarlyStopper(model.params["early_stopping"]['patience'])
opt = pytorch_opt_dict[training_params["optimizer"]](
model.model.parameters(), **training_params["optim_params"])
criterion = pytorch_criterion_dict[training_params["criterion"]]
max_epochs = training_params["epochs"]
data_loader = DataLoader(
model.training,
batch_size=training_params["batch_size"],
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
validation_data_loader = DataLoader(
model.validation,
batch_size=training_params["batch_size"],
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
test_data_loader = DataLoader(model.test_data, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
if use_wandb:
import wandb
wandb.watch(model.model)
session_params = []
for epoch in range(max_epochs):
total_loss = torch_single_train(
model,
opt,
criterion,
data_loader,
takes_target,
forward_params)
print("The loss for epoch " + str(epoch))
print(total_loss)
use_decoder = False
if "use_decoder" in model.params:
use_decoder = True
valid = compute_validation(
validation_data_loader,
model.model,
epoch,
model.params["dataset_params"]["forecast_length"],
criterion,
model.device,
decoder_structure=use_decoder,
use_wandb=use_wandb)
if valid < 0.01:
raise("Error validation loss is zero there is a problem with the validator.")
if use_wandb:
wandb.log({'epoch': epoch, 'loss': total_loss})
epoch_params = {
"epoch": epoch,
"train_loss": str(total_loss),
"validation_loss": str(valid)}
session_params.append(epoch_params)
if es:
if not es.check_loss(model.model, valid):
print("Stopping model now")
model.model.load_state_dict(torch.load("checkpoint.pth"))
break
test = compute_validation(
test_data_loader,
model.model,
epoch,
model.params["dataset_params"]["forecast_length"],
criterion,
model.device,
decoder_structure=True,
use_wandb=use_wandb,
val_or_test="test_loss")
print("test loss:", test)
model.params["run"] = session_params
model.save_model(model_filepath, max_epochs)
[docs]def torch_single_train(model: PyTorchForecast,
opt: optim.Optimizer,
criterion: Type[torch.nn.modules.loss._Loss],
data_loader: DataLoader,
takes_target: bool,
forward_params: Dict = {}) -> float:
i = 0
running_loss = 0.0
for src, trg in data_loader:
opt.zero_grad()
# Convert to CPU/GPU/TPU
src = src.to(model.device)
trg = trg.to(model.device)
# TODO figure how to avoid
if takes_target:
forward_params["t"] = trg
output = model.model(src, **forward_params)
labels = trg[:, :, 0]
loss = criterion(output, labels.float())
if loss > 100:
print("Warning: high loss detected")
loss.backward()
opt.step()
if torch.isnan(loss) or loss == float('inf'):
raise("Error infinite or NaN loss detected. Try normalizing data or performing interpolation")
running_loss += loss.item()
i += 1
print("The running loss is:")
print(running_loss)
print("The number of items in train is: ")
print(i)
total_loss = running_loss / float(i)
return total_loss
[docs]def compute_validation(validation_loader: DataLoader,
model,
epoch: int,
sequence_size: int,
criterion: Type[torch.nn.modules.loss._Loss],
device: torch.device,
decoder_structure=False,
use_wandb: bool = False,
val_or_test="validation_loss") -> float:
"""
Function to compute the validation or test loss
"""
model.eval()
loop_loss = 0.0
with torch.no_grad():
i = 0
loss_unscaled_full = 0.0
for src, targ in validation_loader:
src = src.to(device)
targ = targ.to(device)
i += 1
if decoder_structure:
if type(model).__name__ == "SimpleTransformer":
targ_clone = targ.detach().clone()
output = greedy_decode(
model,
src,
targ.shape[1],
targ_clone,
device=device)[
:,
:,
0]
else:
output = simple_decode(model, src, targ.shape[1], targ, 1)[:, :, 0]
else:
output = model(src.float())
labels = targ[:, :, 0]
validation_dataset = validation_loader.dataset
if validation_dataset.scale:
# unscaled_src = validation_dataset.scale.inverse_transform(src.cpu())
unscaled_out = validation_dataset.inverse_scale(output.cpu())
unscaled_labels = validation_dataset.inverse_scale(labels.cpu())
loss_unscaled = criterion(unscaled_out, unscaled_labels.float())
loss_unscaled_full += len(labels.float()) * loss_unscaled.item()
if i % 10 == 0 and use_wandb:
import wandb
wandb.log({"trg": unscaled_labels, "model_pred": unscaled_out})
loss = criterion(output, labels.float())
loop_loss += len(labels.float()) * loss.item()
if use_wandb:
import wandb
if loss_unscaled_full:
tot_unscaled_loss = loss_unscaled_full / (len(validation_loader.dataset) - 1)
wandb.log({'epoch': epoch,
val_or_test: loop_loss / (len(validation_loader.dataset) - 1),
"unscaled_" + val_or_test: tot_unscaled_loss})
else:
wandb.log({'epoch': epoch, val_or_test: loop_loss /
(len(validation_loader.dataset) - 1)})
model.train()
return loop_loss / (len(validation_loader.dataset) - 1)