Skip to content

Commit

Permalink
Create ReinforcementLearningAgent.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KOSASIH authored Jul 5, 2024
1 parent 5361ee3 commit 1a31f28
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions .ai/models/ReinforcementLearningAgent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import gym
import numpy as np
from stable_baselines3 import PPO

class ReinforcementLearningAgent:
def __init__(self, env_name, model_path):
self.env = gym.make(env_name)
self.model = PPO.load(model_path)

def train(self, num_episodes):
for episode in range(num_episodes):
obs = self.env.reset()
done = False
rewards = 0
while not done:
action, _ = self.model.predict(obs)
obs, reward, done, _ = self.env.step(action)
rewards += reward
print(f"Episode {episode+1}, Reward: {rewards}")

def test(self, num_episodes):
for episode in range(num_episodes):
obs = self.env.reset()
done = False
rewards = 0
while not done:
action, _ = self.model.predict(obs)
obs, reward, done, _ = self.env.step(action)
rewards += reward
print(f"Episode {episode+1}, Reward: {rewards}")

0 comments on commit 1a31f28

Please sign in to comment.