Skip to content

Commit

Permalink
🚀 [RofuncRL] Update tqdm postfix
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylark0924 committed Aug 8, 2023
1 parent 3e0e97d commit cbcf3e5
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions rofunc/learning/RofuncRL/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self,
self.start_time = None
self.eval_steps = self.cfg.Trainer.eval_steps
self.inference_steps = self.cfg.Trainer.inference_steps
self.total_rew_mean = -1e4

'''Environment'''
env.device = self.device
Expand Down Expand Up @@ -190,10 +191,10 @@ def post_interaction(self):
# Update best models and tensorboard
if not self._step % self.write_interval and self.write_interval > 0:
# update best models
reward = np.mean(self.agent.tracking_data.get("Reward / Total reward (mean)", -1e4))
if reward > self.agent.checkpoint_best_modules["reward"]:
self.total_rew_mean = np.mean(self.agent.tracking_data.get("Reward / Total reward (mean)", -1e4))
if self.total_rew_mean > self.agent.checkpoint_best_modules["reward"]:
self.agent.checkpoint_best_modules["timestep"] = self._step
self.agent.checkpoint_best_modules["reward"] = reward
self.agent.checkpoint_best_modules["reward"] = self.total_rew_mean
self.agent.checkpoint_best_modules["saved"] = False
self.agent.checkpoint_best_modules["modules"] = {k: copy.deepcopy(self.agent._get_internal_value(v)) for
k, v in self.agent.checkpoint_modules.items()}
Expand All @@ -203,8 +204,9 @@ def post_interaction(self):
self.write_tensorboard()

# Update tqdm bar message
self.t_bar.set_postfix_str(f"Rew/Best: {reward:.2f}/{self.agent.checkpoint_best_modules['reward']:.2f}")
self.rofunc_logger.info(f"Step: {self._step}, Reward: {reward:.2f}", local_verbose=False)
self.t_bar.set_postfix_str(
f"Rew/Best: {self.total_rew_mean:.2f}/{self.agent.checkpoint_best_modules['reward']:.2f}")
self.rofunc_logger.info(f"Step: {self._step}, Reward: {self.total_rew_mean:.2f}", local_verbose=False)

# Save checkpoints
if not (self._step + 1) % self.agent.checkpoint_interval and \
Expand Down

0 comments on commit cbcf3e5

Please sign in to comment.