Skip to content

Commit

Permalink
Protect StatefulPool from class methods
Browse files Browse the repository at this point in the history
StatefulPool doesn't support batching over with class methods because it
always passes `G` as the first argument to the worker function. If one of the
`run_` methods in StatefulPool is called with a class method it can lead to
a silent lock-up of the pool, which is very difficult to debug.

Note: this bug does not appear unless n_parallel > 1
  • Loading branch information
ryanjulian committed Mar 27, 2018
1 parent b3a2899 commit 03a0271
Showing 1 changed file with 52 additions and 21 deletions.
73 changes: 52 additions & 21 deletions rllab/sampler/stateful_pool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@


from joblib.pool import MemmapingPool
import inspect
import multiprocessing as mp
from rllab.misc import logger
import pyprind
import time
import traceback
import sys

from joblib.pool import MemmapingPool
import pyprind

from rllab.misc import logger


class ProgBarCounter(object):
def __init__(self, total_count):
Expand Down Expand Up @@ -63,18 +64,24 @@ def initialize(self, n_parallel):

def run_each(self, runner, args_list=None):
"""
Run the method on each worker process, and collect the result of execution.
The runner method will receive 'G' as its first argument, followed by the arguments
in the args_list, if any
Run the method on each worker process, and collect the result of
execution.
The runner method will receive 'G' as its first argument, followed by
the arguments in the args_list, if any
:return:
"""
assert not inspect.ismethod(runner), (
"run_each() cannot run a class method. Please ensure that runner is"
" a function with the prototype def foo(G, ...), where G is an "
"object of type rllab.sampler.stateful_pool.SharedGlobal")

if args_list is None:
args_list = [tuple()] * self.n_parallel
assert len(args_list) == self.n_parallel
if self.n_parallel > 1:
results = self.pool.map_async(
_worker_run_each, [(runner, args) for args in args_list]
)
_worker_run_each, [(runner, args) for args in args_list])
for i in range(self.n_parallel):
self.worker_queue.get()
for i in range(self.n_parallel):
Expand All @@ -83,50 +90,74 @@ def run_each(self, runner, args_list=None):
return [runner(self.G, *args_list[0])]

def run_map(self, runner, args_list):
assert not inspect.ismethod(runner), (
"run_map() cannot run a class method. Please ensure that runner is "
"a function with the prototype 'def foo(G, ...)', where G is an "
"object of type rllab.sampler.stateful_pool.SharedGlobal")

if self.n_parallel > 1:
return self.pool.map(_worker_run_map, [(runner, args) for args in args_list])
return self.pool.map(_worker_run_map,
[(runner, args) for args in args_list])
else:
ret = []
for args in args_list:
ret.append(runner(self.G, *args))
return ret

def run_imap_unordered(self, runner, args_list):
assert not inspect.ismethod(runner), (
"run_imap_unordered() cannot run a class method. Please ensure that"
"runner is a function with the prototype 'def foo(G, ...)', where "
"G is an object of type rllab.sampler.stateful_pool.SharedGlobal")

if self.n_parallel > 1:
for x in self.pool.imap_unordered(_worker_run_map, [(runner, args) for args in args_list]):
for x in self.pool.imap_unordered(
_worker_run_map, [(runner, args) for args in args_list]):
yield x
else:
for args in args_list:
yield runner(self.G, *args)

def run_collect(self, collect_once, threshold, args=None, show_prog_bar=True):
def run_collect(self,
collect_once,
threshold,
args=None,
show_prog_bar=True):
"""
Run the collector method using the worker pool. The collect_once method will receive 'G' as
its first argument, followed by the provided args, if any. The method should return a pair of values.
The first should be the object to be collected, and the second is the increment to be added.
This will continue until the total increment reaches or exceeds the given threshold.
Run the collector method using the worker pool. The collect_once method
will receive 'G' as its first argument, followed by the provided args,
if any. The method should return a pair of values. The first should be
the object to be collected, and the second is the increment to be added.
This will continue until the total increment reaches or exceeds the
given threshold.
Sample script:
def collect_once(G):
return 'a', 1
stateful_pool.run_collect(collect_once, threshold=3) # => ['a', 'a', 'a']
stateful_pool.run_collect(collect_once, threshold=3)
# should return ['a', 'a', 'a']
:param collector:
:param threshold:
:return:
"""
assert not inspect.ismethod(collect_once), (
"run_collect() cannot run a class method. Please ensure that "
"collect_once is a function with the prototype 'def foo(G, ...)', "
"where G is an object of type "
"rllab.sampler.stateful_pool.SharedGlobal")

if args is None:
args = tuple()
if self.pool:
manager = mp.Manager()
counter = manager.Value('i', 0)
lock = manager.RLock()
results = self.pool.map_async(
_worker_run_collect,
[(collect_once, counter, lock, threshold, args)] * self.n_parallel
)
_worker_run_collect, [(collect_once, counter, lock, threshold,
args)] * self.n_parallel)
if show_prog_bar:
pbar = ProgBarCounter(threshold)
last_value = 0
Expand Down

0 comments on commit 03a0271

Please sign in to comment.