Source code for flood_forecast.training_utils

import torch


[docs] class EarlyStopper(object): """EarlyStopping handler can be used to stop the training if no improvement after a given number of events. Args: patience (int): Number of events to wait if no improvement and then stop the training. score_function (callable): It should be a function taking a single argument, an :class:`~ignite.engine.Engine` object, and return a score `float`. An improvement is considered if the score is higher. trainer (Engine): trainer engine to stop the run if no improvement. min_delta (float, optional): A minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to `min_delta`, will count as no improvement. cumulative_delta (bool, optional): It True, `min_delta` defines an increase since the last `patience` reset, otherwise, it defines an increase after the last event. Default value is False. Examples: .. code-block:: python from ignite.engine import Engine, Events from ignite.handlers import EarlyStopping def score_function(engine): val_loss = engine.state.metrics['nll'] return -val_loss handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). evaluator.add_event_handler(Events.COMPLETED, handler) """
[docs] def __init__( self, patience: int, min_delta: float = 0.0, cumulative_delta: bool = False, ): if patience < 1: raise ValueError("Argument patience should be positive integer.") if min_delta < 0.0: raise ValueError("Argument min_delta should not be a negative number.") self.patience = patience self.min_delta = min_delta self.cumulative_delta = cumulative_delta self.counter = 0 self.best_score = None
[docs] def check_loss(self, model, validation_loss) -> bool: score = validation_loss if self.best_score is None: self.save_model_checkpoint(model) self.best_score = score elif score + self.min_delta >= self.best_score: if not self.cumulative_delta and score > self.best_score: self.best_score = score self.counter += 1 print(self.counter) if self.counter >= self.patience: return False else: self.save_model_checkpoint(model) self.best_score = score self.counter = 0 return True
[docs] def save_model_checkpoint(self, model): torch.save(model.state_dict(), "checkpoint.pth") """_summary_ """