Skip to content

Commit

Permalink
apply isort
Browse files Browse the repository at this point in the history
  • Loading branch information
cr-xu committed Feb 4, 2024
1 parent 3fdf070 commit bd9206f
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 16 deletions.
9 changes: 2 additions & 7 deletions meta-rl/maml_rl/envs/awake_steering_simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,8 @@
from gymnasium import Wrapper, spaces
from gymnasium.core import WrapperObsType

from maml_rl.envs.helpers import (
Awake_Benchmarking_Wrapper,
MamlHelpers,
Plane,
plot_optimal_policy,
plot_results,
)
from maml_rl.envs.helpers import (Awake_Benchmarking_Wrapper, MamlHelpers,
Plane, plot_optimal_policy, plot_results)

# Standard environment for the AWAKE environment,
# adjusted, so it can be used for the MAML therefore containing
Expand Down
8 changes: 2 additions & 6 deletions meta-rl/maml_rl/metalearners/maml_trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
from maml_rl.metalearners.base import GradientBasedMetaLearner
from maml_rl.utils.optimization import conjugate_gradient
from maml_rl.utils.reinforcement_learning import reinforce_loss
from maml_rl.utils.torch_utils import (
detach_distribution,
to_numpy,
vector_to_parameters,
weighted_mean,
)
from maml_rl.utils.torch_utils import (detach_distribution, to_numpy,
vector_to_parameters, weighted_mean)


class MAMLTRPO(GradientBasedMetaLearner):
Expand Down
3 changes: 2 additions & 1 deletion meta-rl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import numpy as np
import torch
from stable_baselines3 import PPO

from maml_rl.envs.awake_steering_simulated import AwakeSteering as awake_env
from policy_test import verify_external_policy_on_specific_env
from stable_baselines3 import PPO


def main(args):
Expand Down
3 changes: 2 additions & 1 deletion meta-rl/read_out_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

import matplotlib.pyplot as plt
import numpy as np
from maml_rl.utils.reinforcement_learning import get_returns
from sympy import root

from maml_rl.utils.reinforcement_learning import get_returns

# from maml_rl.utils.torch_utils import to_numpy


Expand Down
3 changes: 2 additions & 1 deletion meta-rl/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from maml_rl.envs.awake_steering_simulated import AwakeSteering as awake_env
from maml_rl.samplers import MultiTaskSampler
from maml_rl.utils.helpers import get_input_size, get_policy_for_env
from maml_rl.utils.reinforcement_learning import get_episode_lengths, get_returns
from maml_rl.utils.reinforcement_learning import (get_episode_lengths,
get_returns)
from policy_test import _layout_verficication_plot, verify


Expand Down

0 comments on commit bd9206f

Please sign in to comment.