Skip to content

Commit

Permalink
feat(workflows): use WorkflowConfiguration accelerator_type to overri…
Browse files Browse the repository at this point in the history
…de train_model accelerator

Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Aug 21, 2024
1 parent d306c48 commit 1978a25
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/pyrovelocity/workflows/main_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def map_model_configurations_over_data_set(
postprocessing_resource_limits: ResourcesJSON = default_resource_limits,
summarizing_resource_requests: ResourcesJSON = default_resource_requests,
summarizing_resource_limits: ResourcesJSON = default_resource_limits,
accelerator_type: str = "nvidia-tesla-t4",
upload_results: bool = True,
) -> list[SummarizeOutputs]:
"""
Expand Down Expand Up @@ -441,6 +442,7 @@ def map_model_configurations_over_data_set(
).with_overrides(
requests=Resources(**asdict(train_model_resource_requests)),
limits=Resources(**asdict(train_model_resource_limits)),
accelerator=GPUAccelerator(accelerator_type),
)
model_outputs.append(model_output)

Expand Down Expand Up @@ -614,6 +616,7 @@ def training_workflow(
postprocessing_resource_limits=config.postprocessing_resources_limits,
summarizing_resource_requests=config.summarizing_resources_requests,
summarizing_resource_limits=config.summarizing_resources_limits,
accelerator_type=config.accelerator_type,
upload_results=config.upload_results,
)
results.append(result)
Expand Down

0 comments on commit 1978a25

Please sign in to comment.