From 1978a2507160da94ce6bfca999b5ab8953c95876 Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Tue, 20 Aug 2024 22:41:48 -0400 Subject: [PATCH] feat(workflows): use WorkflowConfiguration accelerator_type to override train_model accelerator Signed-off-by: Cameron Smith --- src/pyrovelocity/workflows/main_workflow.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pyrovelocity/workflows/main_workflow.py b/src/pyrovelocity/workflows/main_workflow.py index b7817a808..c7a1d58b9 100644 --- a/src/pyrovelocity/workflows/main_workflow.py +++ b/src/pyrovelocity/workflows/main_workflow.py @@ -389,6 +389,7 @@ def map_model_configurations_over_data_set( postprocessing_resource_limits: ResourcesJSON = default_resource_limits, summarizing_resource_requests: ResourcesJSON = default_resource_requests, summarizing_resource_limits: ResourcesJSON = default_resource_limits, + accelerator_type: str = "nvidia-tesla-t4", upload_results: bool = True, ) -> list[SummarizeOutputs]: """ @@ -441,6 +442,7 @@ def map_model_configurations_over_data_set( ).with_overrides( requests=Resources(**asdict(train_model_resource_requests)), limits=Resources(**asdict(train_model_resource_limits)), + accelerator=GPUAccelerator(accelerator_type), ) model_outputs.append(model_output) @@ -614,6 +616,7 @@ def training_workflow( postprocessing_resource_limits=config.postprocessing_resources_limits, summarizing_resource_requests=config.summarizing_resources_requests, summarizing_resource_limits=config.summarizing_resources_limits, + accelerator_type=config.accelerator_type, upload_results=config.upload_results, ) results.append(result)