Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safe bulk create #291

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions validator/app/src/compute_horde_validator/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,21 +439,21 @@ def wrapped(*args, **kwargs):
"schedule": timedelta(minutes=5),
"options": {},
},
"llm_prompt_generation": {
"task": "compute_horde_validator.validator.tasks.llm_prompt_generation",
"schedule": timedelta(minutes=5),
"options": {},
},
"llm_prompt_sampling": {
"task": "compute_horde_validator.validator.tasks.llm_prompt_sampling",
"schedule": timedelta(minutes=30),
"options": {},
},
"llm_prompt_answering": {
"task": "compute_horde_validator.validator.tasks.llm_prompt_answering",
"schedule": timedelta(minutes=5),
"options": {},
},
# "llm_prompt_generation": {
# "task": "compute_horde_validator.validator.tasks.llm_prompt_generation",
# "schedule": timedelta(minutes=5),
# "options": {},
# },
# "llm_prompt_sampling": {
# "task": "compute_horde_validator.validator.tasks.llm_prompt_sampling",
# "schedule": timedelta(minutes=30),
# "options": {},
# },
# "llm_prompt_answering": {
# "task": "compute_horde_validator.validator.tasks.llm_prompt_answering",
# "schedule": timedelta(minutes=5),
# "options": {},
# },
}
if env.bool("DEBUG_RUN_BEAT_VERY_OFTEN", default=False):
CELERY_BEAT_SCHEDULE["run_synthetic_jobs"]["schedule"] = crontab(minute="*")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
LlmPromptsSyntheticJobGenerator,
)
from compute_horde_validator.validator.synthetic_jobs.scoring import get_manifest_multiplier
from compute_horde_validator.validator.synthetic_jobs.db import safe_bulk_create
from compute_horde_validator.validator.utils import MACHINE_SPEC_CHANNEL

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1421,18 +1422,21 @@ def _db_persist_system_events(ctx: BatchContext) -> None:
# it's possible some events were already inserted during
# a previous call, but the operation failed before clearing
# the events list, so ignore insert conflicts
SystemEvent.objects.bulk_create(ctx.events, ignore_conflicts=True)
safe_bulk_create(SystemEvent, ctx.events, batch_size=500, timeout=30, ignore_conflicts=True)
except Exception as exc:
logger.error("Failed to persist system events: %r", exc)
finally:
# we call this function multiple times during a batch,
# clear the list to avoid persisting the same event
# multiple times
# also lets clear even if there was some error - better to miss some event
# than have this issue multiple times
ctx.events.clear()
except Exception as exc:
logger.error("Failed to persist system events: %r", exc)


# sync_to_async is needed since we use the sync Django ORM
@sync_to_async
def _db_persist(ctx: BatchContext) -> None:
def _db_persist_critical(ctx: BatchContext) -> None:
start_time = time.time()

# persist the batch and the jobs in the same transaction, to
Expand Down Expand Up @@ -1470,7 +1474,20 @@ def _db_persist(ctx: BatchContext) -> None:
score=job.score,
)
synthetic_jobs.append(synthetic_job)
synthetic_jobs = SyntheticJob.objects.bulk_create(synthetic_jobs)
synthetic_jobs = safe_bulk_create(SyntheticJob, synthetic_jobs, batch_size=500, timeout=60)
duration = time.time() - start_time
logger.info("Persisted to database in %.2f seconds", duration)


# sync_to_async is needed since we use the sync Django ORM
@sync_to_async
def _db_persist(ctx: BatchContext) -> None:
start_time = time.time()

if ctx.batch_id is not None:
batch = SyntheticJobBatch.objects.get(id=ctx.batch_id)
else:
batch = SyntheticJobBatch.objects.get(started_at=ctx.stage_start_time["BATCH_BEGIN"])

miner_manifests: list[MinerManifest] = []
for miner in ctx.miners.values():
Expand All @@ -1484,11 +1501,11 @@ def _db_persist(ctx: BatchContext) -> None:
online_executor_count=ctx.online_executor_count[miner.hotkey],
)
)
MinerManifest.objects.bulk_create(miner_manifests)
safe_bulk_create(MinerManifest, miner_manifests, batch_size=500, timeout=30)

# TODO: refactor into nicer abstraction
synthetic_jobs_map: dict[str, SyntheticJob] = {
str(synthetic_job.job_uuid): synthetic_job for synthetic_job in synthetic_jobs
str(synthetic_job.job_uuid): synthetic_job for synthetic_job in batch.synthetic_jobs.all()
}
prompt_samples: list[PromptSample] = []

Expand Down Expand Up @@ -1518,7 +1535,7 @@ def _db_persist(ctx: BatchContext) -> None:
max_timeout=started_payload.max_timeout,
)
)
JobStartedReceipt.objects.bulk_create(job_started_receipts)
safe_bulk_create(JobStartedReceipt, job_started_receipts, batch_size=500, timeout=30)

job_finished_receipts: list[JobFinishedReceipt] = []
for job in ctx.jobs.values():
Expand All @@ -1534,7 +1551,7 @@ def _db_persist(ctx: BatchContext) -> None:
score_str=finished_payload.score_str,
)
)
JobFinishedReceipt.objects.bulk_create(job_finished_receipts)
safe_bulk_create(JobFinishedReceipt, job_finished_receipts, batch_size=500, timeout=30)

duration = time.time() - start_time
logger.info("Persisted to database in %.2f seconds", duration)
Expand Down Expand Up @@ -1640,6 +1657,9 @@ async def execute_synthetic_batch_run(
func="_multi_close_client",
)

await ctx.checkpoint_system_event("_db_persist_critical")
await _db_persist_critical(ctx)

await ctx.checkpoint_system_event("_emit_telemetry_events")
try:
_emit_telemetry_events(ctx)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
import time

from collections.abc import Iterable
from itertools import islice
from typing import TypeVar

from django.db.models import Model

logger = logging.getLogger(__name__)

# Create a TypeVar that's bound to models.Model
T = TypeVar("T", bound=Model)


def safe_bulk_create(
model_class: type[T],
objects_to_create: Iterable[T],
batch_size: int = 1000,
timeout: float | None = None,
ignore_conflicts: bool = False,
):
start = time.time()

objs = (model for model in objects_to_create)

result_objects: list[T] = []

while True:
batch = list(islice(objs, batch_size))
if not batch:
break
result_objects.extend(
model_class.objects.bulk_create(batch, batch_size, ignore_conflicts=ignore_conflicts)
)

if timeout and time.time() - start > timeout:
logger.error("Bulk create operation timed out: model=%s", str(model_class))
break
return result_objects
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import time
from collections.abc import Iterable
from itertools import islice
from typing import TypeVar

import bittensor
import uvloop
from asgiref.sync import async_to_sync
from django.conf import settings
from django.db.models import Model

from compute_horde_validator.validator.models import Miner, SystemEvent
from compute_horde_validator.validator.synthetic_jobs.batch_run import execute_synthetic_batch_run
Expand Down
Loading