Skip to content

Commit

Permalink
add max_retries for parallel tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
leo-schick committed Nov 22, 2023
1 parent 4837cdf commit 0670c22
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 15 deletions.
31 changes: 20 additions & 11 deletions mara_pipelines/parallel_tasks/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ def __init__(self, id: str, description: str, file_pattern: str, read_mode: Read
max_number_of_parallel_tasks: Optional[int] = None, file_dependencies: Optional[List[str]] = None, date_regex: Optional[str] = None,
partition_target_table_by_day_id: bool = False, truncate_partitions: bool = False,
commands_before: Optional[List[pipelines.Command]] = None, commands_after: Optional[List[pipelines.Command]] = None,
db_alias: Optional[str] = None, timezone: Optional[str] = None) -> None:
db_alias: Optional[str] = None, timezone: Optional[str] = None,
max_retries: Optional[int] = None) -> None:
pipelines.ParallelTask.__init__(self, id=id, description=description,
max_number_of_parallel_tasks=max_number_of_parallel_tasks,
commands_before=commands_before, commands_after=commands_after)
commands_before=commands_before, commands_after=commands_after,
max_retries=max_retries)
self.file_pattern = file_pattern
self.read_mode = read_mode
self.date_regex = date_regex
Expand Down Expand Up @@ -139,12 +141,14 @@ def update_file_dependencies():
id='create_partitions',
description='Creates required target table partitions',
commands=[sql.ExecuteSQL(sql_statement='\n'.join(slice), echo_queries=False, db_alias=self.db_alias)
for slice in more_itertools.sliced(sql_statements, 50)])
for slice in more_itertools.sliced(sql_statements, 50)],
max_retries=self.max_retries)

sub_pipeline.add(create_partitions_task)

for n, chunk in enumerate(more_itertools.chunked(files_per_day.items(), chunk_size)):
task = pipelines.Task(id=str(n), description='Reads a portion of the files')
task = pipelines.Task(id=str(n), description='Reads a portion of the files',
max_retries=self.max_retries)
for (day, files) in chunk:
target_table = self.target_table + '_' + day.strftime("%Y%m%d")
for file in files:
Expand All @@ -155,7 +159,8 @@ def update_file_dependencies():
for n, chunk in enumerate(more_itertools.chunked(files, chunk_size)):
sub_pipeline.add(
pipelines.Task(id=str(n), description=f'Reads {len(chunk)} files',
commands=sum([self.parallel_commands(x[0]) for x in chunk], [])))
commands=sum([self.parallel_commands(x[0]) for x in chunk], []),
max_retries=self.max_retries))

def parallel_commands(self, file_name: str) -> List[pipelines.Command]:
return [self.read_command(file_name)] + (
Expand All @@ -180,14 +185,16 @@ def __init__(self, id: str, description: str, file_pattern: str, read_mode: Read
mapper_script_file_name: Optional[str] = None, make_unique: bool = False, db_alias: Optional[str] = None,
delimiter_char: Optional[str] = None, quote_char: Optional[str] = None, null_value_string: Optional[str] = None,
skip_header: Optional[bool] = None, csv_format: bool = False,
timezone: Optional[str] = None, max_number_of_parallel_tasks: Optional[int] = None) -> None:
timezone: Optional[str] = None, max_number_of_parallel_tasks: Optional[int] = None,
max_retries: Optional[int] = None) -> None:
_ParallelRead.__init__(self, id=id, description=description, file_pattern=file_pattern,
read_mode=read_mode, target_table=target_table, file_dependencies=file_dependencies,
date_regex=date_regex, partition_target_table_by_day_id=partition_target_table_by_day_id,
truncate_partitions=truncate_partitions,
commands_before=commands_before, commands_after=commands_after,
db_alias=db_alias, timezone=timezone,
max_number_of_parallel_tasks=max_number_of_parallel_tasks)
max_number_of_parallel_tasks=max_number_of_parallel_tasks,
max_retries=max_retries)
self.compression = compression
self.mapper_script_file_name = mapper_script_file_name or ''
self.make_unique = make_unique
Expand Down Expand Up @@ -231,16 +238,18 @@ def html_doc_items(self) -> List[Tuple[str, str]]:

