From cbcf3e59e2959b5cb7f9eb7b3dc29a67dbb188a0 Mon Sep 17 00:00:00 2001 From: Junjia Liu Date: Tue, 8 Aug 2023 13:21:27 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20[RofuncRL]=20Update=20tqdm=20pos?= =?UTF-8?q?tfix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rofunc/learning/RofuncRL/trainers/base_trainer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/rofunc/learning/RofuncRL/trainers/base_trainer.py b/rofunc/learning/RofuncRL/trainers/base_trainer.py index b3b46f933..3c977fa8a 100644 --- a/rofunc/learning/RofuncRL/trainers/base_trainer.py +++ b/rofunc/learning/RofuncRL/trainers/base_trainer.py @@ -91,6 +91,7 @@ def __init__(self, self.start_time = None self.eval_steps = self.cfg.Trainer.eval_steps self.inference_steps = self.cfg.Trainer.inference_steps + self.total_rew_mean = -1e4 '''Environment''' env.device = self.device @@ -190,10 +191,10 @@ def post_interaction(self): # Update best models and tensorboard if not self._step % self.write_interval and self.write_interval > 0: # update best models - reward = np.mean(self.agent.tracking_data.get("Reward / Total reward (mean)", -1e4)) - if reward > self.agent.checkpoint_best_modules["reward"]: + self.total_rew_mean = np.mean(self.agent.tracking_data.get("Reward / Total reward (mean)", -1e4)) + if self.total_rew_mean > self.agent.checkpoint_best_modules["reward"]: self.agent.checkpoint_best_modules["timestep"] = self._step - self.agent.checkpoint_best_modules["reward"] = reward + self.agent.checkpoint_best_modules["reward"] = self.total_rew_mean self.agent.checkpoint_best_modules["saved"] = False self.agent.checkpoint_best_modules["modules"] = {k: copy.deepcopy(self.agent._get_internal_value(v)) for k, v in self.agent.checkpoint_modules.items()} @@ -203,8 +204,9 @@ def post_interaction(self): self.write_tensorboard() # Update tqdm bar message - self.t_bar.set_postfix_str(f"Rew/Best: {reward:.2f}/{self.agent.checkpoint_best_modules['reward']:.2f}") - self.rofunc_logger.info(f"Step: {self._step}, Reward: {reward:.2f}", local_verbose=False) + self.t_bar.set_postfix_str( + f"Rew/Best: {self.total_rew_mean:.2f}/{self.agent.checkpoint_best_modules['reward']:.2f}") + self.rofunc_logger.info(f"Step: {self._step}, Reward: {self.total_rew_mean:.2f}", local_verbose=False) # Save checkpoints if not (self._step + 1) % self.agent.checkpoint_interval and \