Source code for flood_forecast.time_model

from abc import ABC, abstractmethod
from typing import Dict
import torch
import json
import os
from datetime import datetime
from flood_forecast.model_dict_function import pytorch_model_dict
from flood_forecast.pre_dict import scaler_dict
from flood_forecast.preprocessing.pytorch_loaders import CSVDataLoader
from flood_forecast.gcp_integration.basic_utils import get_storage_client, upload_file
import wandb

[docs]class TimeSeriesModel(ABC): """ An abstract class used to handle different configurations of models + hyperparams for training, test, and predict functions. This class assumes that data is already split into test train and validation at this point. """ def __init__( self, model_base: str, training_data: str, validation_data: str, test_data: str, params: Dict): self.params = params if "weight_path" in params: self.model = self.load_model(model_base, params["model_params"], params["weight_path"]) else: self.model = self.load_model(model_base, params["model_params"]) params["dataset_params"]["forecast_test_len"] = params["inference_params"]["hours_to_forecast"] = self.make_data_load(training_data, params["dataset_params"], "train") self.validation = self.make_data_load(validation_data, params["dataset_params"], "valid") self.test_data = self.make_data_load(test_data, params["dataset_params"], "test") if "GCS" in self.params and self.params["GCS"]: self.gcs_client = get_storage_client() else: self.gcs_client = None self.wandb = self.wandb_init()
[docs] @abstractmethod def load_model(self, model_base: str, model_params: Dict, weight_path=None) -> object: """ This function should load and return the model this will vary based on the underlying framework used """ raise NotImplementedError
[docs] @abstractmethod def make_data_load(self, data_path, params: Dict, loader_type: str) -> object: """ Intializes a data loader based on the provided data_path. This may be as simple as a pandas dataframe or as complex as a custom PyTorch data loader. """ raise NotImplementedError
[docs] @abstractmethod def save_model(self, output_path: str): """ Saves a model to a specific path along with a configuration report of the parameters and data info. """ raise NotImplementedError
[docs] def upload_gcs(self, save_path: str, name: str, file_type: str, epoch=0, bucket_name=None): """ Function to upload model checkpoints to GCS """ if self.gcs_client: if bucket_name is None: bucket_name = os.environ["MODEL_BUCKET"] print("Data saved to: ") print(name) upload_file(bucket_name, os.path.join("experiments", name), save_path, self.gcs_client) online_path = os.path.join("gs://", bucket_name, "experiments", name) if self.wandb: wandb.config.update({"gcs_m_path_" + str(epoch) + file_type: online_path})
[docs] def wandb_init(self): if self.params["wandb"]: wandb.init( project=self.params["wandb"]["project"], config=self.params, name=self.params["wandb"]["name"], tags=self.params["wandb"]["tags"]) return True elif "sweep" in self.params: print("Using Wandb config:") print(wandb.config) return True return False
[docs]class PyTorchForecast(TimeSeriesModel): def __init__( self, model_base: str, training_data, validation_data, test_data, params_dict: Dict): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') super().__init__(model_base, training_data, validation_data, test_data, params_dict) print("Torch is using " + str(self.device))
[docs] def load_model(self, model_base: str, model_params: Dict, weight_path: str = None, strict=True): if model_base in pytorch_model_dict: model = pytorch_model_dict[model_base](**model_params) if weight_path: checkpoint = torch.load(weight_path, map_location=self.device) if "weight_path_add" in self.params: if "excluded_layers" in self.params["weight_path_add"]: excluded_layers = self.params["weight_path_add"]["excluded_layers"] for layer in excluded_layers: del checkpoint[layer] print("sucessfully deleted layers") strict = False model.load_state_dict(checkpoint, strict=strict) print("Weights sucessfully loaded") # TODO create a general loop to convert all model tensor params to device if hasattr(model, "mask"): model.mask = if hasattr(model, "tgt_mask"): model.tgt_mask = else: raise Exception( "Error the model " + model_base + " was not found in the model dict. Please add it.") return model
[docs] def save_model(self, final_path: str, epoch: int) -> None: """ Function to save a model to a given file path """ if not os.path.exists(final_path): os.mkdir(final_path) time_stamp ="%d_%B_%Y%I_%M%p") model_name = time_stamp + "_model.pth" params_name = time_stamp + ".json" model_save_path = os.path.join(final_path, model_name) params_save_path = os.path.join(final_path, time_stamp + ".json"), model_save_path) with open(params_save_path, "w+") as p: json.dump(self.params, p) self.upload_gcs(params_save_path, params_name, "_params", epoch) self.upload_gcs(model_save_path, model_name, "_model", epoch) if self.wandb: wandb.config.save_path = model_save_path
[docs] def make_data_load( self, data_path: str, dataset_params: Dict, loader_type: str, the_class="default"): start_end_params = {} # TODO clean up else if blocks if loader_type + "_start" in dataset_params: start_end_params["start_stamp"] = dataset_params[loader_type + "_start"] if loader_type + "_end" in dataset_params: start_end_params["end_stamp"] = dataset_params[loader_type + "_end"] if "scaler" in dataset_params: start_end_params["scaling"] = scaler_dict[dataset_params["scaler"]] if "interpolate" in dataset_params: start_end_params["interpolate_param"] = dataset_params["interpolate"] if loader_type == "test" and "forecast_test_len" in dataset_params: loader = CSVDataLoader( data_path, dataset_params["forecast_history"], dataset_params["forecast_test_len"], dataset_params["target_col"], dataset_params["relevant_cols"], **start_end_params) elif the_class == "default": loader = CSVDataLoader( data_path, dataset_params["forecast_history"], dataset_params["forecast_length"], dataset_params["target_col"], dataset_params["relevant_cols"], **start_end_params) else: # TODO support custom DataLoader loader = None return loader