class ParallelReadSqlite(_ParallelRead):
def __init__(self, id: str, description: str, file_pattern: str, read_mode: ReadMode, sql_file_name: str,
target_table: str, file_dependencies: List[str] = None, date_regex: str = None,
target_table: str, file_dependencies: Optional[List[str]] = None, date_regex: Optional[str] = None,
partition_target_table_by_day_id: bool = False, truncate_partitions: bool = False,
commands_before: List[pipelines.Command] = None, commands_after: List[pipelines.Command] = None,
db_alias: str = None, timezone=None, max_number_of_parallel_tasks: int = None) -> None:
commands_before: Optional[List[pipelines.Command]] = None, commands_after: Optional[List[pipelines.Command]] = None,
db_alias: Optional[str] = None, timezone=None, max_number_of_parallel_tasks: Optional[int] = None,
max_retries: Optional[int] = None) -> None:
_ParallelRead.__init__(self, id=id, description=description, file_pattern=file_pattern,
read_mode=read_mode, target_table=target_table, file_dependencies=file_dependencies,
date_regex=date_regex, partition_target_table_by_day_id=partition_target_table_by_day_id,
truncate_partitions=truncate_partitions,
commands_before=commands_before, commands_after=commands_after, db_alias=db_alias,
timezone=timezone, max_number_of_parallel_tasks=max_number_of_parallel_tasks)
timezone=timezone, max_number_of_parallel_tasks=max_number_of_parallel_tasks,
max_retries=max_retries)
self.sql_file_name = sql_file_name

def read_command(self, file_name: str) -> List[pipelines.Command]:
Expand Down
6 changes: 4 additions & 2 deletions mara_pipelines/parallel_tasks/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def add_parallel_tasks(self, sub_pipeline: 'pipelines.Pipeline') -> None:
sub_pipeline.add(pipelines.Task(
id='_'.join([re.sub('[^0-9a-z\-_]+', '', str(x).lower().replace('-', '_')) for x in parameter_tuple]),
description=f'Runs the script with parameters {repr(parameter_tuple)}',
commands=[python.ExecutePython(file_name=self.file_name, args=list(parameter_tuple))]))
commands=[python.ExecutePython(file_name=self.file_name, args=list(parameter_tuple))],
max_retries=self.max_retries))

def html_doc_items(self) -> List[Tuple[str, str]]:
path = self.parent.base_path() / self.file_name
Expand Down Expand Up @@ -58,7 +59,8 @@ def add_parallel_tasks(self, sub_pipeline: 'pipelines.Pipeline') -> None:
sub_pipeline.add(pipelines.Task(
id=str(parameter).lower().replace(' ', '_').replace('-', '_'),
description=f'Runs the function with parameters {repr(parameter)}',
commands=[python.RunFunction(lambda args=parameter: self.function(args))]))
commands=[python.RunFunction(lambda args=parameter: self.function(args))],
max_retries=self.max_retries))

def html_doc_items(self) -> List[Tuple[str, str]]:
return [('function', _.pre[escape(str(self.function))]),
Expand Down
3 changes: 2 additions & 1 deletion mara_pipelines/parallel_tasks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def add_parallel_tasks(self, sub_pipeline: 'pipelines.Pipeline') -> None:
echo_queries=self.echo_queries, timezone=self.timezone, replace=replace)
if self.sql_file_name else
sql.ExecuteSQL(sql_statement=self.sql_statement, db_alias=self.db_alias,
echo_queries=self.echo_queries, timezone=self.timezone, replace=replace)]))
echo_queries=self.echo_queries, timezone=self.timezone, replace=replace)],
max_retries=self.max_retries))

def html_doc_items(self) -> List[Tuple[str, str]]:
return [('db', _.tt[self.db_alias])] \
Expand Down
4 changes: 3 additions & 1 deletion mara_pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,16 @@ def run(self):

class ParallelTask(Node):
def __init__(self, id: str, description: str, max_number_of_parallel_tasks: Optional[int] = None,
commands_before: Optional[List[Command]] = None, commands_after: Optional[List[Command]] = None) -> None:
commands_before: Optional[List[Command]] = None, commands_after: Optional[List[Command]] = None,
max_retries: Optional[int] = None) -> None:
super().__init__(id, description)
self.commands_before = []
for command in commands_before or []:
self.add_command_before(command)
self.commands_after = []
for command in commands_after or []:
self.add_command_after(command)
self.max_retries = max_retries
self.max_number_of_parallel_tasks = max_number_of_parallel_tasks

def add_command_before(self, command: Command):
Expand Down

0 comments on commit 0670c22

Please sign in to comment.