Source code for flood_forecast.da_rnn.train_da

import typing
from typing import Tuple
import json
import os

import torch
from torch import nn
from torch import optim
import matplotlib.pyplot as plt

import numpy as np
from flood_forecast.da_rnn import utils
from flood_forecast.da_rnn.constants import device
from flood_forecast.da_rnn.modules import Encoder, Decoder
from flood_forecast.da_rnn.custom_types import DaRnnNet, TrainData, TrainConfig
from flood_forecast.da_rnn.utils import numpy_to_tvar
from torch.utils.tensorboard import SummaryWriter


logger = utils.setup_log()
logger.info(f"Using computation device: {device}")


[docs]def da_rnn(train_data: TrainData, n_targs: int, encoder_hidden_size=64, decoder_hidden_size=64, T=10, learning_rate=0.01, batch_size=128, param_output_path="", save_path: str = None) -> Tuple[dict, DaRnnNet]: """ n_targs: The number of target columns (not steps) T: The number timesteps in the window """ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using device: " + str(device)) train_cfg = TrainConfig(T, int(train_data.feats.shape[0] * 0.7), batch_size, nn.MSELoss()) logger.info(f"Training size: {train_cfg.train_size:d}.") enc_kwargs = { "input_size": train_data.feats.shape[1], "hidden_size": encoder_hidden_size, "T": T} encoder = Encoder(**enc_kwargs).to(device) with open(os.path.join(param_output_path, "enc_kwargs.json"), "w+") as fi: json.dump(enc_kwargs, fi, indent=4) dec_kwargs = {"encoder_hidden_size": encoder_hidden_size, "decoder_hidden_size": decoder_hidden_size, "T": T, "out_feats": n_targs} decoder = Decoder(**dec_kwargs).to(device) with open(os.path.join(param_output_path, "dec_kwargs.json"), "w+") as fi: json.dump(dec_kwargs, fi, indent=4) if save_path: # dir_path = os.path.dirname(os.path.realpath(__file__)) print("Resuming training from " + os.path.join(save_path, "encoder.pth")) encoder.load_state_dict( torch.load( os.path.join( save_path, "encoder.pth"), map_location=device)) decoder.load_state_dict( torch.load( os.path.join( save_path, "decoder.pth"), map_location=device)) encoder_optimizer = optim.Adam( params=[p for p in encoder.parameters() if p.requires_grad], lr=learning_rate) decoder_optimizer = optim.Adam( params=[p for p in decoder.parameters() if p.requires_grad], lr=learning_rate) da_rnn_net = DaRnnNet(encoder, decoder, encoder_optimizer, decoder_optimizer) return train_cfg, da_rnn_net
[docs]def train( net: DaRnnNet, train_data: TrainData, t_cfg: TrainConfig, train_config="", n_epochs=10, save_plots=True, wandb=False, tensorboard=False): if wandb: import wandb iter_per_epoch = int(np.ceil(t_cfg.train_size * 1. / t_cfg.batch_size)) iter_losses = np.zeros(n_epochs * iter_per_epoch) epoch_losses = np.zeros(n_epochs) logger.info( f"Iterations per epoch: {t_cfg.train_size * 1. / t_cfg.batch_size:3.3f} ~ {iter_per_epoch:d}.") n_iter = 0 if tensorboard: writer = SummaryWriter() for e_i in range(n_epochs): perm_idx = np.random.permutation(t_cfg.train_size - t_cfg.T) for t_i in range(0, t_cfg.train_size, t_cfg.batch_size): batch_idx = perm_idx[t_i:(t_i + t_cfg.batch_size)] feats, y_history, y_target = prep_train_data(batch_idx, t_cfg, train_data) if len(feats) > 0 and len(y_target) > 0 and len(y_history) > 0: loss = train_iteration(net, t_cfg.loss_func, feats, y_history, y_target) iter_losses[e_i * iter_per_epoch + t_i // t_cfg.batch_size] = loss n_iter += 1 adjust_learning_rate(net, n_iter) epoch_losses[e_i] = np.mean( iter_losses[range(e_i * iter_per_epoch, (e_i + 1) * iter_per_epoch)]) if e_i % 1 == 0: y_test_pred = predict(net, train_data, t_cfg.train_size, t_cfg.batch_size, t_cfg.T, on_train=False) # TODO: make this MSE and make it work for multiple inputs val_loss = y_test_pred - train_data.targs[t_cfg.train_size:] logger.info( f"Epoch {e_i:d}, train loss: {epoch_losses[e_i]:3.3f}, val loss: {np.mean(np.abs(val_loss))}.") y_train_pred = predict(net, train_data, t_cfg.train_size, t_cfg.batch_size, t_cfg.T, on_train=True) # train_mse = np.mean((y_train_pred-train_data.targs[:t_cfg.train_size])**2) mse = np.mean((y_test_pred - train_data.targs[t_cfg.train_size:])**2) if wandb: wandb.log({"epoch": n_epochs, "validation_loss": val_loss, "validation_mse": mse}) plt.figure() plt.plot(range(1, 1 + len(train_data.targs)), train_data.targs, label="True") plt.plot(range(t_cfg.T, len(y_train_pred) + t_cfg.T), y_train_pred, label='Predicted - Train') plt.plot(range(t_cfg.T + len(y_train_pred), len(train_data.targs) + 1), y_test_pred, label='Predicted - Test') plt.legend(loc='upper left') utils.save_or_show_plot(f"pred_{e_i}.png", save_plots) if tensorboard: # writer.add_scalar('Loss/Validation', val_loss, e_i) writer.add_scalar('Validation/MSE', mse, e_i) # Check MSE CALC # writer.add_scalar("Train/MSE", train_mse, e_i ) dir_path = os.path.dirname(os.path.realpath(__file__)) if not os.path.exists("checkpoint"): os.makedirs("checkpoint") print(os.path.join(dir_path, "checkpoint", "encoder.pth")) torch.save(net.encoder.state_dict(), os.path.join(dir_path, "checkpoint", "encoder.pth")) torch.save(net.decoder.state_dict(), os.path.join(dir_path, "checkpoint", "decoder.pth")) return [iter_losses, epoch_losses], net
[docs]def prep_train_data(batch_idx: np.ndarray, t_cfg: TrainConfig, train_data: TrainData) -> Tuple: feats = np.zeros((len(batch_idx), t_cfg.T - 1, train_data.feats.shape[1])) y_history = np.zeros((len(batch_idx), t_cfg.T - 1, train_data.targs.shape[1])) y_target = train_data.targs[batch_idx + t_cfg.T] for b_i, b_idx in enumerate(batch_idx): b_slc = slice(b_idx, b_idx + t_cfg.T - 1) feats[b_i, :, :] = train_data.feats[b_slc, :] y_history[b_i, :] = train_data.targs[b_slc] return feats, y_history, y_target
[docs]def adjust_learning_rate(net: DaRnnNet, n_iter: int) -> None: # TODO: Where did this Learning Rate adjustment schedule come from? # Should be modified to use Cosine Annealing with warm restarts # https://www.jeremyjordan.me/nn-learning-rate/ if n_iter % 10000 == 0 and n_iter > 0: for enc_params, dec_params in zip(net.enc_opt.param_groups, net.dec_opt.param_groups): enc_params['lr'] = enc_params['lr'] * 0.9 dec_params['lr'] = dec_params['lr'] * 0.9
[docs]def train_iteration(t_net: DaRnnNet, loss_func: typing.Callable, X, y_history, y_target): t_net.enc_opt.zero_grad() t_net.dec_opt.zero_grad() input_weighted, input_encoded = t_net.encoder(numpy_to_tvar(X)) y_pred = t_net.decoder(input_encoded, numpy_to_tvar(y_history)) y_true = numpy_to_tvar(y_target) loss = loss_func(y_pred, y_true) loss.backward() t_net.enc_opt.step() t_net.dec_opt.step() return loss.item()
[docs]def predict( t_net: DaRnnNet, t_dat: TrainData, train_size: int, batch_size: int, T: int, on_train=False): out_size = t_dat.targs.shape[1] if on_train: y_pred = np.zeros((train_size - T + 1, out_size)) else: y_pred = np.zeros((t_dat.feats.shape[0] - train_size, out_size)) for y_i in range(0, len(y_pred), batch_size): y_slc = slice(y_i, y_i + batch_size) batch_idx = range(len(y_pred))[y_slc] b_len = len(batch_idx) X = np.zeros((b_len, T - 1, t_dat.feats.shape[1])) y_history = np.zeros((b_len, T - 1, t_dat.targs.shape[1])) for b_i, b_idx in enumerate(batch_idx): if on_train: idx = range(b_idx, b_idx + T - 1) else: idx = range(b_idx + train_size - T, b_idx + train_size - 1) X[b_i, :, :] = t_dat.feats[idx, :] y_history[b_i, :] = t_dat.targs[idx] y_history = numpy_to_tvar(y_history) _, input_encoded = t_net.encoder(numpy_to_tvar(X)) y_pred[y_slc] = t_net.decoder(input_encoded, y_history).cpu().data.numpy() return y_pred