From 78b020f66d4570629e79758c064beb381c802cd9 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Wed, 28 Oct 2020 12:19:10 +0100 Subject: [PATCH 1/4] Add failing test_parallel_scheduling_with_unpicklable_tasks test. --- test/scheduler_test.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/test/scheduler_test.py b/test/scheduler_test.py index 6bfd39b4d6..4e6ffa766c 100644 --- a/test/scheduler_test.py +++ b/test/scheduler_test.py @@ -23,13 +23,16 @@ from multiprocessing import Process from helpers import unittest -import luigi.scheduler +import luigi import luigi.server -import luigi.configuration -from helpers import with_config +from helpers import with_config, RunOnceTask from luigi.target import FileAlreadyExists +class PicklableTask(RunOnceTask): + i = luigi.IntParameter() + + class SchedulerIoTest(unittest.TestCase): def test_pretty_id_unicode(self): @@ -286,6 +289,10 @@ class SchedulerWorkerTest(unittest.TestCase): def get_pending_ids(self, worker, state): return {task.id for task in worker.get_tasks(state, 'PENDING')} + def schedule_parallel(self, tasks): + return luigi.interface.build(tasks, local_scheduler=True, parallel_scheduling=True, + parallel_scheduling_processes=2) + def test_get_pending_tasks_with_many_done_tasks(self): sch = luigi.scheduler.Scheduler() sch.add_task(worker='NON_TRIVIAL', task_id='A', resources={'a': 1}) @@ -300,6 +307,17 @@ def test_get_pending_tasks_with_many_done_tasks(self): non_trivial_worker = scheduler_state.get_worker('NON_TRIVIAL') self.assertEqual({'A'}, self.get_pending_ids(non_trivial_worker, scheduler_state)) + def test_parallel_scheduling_with_picklable_tasks(self): + tasks = [PicklableTask(i=i) for i in range(5)] + self.assertTrue(self.schedule_parallel(tasks)) + + def test_parallel_scheduling_with_unpicklable_tasks(self): + class UnpicklableTask(RunOnceTask): + i = luigi.IntParameter() + + tasks = [UnpicklableTask(i=i) for i in range(5)] + self.assertFalse(self.schedule_parallel(tasks)) + class FailingOnDoubleRunTask(luigi.Task): time_to_check_secs = 1 From 043b4c670d57330ce2ed8fb235afe8b3984a3181 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Mon, 5 Jul 2021 17:52:44 +0200 Subject: [PATCH 2/4] Improve and fix loop checking for task completeness. --- luigi/worker.py | 123 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 83 insertions(+), 40 deletions(-) diff --git a/luigi/worker.py b/luigi/worker.py index ba575b7fdf..47a7b1fdf2 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -39,6 +39,7 @@ import subprocess import sys import contextlib +import warnings import queue as Queue import random @@ -342,6 +343,50 @@ def respond(self, response): self._scheduler.add_scheduler_message_response(self._task_id, self._message_id, response) +class SyncResult(object): + """ + Synchronous implementation of ``multiprocessing.pool.AsyncResult`` that immediately calls *func* + with *args* and *kwargs*. Its methods :py:meth:`get`, :py:meth:`wait`, :py:meth:`ready` and + :py:meth:`successful` work in a similar fashion, depending on the result of the function call. + """ + + def __init__(self, func, args=None, kwargs=None): + super(SyncResult, self).__init__() + + # store function and arguments + self._func = func + self._args = args or () + self._kwargs = kwargs or {} + + # store return value and potential exceptions + self._return_value = None + self._exception = None + + # immediately call + self._call() + + def _call(self): + try: + self._return_value = self._func(*self._args, **self._kwargs) + except BaseException as e: + self._exception = e + + def get(self, timeout=None): + if self._exception: + raise self._exception + else: + return self._return_value + + def wait(self, timeout=None): + return + + def ready(self): + return True + + def successful(self): + return self._exception is None + + class SingleProcessPool: """ Dummy process pool for using a single processor. @@ -349,14 +394,14 @@ class SingleProcessPool: Imitates the api of multiprocessing.Pool using single-processor equivalents. """ - def apply_async(self, function, args): - return function(*args) + def apply_async(self, function, args=None, kwargs=None): + return SyncResult(function, args=args, kwargs=kwargs) def close(self): - pass + return def join(self): - pass + return class DequeQueue(collections.deque): @@ -380,6 +425,8 @@ class AsyncCompletionException(Exception): """ def __init__(self, trace): + warnings.warn("{} is deprecated and will be removed in a future release".format( + self.__class__.__name__), DeprecationWarning) self.trace = trace @@ -389,19 +436,17 @@ class TracebackWrapper: """ def __init__(self, trace): + warnings.warn("{} is deprecated and will be removed in a future release".format( + self.__class__.__name__), DeprecationWarning) self.trace = trace -def check_complete(task, out_queue): +def check_complete(task): """ - Checks if task is complete, puts the result to out_queue. + Checks if task is complete. """ logger.debug("Checking if %s is complete", task) - try: - is_complete = task.complete() - except Exception: - is_complete = TracebackWrapper(traceback.format_exc()) - out_queue.put((task, is_complete)) + return task.complete() class worker(Config): @@ -727,7 +772,7 @@ def _handle_task_load_error(self, exception, task_ids): ) notifications.send_error_email(subject, error_message) - def add(self, task, multiprocess=False, processes=0): + def add(self, task, multiprocess=False, processes=0, wait_interval=0.01): """ Add a Task for the worker to check and possibly schedule and run. @@ -737,36 +782,36 @@ def add(self, task, multiprocess=False, processes=0): self._first_task = task.task_id self.add_succeeded = True if multiprocess: - queue = multiprocessing.Manager().Queue() pool = multiprocessing.Pool(processes=processes if processes > 0 else None) else: - queue = DequeQueue() pool = SingleProcessPool() self._validate_task(task) - pool.apply_async(check_complete, [task, queue]) + results = [(task, pool.apply_async(check_complete, (task,)))] - # we track queue size ourselves because len(queue) won't work for multiprocessing - queue_size = 1 try: seen = {task.task_id} - while queue_size: - current = queue.get() - queue_size -= 1 - item, is_complete = current - for next in self._add(item, is_complete): - if next.task_id not in seen: - self._validate_task(next) - seen.add(next.task_id) - pool.apply_async(check_complete, [next, queue]) - queue_size += 1 - except (KeyboardInterrupt, TaskException): - raise - except Exception as ex: - self.add_succeeded = False - formatted_traceback = traceback.format_exc() - self._log_unexpected_error(task) - task.trigger_event(Event.BROKEN_TASK, task, ex) - self._email_unexpected_error(task, formatted_traceback) + while results: + # fetch the first done result + for i, (task, result) in enumerate(list(results)): + if result.ready(): + results.pop(i) + break + else: + time.sleep(wait_interval) + continue + + # get the result or error + try: + is_complete = result.get() + except Exception as e: + is_complete = e + + for dep in self._add(task, is_complete): + if dep.task_id not in seen: + self._validate_task(dep) + seen.add(dep.task_id) + results.append((dep, pool.apply_async(check_complete, (dep,)))) + except BaseException: raise finally: pool.close() @@ -800,8 +845,6 @@ def _add(self, task, is_complete): self._check_complete_value(is_complete) except KeyboardInterrupt: raise - except AsyncCompletionException as ex: - formatted_traceback = ex.trace except BaseException: formatted_traceback = traceback.format_exc() @@ -881,9 +924,9 @@ def _validate_dependency(self, dependency): raise Exception('requires() must return Task objects but {} is a {}'.format(dependency, type(dependency))) def _check_complete_value(self, is_complete): - if is_complete not in (True, False): - if isinstance(is_complete, TracebackWrapper): - raise AsyncCompletionException(is_complete.trace) + if isinstance(is_complete, BaseException): + raise is_complete + elif not isinstance(is_complete, bool): raise Exception("Return value of Task.complete() must be boolean (was %r)" % is_complete) def _add_worker(self): From f60dcc70b68ae55821d21227cb2266dbe9bdb597 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Mon, 5 Jul 2021 18:48:25 +0200 Subject: [PATCH 3/4] Catch erros again while announcing scheduler failures. --- luigi/worker.py | 9 ++++++++- test/worker_parallel_scheduling_test.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/luigi/worker.py b/luigi/worker.py index 47a7b1fdf2..b127ac0ba7 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -811,7 +811,14 @@ def add(self, task, multiprocess=False, processes=0, wait_interval=0.01): self._validate_task(dep) seen.add(dep.task_id) results.append((dep, pool.apply_async(check_complete, (dep,)))) - except BaseException: + except (KeyboardInterrupt, TaskException): + raise + except Exception as ex: + self.add_succeeded = False + formatted_traceback = traceback.format_exc() + self._log_unexpected_error(task) + task.trigger_event(Event.BROKEN_TASK, task, ex) + self._email_unexpected_error(task, formatted_traceback) raise finally: pool.close() diff --git a/test/worker_parallel_scheduling_test.py b/test/worker_parallel_scheduling_test.py index fb3a56fba5..6528978627 100644 --- a/test/worker_parallel_scheduling_test.py +++ b/test/worker_parallel_scheduling_test.py @@ -175,7 +175,7 @@ def test_raise_unpicklable_exception_in_complete(self, send): send.check_called_once() self.assertEqual(UNKNOWN, self.sch.add_task.call_args[1]['status']) self.assertFalse(self.sch.add_task.call_args[1]['runnable']) - self.assertTrue('raise UnpicklableException()' in send.call_args[0][1]) + self.assertTrue("Can't pickle local object 'UnpicklableExceptionTask" in send.call_args[0][1]) @mock.patch('luigi.notifications.send_error_email') def test_raise_exception_in_requires(self, send): From a21754a22f2f8d9aa3bd5766a27bbf18de25f047 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Mon, 5 Jul 2021 21:08:20 +0200 Subject: [PATCH 4/4] Increase dropped code coverage. --- test/scheduler_test.py | 22 +--------- test/worker_parallel_scheduling_test.py | 57 +++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/test/scheduler_test.py b/test/scheduler_test.py index 4e6ffa766c..ee4c2a24c3 100644 --- a/test/scheduler_test.py +++ b/test/scheduler_test.py @@ -23,16 +23,11 @@ from multiprocessing import Process from helpers import unittest -import luigi import luigi.server -from helpers import with_config, RunOnceTask +from helpers import with_config from luigi.target import FileAlreadyExists -class PicklableTask(RunOnceTask): - i = luigi.IntParameter() - - class SchedulerIoTest(unittest.TestCase): def test_pretty_id_unicode(self): @@ -289,10 +284,6 @@ class SchedulerWorkerTest(unittest.TestCase): def get_pending_ids(self, worker, state): return {task.id for task in worker.get_tasks(state, 'PENDING')} - def schedule_parallel(self, tasks): - return luigi.interface.build(tasks, local_scheduler=True, parallel_scheduling=True, - parallel_scheduling_processes=2) - def test_get_pending_tasks_with_many_done_tasks(self): sch = luigi.scheduler.Scheduler() sch.add_task(worker='NON_TRIVIAL', task_id='A', resources={'a': 1}) @@ -307,17 +298,6 @@ def test_get_pending_tasks_with_many_done_tasks(self): non_trivial_worker = scheduler_state.get_worker('NON_TRIVIAL') self.assertEqual({'A'}, self.get_pending_ids(non_trivial_worker, scheduler_state)) - def test_parallel_scheduling_with_picklable_tasks(self): - tasks = [PicklableTask(i=i) for i in range(5)] - self.assertTrue(self.schedule_parallel(tasks)) - - def test_parallel_scheduling_with_unpicklable_tasks(self): - class UnpicklableTask(RunOnceTask): - i = luigi.IntParameter() - - tasks = [UnpicklableTask(i=i) for i in range(5)] - self.assertFalse(self.schedule_parallel(tasks)) - class FailingOnDoubleRunTask(luigi.Task): time_to_check_secs = 1 diff --git a/test/worker_parallel_scheduling_test.py b/test/worker_parallel_scheduling_test.py index 6528978627..50af98f836 100644 --- a/test/worker_parallel_scheduling_test.py +++ b/test/worker_parallel_scheduling_test.py @@ -20,6 +20,7 @@ import os import pickle import time +import warnings from helpers import unittest import mock @@ -28,6 +29,7 @@ import luigi from luigi.worker import Worker from luigi.task_status import UNKNOWN +from helpers import RunOnceTask def running_children(): @@ -95,6 +97,10 @@ class UnpicklableException(Exception): raise UnpicklableException() +class PicklableTask(RunOnceTask): + i = luigi.IntParameter() + + class ParallelSchedulingTest(unittest.TestCase): def setUp(self): @@ -183,3 +189,54 @@ def test_raise_exception_in_requires(self, send): send.check_called_once() self.assertEqual(UNKNOWN, self.sch.add_task.call_args[1]['status']) self.assertFalse(self.sch.add_task.call_args[1]['runnable']) + + def test_parallel_scheduling_with_picklable_tasks(self): + tasks = [PicklableTask(i=i) for i in range(5)] + success = luigi.interface.build(tasks, local_scheduler=True, parallel_scheduling=True, + parallel_scheduling_processes=2) + self.assertTrue(success) + + def test_parallel_scheduling_with_unpicklable_tasks(self): + class UnpicklableTask(RunOnceTask): + i = luigi.IntParameter() + + tasks = [UnpicklableTask(i=i) for i in range(5)] + success = luigi.interface.build(tasks, local_scheduler=True, parallel_scheduling=True, + parallel_scheduling_processes=2) + self.assertFalse(success) + + def test_sync_result(self): + def func1(a, b): + return a + b + + def func2(a, b): + raise Exception("unknown") + + def func3(a): + raise Exception("never called") + + r = luigi.worker.SyncResult(func1, (1, 2)) + self.assertIsNone(r.wait()) + self.assertTrue(r.ready()) + self.assertTrue(r.successful()) + self.assertEqual(r.get(), 3) + + r = luigi.worker.SyncResult(func2, (1, 2)) + self.assertIsNone(r.wait()) + self.assertTrue(r.ready()) + self.assertFalse(r.successful()) + with self.assertRaises(Exception): + r.get() + + r = luigi.worker.SyncResult(func3, (1, 2)) + self.assertIsNone(r.wait()) + self.assertTrue(r.ready()) + self.assertFalse(r.successful()) + with self.assertRaises(TypeError): + r.get() + + def test_deprecations(self): + with warnings.catch_warnings(record=True) as w: + luigi.worker.AsyncCompletionException("foo") + luigi.worker.TracebackWrapper("foo") + self.assertEqual(len(w), 2)