From cf760e5a331768c696f89e67556dd864706113ec Mon Sep 17 00:00:00 2001 From: Tonio Weidler Date: Sun, 12 Jun 2022 22:24:16 +0200 Subject: [PATCH] Fixed FreeReach --- angorapy/environments/reach.py | 2 +- tests/test_agent.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/angorapy/environments/reach.py b/angorapy/environments/reach.py index 4dc29844..63fd7ff8 100644 --- a/angorapy/environments/reach.py +++ b/angorapy/environments/reach.py @@ -318,7 +318,7 @@ def _sample_goal(self): def get_target_finger_position(self): """Get position of the target finger in space.""" - return self.sim.data.get_site_xpos(FINGERTIP_SITE_NAMES[np.where(self.goal == 1)[0].item()]).flatten() + return self.data.site(FINGERTIP_SITE_NAMES[np.where(self.goal == 1)[0].item()]).xpos.flatten() def _is_success(self, achieved_goal, desired_goal): d = get_fingertip_distance(self.get_thumb_position(), self.get_target_finger_position()) diff --git a/tests/test_agent.py b/tests/test_agent.py index adb5f41a..2eb11cf5 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -50,7 +50,7 @@ def test_drill_discrete(self): except Exception: self.fail("Discrete drill raises error.") - def test_drill_dexterity_multicontinuous(self): + def test_drill_manipulate_multicontinuous(self): """Test drilling of discrete agent (LunarLander).""" try: @@ -62,7 +62,7 @@ def test_drill_dexterity_multicontinuous(self): except Exception: self.fail("HumanoidManipulateBlockDiscreteAsynchronous drill raises error.") - def test_drill_dexterity_continuous(self): + def test_drill_manipulate_continuous(self): """Test drilling of discrete agent (LunarLander).""" try: @@ -74,6 +74,30 @@ def test_drill_dexterity_continuous(self): except Exception: self.fail("HumanoidManipulateBlockDiscreteAsynchronous drill raises error.") + def test_drill_reach(self): + """Test drilling of discrete agent (LunarLander).""" + + try: + wrappers = [StateNormalizationTransformer, RewardNormalizationTransformer] + env = make_env("ReachAbsolute-v0", reward_config=None, transformers=wrappers) + build_models = get_model_builder(model="shadow", model_type="lstm", shared=False) + agent = PPOAgent(build_models, env, workers=2, horizon=128, distribution=BetaPolicyDistribution(env)) + agent.drill(n=2, epochs=2, batch_size=64) + except Exception: + self.fail("ReachAbsolute drill raises error.") + + def test_drill_freereach(self): + """Test drilling of free reach agent.""" + + try: + wrappers = [StateNormalizationTransformer, RewardNormalizationTransformer] + env = make_env("FreeReachAbsolute-v0", reward_config=None, transformers=wrappers) + build_models = get_model_builder(model="shadow", model_type="lstm", shared=False) + agent = PPOAgent(build_models, env, workers=2, horizon=128, distribution=BetaPolicyDistribution(env)) + agent.drill(n=2, epochs=2, batch_size=64) + except Exception: + self.fail("FreeReachAbsolute drill raises error.") + if __name__ == '__main__': unittest.main()