Skip to content

Commit

Permalink
Add separate plot_model arg to training (#582)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfhealy authored Apr 12, 2024
1 parent 143fba4 commit a6986ed
Showing 1 changed file with 8 additions and 32 deletions.
40 changes: 8 additions & 32 deletions scope/scope_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,11 @@ def parse_run_train(self):
action="store_true",
help="if set, generate/save diagnostic training plots",
)
parser.add_argument(
"--plot-model",
action="store_true",
help="if set, plot model architecture",
)
parser.add_argument(
"--weights-only",
action="store_true",
Expand All @@ -688,37 +693,6 @@ def parse_run_train(self):
args, _ = parser.parse_known_args()
self.train(**vars(args))

# args to add for ds.make (override config-specified values)
# threshold
# balance
# weight_per_class (test this to make sure it works as intended)
# scale_features
# test_size
# val_size
# random_state
# feature_stats
# batch_size
# shuffle_buffer_size
# epochs
# float_convert_types

# Args to add with descriptions (or references to tf docs)
# lr
# beta_1
# beta_2
# epsilon
# decay
# amsgrad
# momentum
# monitor
# patience
# callbacks
# run_eagerly
# pre_trained_model
# save
# plot
# weights_only

def train(
self,
tag: str,
Expand Down Expand Up @@ -756,6 +730,7 @@ def train(
pre_trained_model: str = None,
save: bool = False,
plot: bool = False,
plot_model: bool = False,
weights_only: bool = False,
skip_cv: bool = False,
**kwargs,
Expand Down Expand Up @@ -797,6 +772,7 @@ def train(
:param pre_trained_model: name of dnn pre-trained model to load, if any (str)
:param save: if set, save trained model (bool)
:param plot: if set, generate/save diagnostic training plots (bool)
:param plot_model: if set, plot model architecture (bool)
:param weights_only: if set and pre-trained model specified, load only weights (bool)
:param skip_cv: if set, skip XGB cross-validation (bool)
Expand Down Expand Up @@ -1121,7 +1097,7 @@ def train(
amsgrad=amsgrad,
)

if plot:
if plot_model:
tf.keras.utils.plot_model(
classifier.model,
to_file=self.base_path / "DNN.pdf",
Expand Down

0 comments on commit a6986ed

Please sign in to comment.