Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threading not working #104

Open
tchaye59 opened this issue Oct 4, 2020 · 1 comment
Open

Threading not working #104

tchaye59 opened this issue Oct 4, 2020 · 1 comment

Comments

@tchaye59
Copy link

tchaye59 commented Oct 4, 2020

Hi,

When using multiple threads to collect experience each thread always get lock on call of the step function.
I have been facing this problem since I updated the API. I tried both halite and football environment.

new_obs, reward, done, info = self.env.step(actions)

Full code:

`class EpisodeCollector(threading.Thread):
n_episode = 0
reward_sum = 0
max_episode = 0

def __init__(self, env: FootEnv, policy: Policy, result_queue=None, replays_dir=None):
    super().__init__()
    self.result_queue = result_queue
    self.env = env
    self.policy = policy
    self.replays_dir = replays_dir
    self.n_episode = -1

def clone(self):
    obj = EpisodeCollector(self.env, self.policy)
    obj.result_queue = self.result_queue
    obj.replays_dir = self.replays_dir
    obj.n_episode = self.n_episode
    return obj

def run(self):
    self.result_queue.put(self.collect(1))

def collect(self, n=1):
    n = max(n, self.n_episode)
    return [self.collect_() for _ in range(n)]

def collect_(self):
    memory = Memory()
    done = False
    EpisodeCollector.n_episode += 1
    obs = self.env.reset()
    i = 0
    total_reward = 0
    state = None
    while not done:
        actions, state = self.policy.get_action(obs, state=state)
        new_obs, reward, done, info = self.env.step(actions[0])
        total_reward = reward
        # store data
        memory.store(obs, actions, reward, done)

        if done or i % 100 == 0:
            with lock:
                print(
                    f"Episode: {EpisodeCollector.n_episode}/{EpisodeCollector.max_episode} | "
                    f"Step: {i} | "
                    f"Env ID: {self.env.env_id} | "
                    f"Reward: {total_reward} | "
                    f"Done: {done} | "
                    f"Total Rewards: {EpisodeCollector.reward_sum} | "
                )
                print(info)

        obs = new_obs
        i += 1
    EpisodeCollector.reward_sum += total_reward
    if self.replays_dir:
        with open(os.path.join(self.replays_dir, f'replay-{uuid.uuid4().hex}.dill'), 'wb') as f:
            dill.dump(memory, f)
    return memory

class ParallelEpisodeCollector:

def __init__(self, env_fn, n_jobs, policy: Policy, replays_dir=None, ):
    self.n_jobs = n_jobs
    self.policy: Policy
    self.envs = []
    self.result_queue = Queue()
    self.replays_dir = replays_dir
    for i in range(n_jobs):
        self.envs.append(env_fn(env_id=i))
    self.collectors = [EpisodeCollector(env,
                                        policy=policy,
                                        result_queue=self.result_queue,
                                        replays_dir=replays_dir) for env in self.envs]

def collect(self, n_steps=1):
    if not n_steps: n_steps = 1
    result_queue = self.result_queue
    for i, collector in enumerate(self.collectors):
        collector = collector.clone()
        self.collectors[i] = collector
        collector.n_episode = max(1, int(n_steps / len(self.collectors)))
        print("Starting collector {}".format(i))
        collector.start()
    tmp = []
    for _ in self.collectors:
        res = result_queue.get()
        tmp.extend(res)
    [collector.join() for collector in self.collectors]
    return tmp`
@ModdyLP
Copy link

ModdyLP commented Nov 15, 2021

Any Progress on this issue? Any Tipps for handling ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants