Skip to content

Commit

Permalink
feat: add Pipeline.training attribute (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
hbredin authored Sep 15, 2021
1 parent f6946e4 commit 5df4f99
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
14 changes: 13 additions & 1 deletion pyannote/pipeline/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# The MIT License (MIT)

# Copyright (c) 2018-2020 CNRS
# Copyright (c) 2018-2021 CNRS

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -261,6 +261,9 @@ def tune(
['params'] nested dictionary of optimal parameters
"""

# pipeline is currently being optimized
self.pipeline.training = True

objective = self.get_objective(inputs, show_progress=show_progress)

if warm_start:
Expand All @@ -272,6 +275,9 @@ def tune(

self.study_.optimize(objective, n_trials=n_iterations, timeout=None, n_jobs=1)

# pipeline is no longer being optimized
self.pipeline.training = False

return {"loss": self.best_loss, "params": self.best_params}

def tune_iter(
Expand Down Expand Up @@ -311,6 +317,9 @@ def tune_iter(

while True:

# pipeline is currently being optimized
self.pipeline.training = True

# one trial at a time
self.study_.optimize(objective, n_trials=1, timeout=None, n_jobs=1)

Expand All @@ -320,4 +329,7 @@ def tune_iter(
except ValueError as e:
continue

# pipeline is no longer being optimized
self.pipeline.training = False

yield {"loss": best_loss, "params": best_params}
5 changes: 4 additions & 1 deletion pyannote/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# The MIT License (MIT)

# Copyright (c) 2018-2020 CNRS
# Copyright (c) 2018-2021 CNRS

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -57,6 +57,9 @@ def __init__(self):
# sub-pipelines
self._pipelines = OrderedDict()

# whether pipeline is currently being optimized
self.training = False

def __hash__(self):
# FIXME -- also keep track of (sub)pipeline attribtes
frozen = self.parameters(frozen=True)
Expand Down

0 comments on commit 5df4f99

Please sign in to comment.