model_trainer module
ModelTrainer
¶
A class for training, buidling, compiling, and running specified deep learning models.
Attributes:
Name | Type | Description |
---|---|---|
config |
An object containing the configuration settings for model training. |
|
model_builder |
An instance of ModelBuilder for building the model. |
|
build_model |
A partial function for building the model with the specified model type. |
Source code in aces/model_trainer.py
class ModelTrainer:
"""
A class for training, buidling, compiling, and running specified deep learning models.
Attributes:
config: An object containing the configuration settings for model training.
model_builder: An instance of ModelBuilder for building the model.
build_model: A partial function for building the model with the specified model type.
"""
def __init__(self, config: Config):
"""
Initialize the ModelTrainer object.
Args:
config: An object containing the configuration settings for model training.
Attributes:
config: The configuration settings for model training.
model_builder: An instance of ModelBuilder for building the model.
build_model: A partial function for building the model with the specified model type.
"""
self.config = config
# @FIXME: This isn't producing reproducable results
if self.config.USE_SEED:
# producable results
import random
print(f"Using seed: {self.config.SEED}")
tf.random.set_seed(self.config.SEED)
np.random.seed(self.config.SEED)
random.seed(self.config.SEED)
# @ToDO: Create a way to autoload loss function from the list without if else
if self.config.LOSS == "custom_focal_tversky_loss":
self.config.LOSS = Metrics.focal_tversky_loss
self.LOSS_TXT = Metrics.focal_tversky_loss.__func__.__name__ # "focal_tversky_loss"
else:
self.config.LOSS_TXT = self.config.LOSS
self.model_builder = ModelBuilder(
features=self.config.FEATURES,
out_classes=self.config.OUT_CLASS_NUM,
optimizer=self.config.OPTIMIZER,
loss=self.config.LOSS
)
self.build_model = partial(self.model_builder.build_model, model_type=self.config.MODEL_TYPE,
**{"FOR_AI_PLATFORM": self.config.USE_AI_PLATFORM,
"DERIVE_FEATURES": self.config.DERIVE_FEATURES if hasattr(self.config, "DERIVE_FEATURES") else False,})
def train_model(self) -> None:
"""
Train the model using the provided configuration settings.
This method performs the following steps:
1. Configures memory growth for TensorFlow.
2. Creates TensorFlow datasets for training, testing, and validation.
3. Builds and compiles the model.
4. Prepares the output directory for saving models and results.
5. Starts the training process.
6. Evaluates and prints validation metrics.
7. Saves training parameters, plots, and models.
"""
print("****************************************************************************")
print("****************************** Clear Session... ****************************")
keras.backend.clear_session()
print("****************************************************************************")
print(f"****************************** Configure memory growth... ************************")
physical_devices = TFUtils.configure_memory_growth()
self.config.physical_devices = physical_devices
print("****************************************************************************")
print("****************************** creating datasets... ************************")
self.create_datasets(print_info=self.config.PRINT_INFO)
if self.config.USE_AI_PLATFORM:
print("****************************************************************************")
print("******* building and compiling model for ai platform... ********************")
self.build_and_compile_model_ai_platform()
else:
print("****************************************************************************")
print("************************ building and compiling model... *******************")
self.build_and_compile_model(print_model_summary=True)
print("****************************************************************************")
print("************************ preparing output directory... *********************")
self.prepare_output_dir()
print("****************************************************************************")
print("****************************** training model... ***************************")
self.start_training()
if self.config.USE_AI_PLATFORM:
print(self.model.summary())
print("****************************************************************************")
print("****************************** evaluating model... *************************")
self.evaluate_and_print_val()
print("****************************************************************************")
print("****************************** saving parameters... ************************")
ModelTrainer.save_parameters(**self.config.__dict__)
print("****************************************************************************")
print("*************** saving model config and history object... ******************")
self.save_history_object()
if self.config.USE_AI_PLATFORM:
ModelTrainer.save_model_config(self.config.MODEL_SAVE_DIR, **self._model.get_config())
else:
ModelTrainer.save_model_config(self.config.MODEL_SAVE_DIR, **self.model.get_config())
print("****************************************************************************")
print("****************************** saving plots... *****************************")
self.save_plots()
print("****************************************************************************")
print("****************************** saving models... ****************************")
self.save_models()
print("****************************************************************************")
def prepare_output_dir(self) -> None:
"""
Prepare the output directory for saving models and results.
Creates a directory with a timestamped name and increments the version number if necessary.
"""
if not self.config.AUTO_MODEL_DIR_NAME:
self.config.MODEL_SAVE_DIR = self.config.OUTPUT_DIR / self.config.MODEL_DIR_NAME
print(f"> Saving models and results at {self.config.MODEL_SAVE_DIR}...")
if not os.path.exists(self.config.MODEL_SAVE_DIR):
os.mkdir(self.config.MODEL_SAVE_DIR)
else:
today = datetime.date.today().strftime("%Y_%m_%d")
iterator = 1
while True:
model_dir_name = f"trial_{self.config.MODEL_TYPE}_{today}_v{iterator}"
self.config.MODEL_SAVE_DIR = self.config.OUTPUT_DIR / model_dir_name
try:
os.mkdir(self.config.MODEL_SAVE_DIR)
except FileExistsError:
print(f"> {self.config.MODEL_SAVE_DIR} exists, creating another version...")
iterator += 1
continue
else:
print(f"> Saving models and results at {self.config.MODEL_SAVE_DIR}...")
break
def create_datasets(self, print_info: bool = False) -> None:
"""
Create TensorFlow datasets for training, testing, and validation.
Args:
print_info: Flag indicating whether to print dataset information.
Prints information about the created datasets if print_info is set to True.
"""
self.TRAINING_DATASET = DataProcessor.get_dataset(
f"{str(self.config.TRAINING_DIR)}/*",
self.config.FEATURES,
self.config.LABELS,
self.config.PATCH_SHAPE[0],
self.config.BATCH_SIZE,
self.config.OUT_CLASS_NUM,
**{**self.config.__dict__, "training": True},
).repeat()
self.VALIDATION_DATASET = DataProcessor.get_dataset(
f"{str(self.config.VALIDATION_DIR)}/*",
self.config.FEATURES,
self.config.LABELS,
self.config.PATCH_SHAPE[0],
1,
self.config.OUT_CLASS_NUM,
**self.config.__dict__,
).repeat()
self.TESTING_DATASET = DataProcessor.get_dataset(
f"{str(self.config.TESTING_DIR)}/*",
self.config.FEATURES,
self.config.LABELS,
self.config.PATCH_SHAPE[0],
1,
self.config.OUT_CLASS_NUM,
**self.config.__dict__,
)
if print_info:
print("Printing dataset info:")
DataProcessor.print_dataset_info(self.TRAINING_DATASET, "Training")
DataProcessor.print_dataset_info(self.TESTING_DATASET, "Testing")
DataProcessor.print_dataset_info(self.VALIDATION_DATASET, "Validation")
def build_and_compile_model(self, print_model_summary: bool = True) -> None:
"""
Build and compile the model.
Args:
print_model_summary: Flag indicating whether to print the model summary.
Builds and compiles the model using the provided configuration settings.
Prints the model summary if print_model_summary is set to True.
"""
self.model = self.build_model(**self.config.__dict__)
if print_model_summary: print(self.model.summary())
def build_and_compile_model_ai_platform(self) -> None:
"""
Build and compile the model.
Args:
print_model_summary: Flag indicating whether to print the model summary.
Builds and compiles the model using the provided configuration settings.
Prints the model summary if print_model_summary is set to True.
"""
model, wrapped_model = self.build_model(**self.config.__dict__)
print(model.summary())
self._model = model
self.model = wrapped_model
def start_training(self) -> None:
"""
Start the training process.
Trains the model using the provided configuration settings and callbacks.
"""
model_checkpoint = callbacks.ModelCheckpoint(
f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_CHECKPOINT_NAME}",
monitor=self.config.CALLBACK_PARAMETER,
save_best_only=True,
mode="auto",
verbose=1,
save_weights_only=False,
) # save best model
tensorboard = callbacks.TensorBoard(log_dir=str(self.config.MODEL_SAVE_DIR / "logs"), write_images=True)
def lr_scheduler(epoch):
if epoch < self.config.RAMPUP_EPOCHS:
return self.config.MAX_LR
elif epoch < self.config.RAMPUP_EPOCHS + self.config.SUSTAIN_EPOCHS:
return self.config.MID_LR
else:
return self.config.MIN_LR
model_callbacks = [model_checkpoint, tensorboard]
if self.config.USE_ADJUSTED_LR:
lr_callback = callbacks.LearningRateScheduler(lambda epoch: lr_scheduler(epoch), verbose=True)
model_callbacks.append(lr_callback)
if self.config.EARLY_STOPPING:
early_stopping = callbacks.EarlyStopping(
monitor=self.config.CALLBACK_PARAMETER,
patience=int(0.3 * self.config.EPOCHS),
verbose=1,
mode="auto",
restore_best_weights=True,
)
model_callbacks.append(early_stopping)
self.model_callbacks = model_callbacks
self.history = self.model.fit(
x=self.TRAINING_DATASET,
epochs=self.config.EPOCHS,
steps_per_epoch=(self.config.TRAIN_SIZE // self.config.BATCH_SIZE),
validation_data=self.VALIDATION_DATASET,
validation_steps=self.config.VAL_SIZE,
callbacks=model_callbacks,
)
# either save the wrapped model or the original model
# named as "trained-model" to avoid confusion
# self.model.save(f"{self.config.MODEL_SAVE_DIR}/trained-wrapped-model")
self.model.save(f"{self.config.MODEL_SAVE_DIR}/trained-model")
def evaluate_and_print_val(self) -> None:
"""
Evaluate and print validation metrics.
Evaluates the model on the validation dataset and prints the metrics.
"""
print("************************************************")
print("************************************************")
print("Validation")
# Tip: You can remove steps=self.config.TEST_SIZE and match the TEST_SIZE from the env
evaluate_results = self.model.evaluate(self.TESTING_DATASET) # , steps=self.config.TEST_SIZE
with open(f"{self.config.MODEL_SAVE_DIR}/evaluation.txt", "w") as evaluate:
evaluate.write(json.dumps(dict(zip(self.model.metrics_names, evaluate_results))))
for name, value in zip(self.model.metrics_names, evaluate_results):
print(f"{name}: {value}")
print("\n")
@staticmethod
def save_parameters(**config) -> None:
"""
Save the training parameters to a text file.
Saves the training parameters used in the configuration settings to a text file.
"""
with open(f"{str(config.get('MODEL_SAVE_DIR'))}/parameters.txt", "w") as f:
f.write(f"TRAIN_SIZE: {config.get('TRAIN_SIZE')}\n")
f.write(f"TEST_SIZE: {config.get('TEST_SIZE')}\n")
f.write(f"VAL_SIZE: {config.get('VAL_SIZE')}\n")
f.write(f"BATCH_SIZE: {config.get('BATCH_SIZE')}\n")
f.write(f"EPOCHS: {config.get('EPOCHS')}\n")
f.write(f"LOSS: {config.get('LOSS_TXT')}\n")
f.write(f"TRAINING_DIR: {config.get('TRAINING_DIR')}\n")
f.write(f"TESTING_DIR: {config.get('TESTING_DIR')}\n")
f.write(f"VALIDATION_DIR: {config.get('VALIDATION_DIR')}\n")
if config.get('USE_ADJUSTED_LR'):
f.write(f"USE_ADJUSTED_LR: {config.get('USE_ADJUSTED_LR')}\n")
f.write(f"MAX_LR: {config.get('MAX_LR')}\n")
f.write(f"MID_LR: {config.get('MID_LR')}\n")
f.write(f"MIN_LR: {config.get('MIN_LR')}\n")
f.write(f"RAMPUP_EPOCHS: {config.get('RAMPUP_EPOCHS')}\n")
f.write(f"SUSTAIN_EPOCHS: {config.get('SUSTAIN_EPOCHS')}\n")
f.write(f"DROPOUT_RATE: {config.get('DROPOUT_RATE')}\n")
f.write(f"ACTIVATION_FN: {config.get('ACTIVATION_FN')}\n")
f.write(f"FEATURES: {config.get('FEATURES')}\n")
f.write(f"LABELS: {config.get('LABELS')}\n")
f.write(f"PATCH_SHAPE: {config.get('PATCH_SHAPE')}\n")
f.write(f"CALLBACK_PARAMETER: {config.get('CALLBACK_PARAMETER')}\n")
f.write(f"MODEL_TYPE: {config.get('MODEL_TYPE')}\n")
f.write(f"TRANSFORM_DATA: {config.get('TRANSFORM_DATA')}\n")
f.write(f"MODEL_NAME: {config.get('MODEL_NAME')}.h5\n")
f.write(f"MODEL_CHECKPOINT_NAME: {config.get('MODEL_CHECKPOINT_NAME')}.h5\n")
f.close()
@staticmethod
def save_model_config(save_dir, **model_config) -> None:
with open(f"{save_dir}/config.json", "w") as f:
json.dump(model_config, f, indent=4)
f.close()
def save_plots(self) -> None:
"""
Save plots and model visualization.
Saves the model architecture plot, training history plot, and model object.
"""
print(f"Saving plots and model visualization at {self.config.MODEL_SAVE_DIR}...")
Utils.plot_metrics([key.replace("val_", "") for key in self.history.history.keys() if key.startswith("val_")],
self.history.history, len(self.history.epoch), self.config.MODEL_SAVE_DIR)
if self.config.USE_AI_PLATFORM:
keras.utils.plot_model(self._model, f"{self.config.MODEL_SAVE_DIR}/model.png", show_shapes=True, rankdir="TB")
keras.utils.plot_model(self.model, f"{self.config.MODEL_SAVE_DIR}/wrapped_model.png", show_shapes=True, rankdir="LR") # rankdir='TB'
else:
keras.utils.plot_model(self.model, f"{self.config.MODEL_SAVE_DIR}/model.png", show_shapes=True, rankdir="TB") # rankdir='TB'
def save_history_object(self) -> None:
"""
Save the history object.
"""
with open(f"{self.config.MODEL_SAVE_DIR}/model.pkl", "wb") as f:
pickle.dump(self.history.history, f)
with open(f"{self.config.MODEL_SAVE_DIR}/model.txt", "w") as f:
f.write(json.dumps(self.history.history))
def load_and_save_models(self) -> None:
"""
Load the trained models.
Loads the trained models from different formats: h5 and tf formats.
"""
self.config.MODEL_SAVE_DIR = self.config.OUTPUT_DIR / self.config.MODEL_DIR_NAME
self.model = tf.keras.models.load_model(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_CHECKPOINT_NAME}.tf")
updated_model = self.serialize_model()
# if not issubclass(self.model.__class__, keras.Model):
# # Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model
# updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.tf", save_format="tf")
def save_models(self) -> None:
"""
Save the trained models.
Saves the trained models in different formats: h5 and tf formats.
"""
if self.config.USE_AI_PLATFORM:
updated_model = self.serialize_model()
# updated_model = self.model
# if not issubclass(self.model.__class__, keras.Model):
# # Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model
# updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
# updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
# updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.tf", save_format="tf")
updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}")
else:
if not issubclass(self.model.__class__, keras.Model):
self.model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
self.model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}", save_format="tf")
def serialize_model(self) -> tf.keras.Model:
"""
Serialize and save the trained models.
Saves the trained models in different formats: h5 and tf formats.
"""
input_deserializer = DeSerializeInput(self.config.FEATURES)
output_deserializer = ReSerializeOutput()
serialized_inputs = {
b: tf.keras.Input(shape=[], dtype="string", name=b) for b in self.config.FEATURES
}
updated_model_input = input_deserializer(serialized_inputs)
updated_model = self.model(updated_model_input)
updated_model = output_deserializer(updated_model, "output")
updated_model = tf.keras.Model(serialized_inputs, updated_model)
keras.utils.plot_model(updated_model, f"{self.config.MODEL_SAVE_DIR}/serialized_model.png", show_shapes=True, rankdir="LR")
return updated_model
__init__(self, config)
special
¶
Initialize the ModelTrainer object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Config |
An object containing the configuration settings for model training. |
required |
Attributes:
Name | Type | Description |
---|---|---|
config |
The configuration settings for model training. |
|
model_builder |
An instance of ModelBuilder for building the model. |
|
build_model |
A partial function for building the model with the specified model type. |
Source code in aces/model_trainer.py
def __init__(self, config: Config):
"""
Initialize the ModelTrainer object.
Args:
config: An object containing the configuration settings for model training.
Attributes:
config: The configuration settings for model training.
model_builder: An instance of ModelBuilder for building the model.
build_model: A partial function for building the model with the specified model type.
"""
self.config = config
# @FIXME: This isn't producing reproducable results
if self.config.USE_SEED:
# producable results
import random
print(f"Using seed: {self.config.SEED}")
tf.random.set_seed(self.config.SEED)
np.random.seed(self.config.SEED)
random.seed(self.config.SEED)
# @ToDO: Create a way to autoload loss function from the list without if else
if self.config.LOSS == "custom_focal_tversky_loss":
self.config.LOSS = Metrics.focal_tversky_loss
self.LOSS_TXT = Metrics.focal_tversky_loss.__func__.__name__ # "focal_tversky_loss"
else:
self.config.LOSS_TXT = self.config.LOSS
self.model_builder = ModelBuilder(
features=self.config.FEATURES,
out_classes=self.config.OUT_CLASS_NUM,
optimizer=self.config.OPTIMIZER,
loss=self.config.LOSS
)
self.build_model = partial(self.model_builder.build_model, model_type=self.config.MODEL_TYPE,
**{"FOR_AI_PLATFORM": self.config.USE_AI_PLATFORM,
"DERIVE_FEATURES": self.config.DERIVE_FEATURES if hasattr(self.config, "DERIVE_FEATURES") else False,})
build_and_compile_model(self, print_model_summary=True)
¶
Build and compile the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
print_model_summary |
bool |
Flag indicating whether to print the model summary. |
True |
Builds and compiles the model using the provided configuration settings.
Prints the model summary if print_model_summary is set to True.
Source code in aces/model_trainer.py
def build_and_compile_model(self, print_model_summary: bool = True) -> None:
"""
Build and compile the model.
Args:
print_model_summary: Flag indicating whether to print the model summary.
Builds and compiles the model using the provided configuration settings.
Prints the model summary if print_model_summary is set to True.
"""
self.model = self.build_model(**self.config.__dict__)
if print_model_summary: print(self.model.summary())
build_and_compile_model_ai_platform(self)
¶
Build and compile the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
print_model_summary |
Flag indicating whether to print the model summary. |
required |
Builds and compiles the model using the provided configuration settings.
Prints the model summary if print_model_summary is set to True.
Source code in aces/model_trainer.py
def build_and_compile_model_ai_platform(self) -> None:
"""
Build and compile the model.
Args:
print_model_summary: Flag indicating whether to print the model summary.
Builds and compiles the model using the provided configuration settings.
Prints the model summary if print_model_summary is set to True.
"""
model, wrapped_model = self.build_model(**self.config.__dict__)
print(model.summary())
self._model = model
self.model = wrapped_model
create_datasets(self, print_info=False)
¶
Create TensorFlow datasets for training, testing, and validation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
print_info |
bool |
Flag indicating whether to print dataset information. |
False |
Prints information about the created datasets if print_info is set to True.
Source code in aces/model_trainer.py
def create_datasets(self, print_info: bool = False) -> None:
"""
Create TensorFlow datasets for training, testing, and validation.
Args:
print_info: Flag indicating whether to print dataset information.
Prints information about the created datasets if print_info is set to True.
"""
self.TRAINING_DATASET = DataProcessor.get_dataset(
f"{str(self.config.TRAINING_DIR)}/*",
self.config.FEATURES,
self.config.LABELS,
self.config.PATCH_SHAPE[0],
self.config.BATCH_SIZE,
self.config.OUT_CLASS_NUM,
**{**self.config.__dict__, "training": True},
).repeat()
self.VALIDATION_DATASET = DataProcessor.get_dataset(
f"{str(self.config.VALIDATION_DIR)}/*",
self.config.FEATURES,
self.config.LABELS,
self.config.PATCH_SHAPE[0],
1,
self.config.OUT_CLASS_NUM,
**self.config.__dict__,
).repeat()
self.TESTING_DATASET = DataProcessor.get_dataset(
f"{str(self.config.TESTING_DIR)}/*",
self.config.FEATURES,
self.config.LABELS,
self.config.PATCH_SHAPE[0],
1,
self.config.OUT_CLASS_NUM,
**self.config.__dict__,
)
if print_info:
print("Printing dataset info:")
DataProcessor.print_dataset_info(self.TRAINING_DATASET, "Training")
DataProcessor.print_dataset_info(self.TESTING_DATASET, "Testing")
DataProcessor.print_dataset_info(self.VALIDATION_DATASET, "Validation")
evaluate_and_print_val(self)
¶
Evaluate and print validation metrics.
Evaluates the model on the validation dataset and prints the metrics.
Source code in aces/model_trainer.py
def evaluate_and_print_val(self) -> None:
"""
Evaluate and print validation metrics.
Evaluates the model on the validation dataset and prints the metrics.
"""
print("************************************************")
print("************************************************")
print("Validation")
# Tip: You can remove steps=self.config.TEST_SIZE and match the TEST_SIZE from the env
evaluate_results = self.model.evaluate(self.TESTING_DATASET) # , steps=self.config.TEST_SIZE
with open(f"{self.config.MODEL_SAVE_DIR}/evaluation.txt", "w") as evaluate:
evaluate.write(json.dumps(dict(zip(self.model.metrics_names, evaluate_results))))
for name, value in zip(self.model.metrics_names, evaluate_results):
print(f"{name}: {value}")
print("\n")
load_and_save_models(self)
¶
Load the trained models.
Loads the trained models from different formats: h5 and tf formats.
Source code in aces/model_trainer.py
def load_and_save_models(self) -> None:
"""
Load the trained models.
Loads the trained models from different formats: h5 and tf formats.
"""
self.config.MODEL_SAVE_DIR = self.config.OUTPUT_DIR / self.config.MODEL_DIR_NAME
self.model = tf.keras.models.load_model(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_CHECKPOINT_NAME}.tf")
updated_model = self.serialize_model()
# if not issubclass(self.model.__class__, keras.Model):
# # Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model
# updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.tf", save_format="tf")
prepare_output_dir(self)
¶
Prepare the output directory for saving models and results.
Creates a directory with a timestamped name and increments the version number if necessary.
Source code in aces/model_trainer.py
def prepare_output_dir(self) -> None:
"""
Prepare the output directory for saving models and results.
Creates a directory with a timestamped name and increments the version number if necessary.
"""
if not self.config.AUTO_MODEL_DIR_NAME:
self.config.MODEL_SAVE_DIR = self.config.OUTPUT_DIR / self.config.MODEL_DIR_NAME
print(f"> Saving models and results at {self.config.MODEL_SAVE_DIR}...")
if not os.path.exists(self.config.MODEL_SAVE_DIR):
os.mkdir(self.config.MODEL_SAVE_DIR)
else:
today = datetime.date.today().strftime("%Y_%m_%d")
iterator = 1
while True:
model_dir_name = f"trial_{self.config.MODEL_TYPE}_{today}_v{iterator}"
self.config.MODEL_SAVE_DIR = self.config.OUTPUT_DIR / model_dir_name
try:
os.mkdir(self.config.MODEL_SAVE_DIR)
except FileExistsError:
print(f"> {self.config.MODEL_SAVE_DIR} exists, creating another version...")
iterator += 1
continue
else:
print(f"> Saving models and results at {self.config.MODEL_SAVE_DIR}...")
break
save_history_object(self)
¶
Save the history object.
Source code in aces/model_trainer.py
def save_history_object(self) -> None:
"""
Save the history object.
"""
with open(f"{self.config.MODEL_SAVE_DIR}/model.pkl", "wb") as f:
pickle.dump(self.history.history, f)
with open(f"{self.config.MODEL_SAVE_DIR}/model.txt", "w") as f:
f.write(json.dumps(self.history.history))
save_models(self)
¶
Save the trained models.
Saves the trained models in different formats: h5 and tf formats.
Source code in aces/model_trainer.py
def save_models(self) -> None:
"""
Save the trained models.
Saves the trained models in different formats: h5 and tf formats.
"""
if self.config.USE_AI_PLATFORM:
updated_model = self.serialize_model()
# updated_model = self.model
# if not issubclass(self.model.__class__, keras.Model):
# # Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model
# updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
# updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
# updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.tf", save_format="tf")
updated_model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}")
else:
if not issubclass(self.model.__class__, keras.Model):
self.model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}.h5", save_format="h5")
self.model.save(f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_NAME}", save_format="tf")
save_parameters(**config)
staticmethod
¶
Save the training parameters to a text file.
Saves the training parameters used in the configuration settings to a text file.
Source code in aces/model_trainer.py
@staticmethod
def save_parameters(**config) -> None:
"""
Save the training parameters to a text file.
Saves the training parameters used in the configuration settings to a text file.
"""
with open(f"{str(config.get('MODEL_SAVE_DIR'))}/parameters.txt", "w") as f:
f.write(f"TRAIN_SIZE: {config.get('TRAIN_SIZE')}\n")
f.write(f"TEST_SIZE: {config.get('TEST_SIZE')}\n")
f.write(f"VAL_SIZE: {config.get('VAL_SIZE')}\n")
f.write(f"BATCH_SIZE: {config.get('BATCH_SIZE')}\n")
f.write(f"EPOCHS: {config.get('EPOCHS')}\n")
f.write(f"LOSS: {config.get('LOSS_TXT')}\n")
f.write(f"TRAINING_DIR: {config.get('TRAINING_DIR')}\n")
f.write(f"TESTING_DIR: {config.get('TESTING_DIR')}\n")
f.write(f"VALIDATION_DIR: {config.get('VALIDATION_DIR')}\n")
if config.get('USE_ADJUSTED_LR'):
f.write(f"USE_ADJUSTED_LR: {config.get('USE_ADJUSTED_LR')}\n")
f.write(f"MAX_LR: {config.get('MAX_LR')}\n")
f.write(f"MID_LR: {config.get('MID_LR')}\n")
f.write(f"MIN_LR: {config.get('MIN_LR')}\n")
f.write(f"RAMPUP_EPOCHS: {config.get('RAMPUP_EPOCHS')}\n")
f.write(f"SUSTAIN_EPOCHS: {config.get('SUSTAIN_EPOCHS')}\n")
f.write(f"DROPOUT_RATE: {config.get('DROPOUT_RATE')}\n")
f.write(f"ACTIVATION_FN: {config.get('ACTIVATION_FN')}\n")
f.write(f"FEATURES: {config.get('FEATURES')}\n")
f.write(f"LABELS: {config.get('LABELS')}\n")
f.write(f"PATCH_SHAPE: {config.get('PATCH_SHAPE')}\n")
f.write(f"CALLBACK_PARAMETER: {config.get('CALLBACK_PARAMETER')}\n")
f.write(f"MODEL_TYPE: {config.get('MODEL_TYPE')}\n")
f.write(f"TRANSFORM_DATA: {config.get('TRANSFORM_DATA')}\n")
f.write(f"MODEL_NAME: {config.get('MODEL_NAME')}.h5\n")
f.write(f"MODEL_CHECKPOINT_NAME: {config.get('MODEL_CHECKPOINT_NAME')}.h5\n")
f.close()
save_plots(self)
¶
Save plots and model visualization.
Saves the model architecture plot, training history plot, and model object.
Source code in aces/model_trainer.py
def save_plots(self) -> None:
"""
Save plots and model visualization.
Saves the model architecture plot, training history plot, and model object.
"""
print(f"Saving plots and model visualization at {self.config.MODEL_SAVE_DIR}...")
Utils.plot_metrics([key.replace("val_", "") for key in self.history.history.keys() if key.startswith("val_")],
self.history.history, len(self.history.epoch), self.config.MODEL_SAVE_DIR)
if self.config.USE_AI_PLATFORM:
keras.utils.plot_model(self._model, f"{self.config.MODEL_SAVE_DIR}/model.png", show_shapes=True, rankdir="TB")
keras.utils.plot_model(self.model, f"{self.config.MODEL_SAVE_DIR}/wrapped_model.png", show_shapes=True, rankdir="LR") # rankdir='TB'
else:
keras.utils.plot_model(self.model, f"{self.config.MODEL_SAVE_DIR}/model.png", show_shapes=True, rankdir="TB") # rankdir='TB'
serialize_model(self)
¶
Serialize and save the trained models.
Saves the trained models in different formats: h5 and tf formats.
Source code in aces/model_trainer.py
def serialize_model(self) -> tf.keras.Model:
"""
Serialize and save the trained models.
Saves the trained models in different formats: h5 and tf formats.
"""
input_deserializer = DeSerializeInput(self.config.FEATURES)
output_deserializer = ReSerializeOutput()
serialized_inputs = {
b: tf.keras.Input(shape=[], dtype="string", name=b) for b in self.config.FEATURES
}
updated_model_input = input_deserializer(serialized_inputs)
updated_model = self.model(updated_model_input)
updated_model = output_deserializer(updated_model, "output")
updated_model = tf.keras.Model(serialized_inputs, updated_model)
keras.utils.plot_model(updated_model, f"{self.config.MODEL_SAVE_DIR}/serialized_model.png", show_shapes=True, rankdir="LR")
return updated_model
start_training(self)
¶
Start the training process.
Trains the model using the provided configuration settings and callbacks.
Source code in aces/model_trainer.py
def start_training(self) -> None:
"""
Start the training process.
Trains the model using the provided configuration settings and callbacks.
"""
model_checkpoint = callbacks.ModelCheckpoint(
f"{str(self.config.MODEL_SAVE_DIR)}/{self.config.MODEL_CHECKPOINT_NAME}",
monitor=self.config.CALLBACK_PARAMETER,
save_best_only=True,
mode="auto",
verbose=1,
save_weights_only=False,
) # save best model
tensorboard = callbacks.TensorBoard(log_dir=str(self.config.MODEL_SAVE_DIR / "logs"), write_images=True)
def lr_scheduler(epoch):
if epoch < self.config.RAMPUP_EPOCHS:
return self.config.MAX_LR
elif epoch < self.config.RAMPUP_EPOCHS + self.config.SUSTAIN_EPOCHS:
return self.config.MID_LR
else:
return self.config.MIN_LR
model_callbacks = [model_checkpoint, tensorboard]
if self.config.USE_ADJUSTED_LR:
lr_callback = callbacks.LearningRateScheduler(lambda epoch: lr_scheduler(epoch), verbose=True)
model_callbacks.append(lr_callback)
if self.config.EARLY_STOPPING:
early_stopping = callbacks.EarlyStopping(
monitor=self.config.CALLBACK_PARAMETER,
patience=int(0.3 * self.config.EPOCHS),
verbose=1,
mode="auto",
restore_best_weights=True,
)
model_callbacks.append(early_stopping)
self.model_callbacks = model_callbacks
self.history = self.model.fit(
x=self.TRAINING_DATASET,
epochs=self.config.EPOCHS,
steps_per_epoch=(self.config.TRAIN_SIZE // self.config.BATCH_SIZE),
validation_data=self.VALIDATION_DATASET,
validation_steps=self.config.VAL_SIZE,
callbacks=model_callbacks,
)
# either save the wrapped model or the original model
# named as "trained-model" to avoid confusion
# self.model.save(f"{self.config.MODEL_SAVE_DIR}/trained-wrapped-model")
self.model.save(f"{self.config.MODEL_SAVE_DIR}/trained-model")
train_model(self)
¶
Train the model using the provided configuration settings.
This method performs the following steps:
-
Configures memory growth for TensorFlow.
-
Creates TensorFlow datasets for training, testing, and validation.
-
Builds and compiles the model.
-
Prepares the output directory for saving models and results.
-
Starts the training process.
-
Evaluates and prints validation metrics.
-
Saves training parameters, plots, and models.
Source code in aces/model_trainer.py
def train_model(self) -> None:
"""
Train the model using the provided configuration settings.
This method performs the following steps:
1. Configures memory growth for TensorFlow.
2. Creates TensorFlow datasets for training, testing, and validation.
3. Builds and compiles the model.
4. Prepares the output directory for saving models and results.
5. Starts the training process.
6. Evaluates and prints validation metrics.
7. Saves training parameters, plots, and models.
"""
print("****************************************************************************")
print("****************************** Clear Session... ****************************")
keras.backend.clear_session()
print("****************************************************************************")
print(f"****************************** Configure memory growth... ************************")
physical_devices = TFUtils.configure_memory_growth()
self.config.physical_devices = physical_devices
print("****************************************************************************")
print("****************************** creating datasets... ************************")
self.create_datasets(print_info=self.config.PRINT_INFO)
if self.config.USE_AI_PLATFORM:
print("****************************************************************************")
print("******* building and compiling model for ai platform... ********************")
self.build_and_compile_model_ai_platform()
else:
print("****************************************************************************")
print("************************ building and compiling model... *******************")
self.build_and_compile_model(print_model_summary=True)
print("****************************************************************************")
print("************************ preparing output directory... *********************")
self.prepare_output_dir()
print("****************************************************************************")
print("****************************** training model... ***************************")
self.start_training()
if self.config.USE_AI_PLATFORM:
print(self.model.summary())
print("****************************************************************************")
print("****************************** evaluating model... *************************")
self.evaluate_and_print_val()
print("****************************************************************************")
print("****************************** saving parameters... ************************")
ModelTrainer.save_parameters(**self.config.__dict__)
print("****************************************************************************")
print("*************** saving model config and history object... ******************")
self.save_history_object()
if self.config.USE_AI_PLATFORM:
ModelTrainer.save_model_config(self.config.MODEL_SAVE_DIR, **self._model.get_config())
else:
ModelTrainer.save_model_config(self.config.MODEL_SAVE_DIR, **self.model.get_config())
print("****************************************************************************")
print("****************************** saving plots... *****************************")
self.save_plots()
print("****************************************************************************")
print("****************************** saving models... ****************************")
self.save_models()
print("****************************************************************************")