Skip to content

Commit 13137df

Browse files
committed
saving time by not saving models during grid search
1 parent f809f9d commit 13137df

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

libmultilabel/nn/nn_utils.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def init_trainer(
131131
limit_val_batches=1.0,
132132
limit_test_batches=1.0,
133133
save_checkpoints=True,
134+
is_tune_mode=False,
134135
):
135136
"""Initialize a torch lightning trainer.
136137
@@ -146,6 +147,7 @@ def init_trainer(
146147
limit_val_batches (Union[int, float]): Percentage of validation dataset to use. Defaults to 1.0.
147148
limit_test_batches (Union[int, float]): Percentage of test dataset to use. Defaults to 1.0.
148149
save_checkpoints (bool): Whether to save the last and the best checkpoint or not. Defaults to True.
150+
is_tune_mode (bool): Whether is parameter search is running or not. Defaults to False.
149151
150152
Returns:
151153
lightning.trainer: A torch lightning trainer.
@@ -163,7 +165,19 @@ def init_trainer(
163165
strict=False,
164166
)
165167
callbacks = [early_stopping_callback]
166-
if save_checkpoints:
168+
169+
if is_tune_mode:
170+
callbacks += [
171+
ModelCheckpoint(
172+
dirpath=checkpoint_dir,
173+
filename="best_model",
174+
save_top_k=1,
175+
save_weights_only=True,
176+
monitor=val_metric,
177+
mode="min" if val_metric == "Loss" else "max",
178+
)
179+
]
180+
elif save_checkpoints:
167181
callbacks += [
168182
ModelCheckpoint(
169183
dirpath=checkpoint_dir,

search_params.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def train_libmultilabel_tune(config, datasets, classes, word_dict):
4242
datasets=datasets,
4343
classes=classes,
4444
word_dict=word_dict,
45-
save_checkpoints=True,
45+
save_checkpoints=False,
46+
is_tune_mode=True,
4647
)
4748
val_score = trainer.train()
4849
return {f"val_{config.val_metric}": val_score}

torch_trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
word_dict: dict = None,
3434
embed_vecs=None,
3535
save_checkpoints: bool = True,
36+
is_tune_mode: bool = False,
3637
):
3738
self.run_name = config.run_name
3839
self.checkpoint_dir = config.checkpoint_dir
@@ -119,6 +120,7 @@ def __init__(
119120
limit_val_batches=config.limit_val_batches,
120121
limit_test_batches=config.limit_test_batches,
121122
save_checkpoints=save_checkpoints,
123+
is_tune_mode=is_tune_mode,
122124
)
123125
callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)]
124126
self.checkpoint_callback = callbacks[0] if callbacks else None

0 commit comments

Comments
 (0)