Skip to content

Commit

Permalink
fix: add check for NaN validation loss in EarlyStopping (#28)
Browse files Browse the repository at this point in the history
This PR addresses an issue where `EarlyStopping` incorrectly treats
`nan` validation losses as an improvement, often caused by exploding
gradients.

Key changes:
- Added `np.isnan(val_loss)` check to ensure that `nan` validation
losses are ignored.
- Updated the logic to ensure that the patience counter and model
checkpointing are unaffected by `nan` values.
- Introduced a new unit test, `test_validation_loss_nan`, to verify that
`EarlyStopping` behaves correctly when `nan` values are encountered
during training.

Closes #16
  • Loading branch information
Bjarten authored Oct 15, 2024
1 parent ffe12ee commit 676686b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytorchtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra
self.trace_func = trace_func

def __call__(self, val_loss, model):
# Check if validation loss is nan
if np.isnan(val_loss):
self.trace_func("Validation loss is NaN. Ignoring this epoch.")
return

if self.best_val_loss is None:
self.best_val_loss = val_loss
Expand Down
27 changes: 27 additions & 0 deletions tests/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,31 @@ def test_delta_functionality(mock_model, temp_checkpoint_path):
# Assert no new checkpoints were saved after early stopping was triggered
assert mock_save_checkpoint.call_count == 2, "No additional checkpoints should be saved after early stopping was triggered"

def test_validation_loss_nan(mock_model, temp_checkpoint_path):
"""
Test that EarlyStopping ignores epochs where validation loss is NaN.
This test ensures that when a validation loss is NaN, EarlyStopping does not update the model
checkpoint, does not reset the patience counter, and ignores the NaN epoch.
"""
# Patch the save_checkpoint method used inside EarlyStopping
with patch.object(EarlyStopping, 'save_checkpoint') as mock_save_checkpoint:
# Initialize EarlyStopping with specified parameters
early_stopping = EarlyStopping(patience=3, verbose=False, path=temp_checkpoint_path)

# Simulate validation losses, including NaN
losses = [1.0, 0.95, float('nan'), 0.9]
for loss in losses:
early_stopping(loss, mock_model)

# Assert that save_checkpoint was called three times:
# - Initial call (loss=1.0)
# - Improvement (loss=0.95)
# - Improvement (loss=0.9)
assert mock_save_checkpoint.call_count == 3, "Checkpoints should be saved on initial and each significant improvement, ignoring NaN"

# Assert that early stopping is not triggered
assert not early_stopping.early_stop, "Early stop should not be triggered when validations improve"

# Assert that the patience counter was not incremented for NaN loss
assert early_stopping.counter == 0, "Counter should remain 0 since NaN loss was ignored"

0 comments on commit 676686b

Please sign in to comment.