PyTorch Training¶
-
flood_forecast.pytorch_training.
train_transformer_style
(model: flood_forecast.time_model.PyTorchForecast, training_params: Dict[KT, VT], takes_target=False, forward_params: Dict[KT, VT] = {}, model_filepath: str = 'model_save') → None[source]¶ Function to train any PyTorchForecast model :model The initialized PyTorchForecastModel :training_params_dict A dictionary of the parameters needed to train model :takes_target boolean: Determines whether to pass target during training :forward_params: A dictionary for additional forward parameters (for instance target)
-
flood_forecast.pytorch_training.
torch_single_train
(model: flood_forecast.time_model.PyTorchForecast, opt: torch.optim.optimizer.Optimizer, criterion: Type[torch.nn.modules.loss._Loss], data_loader: torch.utils.data.dataloader.DataLoader, takes_target: bool, forward_params: Dict[KT, VT] = {}) → float[source]¶
-
flood_forecast.pytorch_training.
compute_validation
(validation_loader: torch.utils.data.dataloader.DataLoader, model, epoch: int, sequence_size: int, criterion: Type[torch.nn.modules.loss._Loss], device: torch.device, decoder_structure=False, use_wandb: bool = False, val_or_test='validation_loss') → float[source]¶ Function to compute the validation or test loss