Skip to content

Commit

Permalink
Fixes img/sec/core.
Browse files Browse the repository at this point in the history
Before this fix, only `lt0` but not `lstep` was updated after computing an evaluation / writing a checkpoint, which led to a img/sec/core computation that was too high.

PiperOrigin-RevId: 555396130
  • Loading branch information
andsteing authored and copybara-github committed Aug 10, 2023
1 parent ac6e056 commit 6e76888
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions vit_jax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def init_model():
(step == total_steps)):

accuracies = []
lt0 = time.time()
tt0 = time.time()
for test_batch in input_pipeline.prefetch(ds_test, config.prefetch):
logits = infer_fn_repl(
dict(params=params_repl), test_batch['image'])
Expand All @@ -223,8 +223,7 @@ def init_model():
accuracy_test = np.mean(accuracies)
img_sec_core_test = (
config.batch_eval * ds_test.cardinality().numpy() /
(time.time() - lt0) / jax.device_count())
lt0 = time.time()
(time.time() - tt0) / jax.device_count())

lr = float(lr_fn(step))
logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation
Expand All @@ -237,14 +236,17 @@ def init_model():
accuracy_test=accuracy_test,
lr=lr,
img_sec_core_test=img_sec_core_test))
lt0 += time.time() - tt0

# Store checkpoint.
if ((config.checkpoint_every and step % config.eval_every == 0) or
step == total_steps):
tt0 = time.time()
checkpoint_path = flax_checkpoints.save_checkpoint(
workdir, (flax.jax_utils.unreplicate(params_repl),
flax.jax_utils.unreplicate(opt_state_repl), step), step)
logging.info('Stored checkpoint at step %d to "%s"', step,
checkpoint_path)
lt0 += time.time() - tt0

return flax.jax_utils.unreplicate(params_repl)

0 comments on commit 6e76888

Please sign in to comment.