From 1a31f2848fb89b76fd952f420ca86920ac795193 Mon Sep 17 00:00:00 2001 From: KOSASIH Date: Fri, 5 Jul 2024 16:37:27 +0700 Subject: [PATCH] Create ReinforcementLearningAgent.py --- .ai/models/ReinforcementLearningAgent.py | 30 ++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .ai/models/ReinforcementLearningAgent.py diff --git a/.ai/models/ReinforcementLearningAgent.py b/.ai/models/ReinforcementLearningAgent.py new file mode 100644 index 0000000..d5a5e10 --- /dev/null +++ b/.ai/models/ReinforcementLearningAgent.py @@ -0,0 +1,30 @@ +import gym +import numpy as np +from stable_baselines3 import PPO + +class ReinforcementLearningAgent: + def __init__(self, env_name, model_path): + self.env = gym.make(env_name) + self.model = PPO.load(model_path) + + def train(self, num_episodes): + for episode in range(num_episodes): + obs = self.env.reset() + done = False + rewards = 0 + while not done: + action, _ = self.model.predict(obs) + obs, reward, done, _ = self.env.step(action) + rewards += reward + print(f"Episode {episode+1}, Reward: {rewards}") + + def test(self, num_episodes): + for episode in range(num_episodes): + obs = self.env.reset() + done = False + rewards = 0 + while not done: + action, _ = self.model.predict(obs) + obs, reward, done, _ = self.env.step(action) + rewards += reward + print(f"Episode {episode+1}, Reward: {rewards}")