Skip to content

Commit

Permalink
Fixed FreeReach
Browse files Browse the repository at this point in the history
  • Loading branch information
weidler committed Jun 12, 2022
1 parent edd963e commit cf760e5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
2 changes: 1 addition & 1 deletion angorapy/environments/reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _sample_goal(self):

def get_target_finger_position(self):
"""Get position of the target finger in space."""
return self.sim.data.get_site_xpos(FINGERTIP_SITE_NAMES[np.where(self.goal == 1)[0].item()]).flatten()
return self.data.site(FINGERTIP_SITE_NAMES[np.where(self.goal == 1)[0].item()]).xpos.flatten()

def _is_success(self, achieved_goal, desired_goal):
d = get_fingertip_distance(self.get_thumb_position(), self.get_target_finger_position())
Expand Down
28 changes: 26 additions & 2 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_drill_discrete(self):
except Exception:
self.fail("Discrete drill raises error.")

def test_drill_dexterity_multicontinuous(self):
def test_drill_manipulate_multicontinuous(self):
"""Test drilling of discrete agent (LunarLander)."""

try:
Expand All @@ -62,7 +62,7 @@ def test_drill_dexterity_multicontinuous(self):
except Exception:
self.fail("HumanoidManipulateBlockDiscreteAsynchronous drill raises error.")

def test_drill_dexterity_continuous(self):
def test_drill_manipulate_continuous(self):
"""Test drilling of discrete agent (LunarLander)."""

try:
Expand All @@ -74,6 +74,30 @@ def test_drill_dexterity_continuous(self):
except Exception:
self.fail("HumanoidManipulateBlockDiscreteAsynchronous drill raises error.")

def test_drill_reach(self):
"""Test drilling of discrete agent (LunarLander)."""

try:
wrappers = [StateNormalizationTransformer, RewardNormalizationTransformer]
env = make_env("ReachAbsolute-v0", reward_config=None, transformers=wrappers)
build_models = get_model_builder(model="shadow", model_type="lstm", shared=False)
agent = PPOAgent(build_models, env, workers=2, horizon=128, distribution=BetaPolicyDistribution(env))
agent.drill(n=2, epochs=2, batch_size=64)
except Exception:
self.fail("ReachAbsolute drill raises error.")

def test_drill_freereach(self):
"""Test drilling of free reach agent."""

try:
wrappers = [StateNormalizationTransformer, RewardNormalizationTransformer]
env = make_env("FreeReachAbsolute-v0", reward_config=None, transformers=wrappers)
build_models = get_model_builder(model="shadow", model_type="lstm", shared=False)
agent = PPOAgent(build_models, env, workers=2, horizon=128, distribution=BetaPolicyDistribution(env))
agent.drill(n=2, epochs=2, batch_size=64)
except Exception:
self.fail("FreeReachAbsolute drill raises error.")


if __name__ == '__main__':
unittest.main()

0 comments on commit cf760e5

Please sign in to comment.