PyTorch Training

flood_forecast.pytorch_training.train_transformer_style(model: flood_forecast.time_model.PyTorchForecast, training_params: Dict[KT, VT], takes_target=False, forward_params: Dict[KT, VT] = {}, model_filepath: str = 'model_save') → None[source]

Function to train any PyTorchForecast model :model The initialized PyTorchForecastModel :training_params_dict A dictionary of the parameters needed to train model :takes_target boolean: Determines whether to pass target during training :forward_params: A dictionary for additional forward parameters (for instance target)

flood_forecast.pytorch_training.torch_single_train(model: flood_forecast.time_model.PyTorchForecast, opt: torch.optim.optimizer.Optimizer, criterion: Type[torch.nn.modules.loss._Loss], data_loader: torch.utils.data.dataloader.DataLoader, takes_target: bool, forward_params: Dict[KT, VT] = {}) → float[source]
flood_forecast.pytorch_training.compute_validation(validation_loader: torch.utils.data.dataloader.DataLoader, model, epoch: int, sequence_size: int, criterion: Type[torch.nn.modules.loss._Loss], device: torch.device, decoder_structure=False, use_wandb: bool = False, val_or_test='validation_loss') → float[source]

Function to compute the validation or test loss