Skip to content

Commit

Permalink
Intermittent bug fixes (#341)
Browse files Browse the repository at this point in the history
* Fix up ADIFeauresUploader to work with autoincrementing IDs; remove duplicate db connections by workers; add sig alarm timeout for single value computation

* Add autocleanup in case feature seems to never complete.

* Fix indefinite wait, check flag.

* Fix tuple typo.

* Fix possible race condition in notification processing

* Fix timing of notifications block in workers

* Generalizes timeout, applies it also to single sample computation.
  • Loading branch information
jimmymathews authored Jul 30, 2024
1 parent 6414341 commit dcc5927
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 155 deletions.
2 changes: 0 additions & 2 deletions spatialprofilingtoolbox/db/importance_score_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ def _upload(
get_feature_description("gnn importance score"),
len(importance_score_set_indexer),
),
impute_zeros=True,
upload_anyway=True,
) as feature_uploader:
for histological_structure, row in df.iterrows():
feature_uploader.stage_feature_value(
Expand Down
2 changes: 1 addition & 1 deletion spatialprofilingtoolbox/db/source_file_parser_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_field_names(self, tablename):

def generate_basic_insert_query(self, tablename):
fields_sorted = self.get_field_names(tablename)
if tablename == 'quantitative_feature_value':
if tablename in ('quantitative_feature_value', 'feature_specification'):
fields_sorted = fields_sorted[1:]
handle_duplicates = 'ON CONFLICT DO NOTHING '
query = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _warn_no_value(self) -> None:
specification = str(self.job.feature_specification)
study = self.job.study
sample = self.job.sample
logger.warning(f'Feature {specification} ({sample}, {study}) could not be computed, worker generated None.')
logger.warning(f'Feature {specification} ({sample}, {study}) could not be computed, worker generated None. May insert None.')

def _insert_value(self, value: float | int) -> None:
study = self.job.study
Expand Down
114 changes: 79 additions & 35 deletions spatialprofilingtoolbox/ondemand/request_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from typing import cast
from typing import Callable
import signal

from psycopg import Connection as PsycopgConnection

from spatialprofilingtoolbox.ondemand.job_reference import ComputationJobReference
from spatialprofilingtoolbox.db.database_connection import DBConnection
from spatialprofilingtoolbox.db.database_connection import DBCursor
from spatialprofilingtoolbox.ondemand.providers.counts_provider import CountsProvider
Expand All @@ -18,7 +18,8 @@
CompositePhenotype,
UnivariateMetricsComputationResult,
)
from spatialprofilingtoolbox.ondemand.relevant_specimens import relevant_specimens_query
from spatialprofilingtoolbox.ondemand.timeout import create_timeout_handler
from spatialprofilingtoolbox.ondemand.timeout import SPTTimeoutError
from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger
Metrics1D = UnivariateMetricsComputationResult

Expand All @@ -40,6 +41,51 @@ def _nonempty(string: str) -> bool:
return string != ''


class FeatureComputationTimeoutHandler:
feature: str
study: str

def __init__(self, feature: str, study: str):
self.feature = feature
self.study = study

def handle(self) -> None:
message = f'Timed out waiting for the feature {self.feature} to complete. Aborting.'
logger.error(message)
if self._queue_size() == 0 and self._completed_size() < self._expected_size():
self._delete_feature()

def _queue_size(self) -> int:
with DBCursor(study=self.study) as cursor:
query = 'SELECT COUNT(*) FROM quantitative_feature_value_queue WHERE feature=%s ;'
cursor.execute(query, (self.feature,))
count = tuple(cursor.fetchall())[0][0]
return count

def _completed_size(self) -> int:
with DBCursor(study=self.study) as cursor:
query = 'SELECT COUNT(*) FROM quantitative_feature_value WHERE feature=%s ;'
cursor.execute(query, (self.feature,))
count = tuple(cursor.fetchall())[0][0]
return count

def _expected_size(self) -> int:
with DBCursor(study=self.study) as cursor:
query = 'SELECT COUNT(*) FROM specimen_data_measurement_process ;'
cursor.execute(query)
count = tuple(cursor.fetchall())[0][0]
return count

def _delete_feature(self) -> None:
logger.error('Also deleting the feature, since the queue was empty; we assume the remaining jobs failed.')
with DBCursor(study=self.study) as cursor:
param = (self.feature,)
cursor.execute('DELETE FROM quantitative_feature_value WHERE feature=%s ;', param)
cursor.execute('DELETE FROM feature_specifier WHERE feature_specification=%s ;', param)
cursor.execute('DELETE FROM feature_specification WHERE identifier=%s ;', param)



class OnDemandRequester:
"""Entry point for requesting computation by the on-demand service."""

Expand Down Expand Up @@ -89,21 +135,14 @@ def _counts(
) -> tuple[str, Metrics1D, Metrics1D]:
get = CountsProvider.get_metrics_or_schedule

def get_results() -> tuple[Metrics1D, str]:
def get_results1() -> tuple[Metrics1D, str]:
counts, feature1 = get(
study_name,
phenotype=phenotype,
cells_selected=selected,
)
return (counts, feature1)

with DBConnection() as connection:
connection._set_autocommit(True)
connection.execute('LISTEN new_items_in_queue ;')
connection.execute('LISTEN one_job_complete ;')
cls._wait_for_wrappedup(connection, get_results)
counts, feature1 = get_results()

def get_results2() -> tuple[Metrics1D, str]:
counts_all, feature2 = get(
study_name,
Expand All @@ -114,44 +153,49 @@ def get_results2() -> tuple[Metrics1D, str]:

with DBConnection() as connection:
connection._set_autocommit(True)
connection.execute('LISTEN new_items_in_queue ;')
connection.execute('LISTEN one_job_complete ;')
cls._wait_for_wrappedup(connection, get_results2)
counts_all, _ = get_results2()

cls._request_check_for_failed_jobs()
return (feature1, counts, counts_all)
cls._wait_for_wrappedup(connection, get_results1, study_name)
counts, feature1 = get_results1()

@classmethod
def _request_check_for_failed_jobs(cls) -> None:
with DBConnection() as connection:
connection._set_autocommit(True)
connection.execute('NOTIFY check_for_failed_jobs ;')
cls._wait_for_wrappedup(connection, get_results2, study_name)
counts_all, _ = get_results2()

return (feature1, counts, counts_all)

@classmethod
def _wait_for_wrappedup(
cls,
connection: PsycopgConnection,
get_results: Callable[[], tuple[Metrics1D, str]],
study_name: str,
) -> None:
counts, feature = get_results()
if not get_results()[0].is_pending:
logger.debug(f'Feature {feature} already complete.')
return
logger.debug(f'Waiting for signal that feature {feature} may be ready.')
connection.execute('LISTEN new_items_in_queue ;')
connection.execute('LISTEN one_job_complete ;')
notifications = connection.notifies()
for notification in notifications:
if not get_results()[0].is_pending:
logger.debug(f'Closing notification processing, {feature} ready.')
notifications.close()
break
channel = notification.channel
if channel == 'one_job_complete':
logger.debug(f'A job is complete, so {feature} may be ready.')
if not get_results()[0].is_pending:
logger.debug(f'And {feature} was ready. Closing.')

counts, feature = get_results()
handler = FeatureComputationTimeoutHandler(feature, study_name)
generic_handler = create_timeout_handler(handler.handle)
try:
if not counts.is_pending:
logger.debug(f'Feature {feature} already complete.')
return
logger.debug(f'Waiting for signal that feature {feature} may be ready, because the result is not ready yet.')

for notification in notifications:
channel = notification.channel
if channel == 'one_job_complete':
logger.debug(f'A job is complete, so {feature} may be ready. (PID: {notification.pid})')
_result = get_results()
if not _result[0].is_pending:
logger.debug(f'Closing notification processing, {feature} ready.')
notifications.close()
break
except SPTTimeoutError:
pass
finally:
generic_handler.disalarm()

@classmethod
def get_proximity_metrics(
Expand Down
45 changes: 45 additions & 0 deletions spatialprofilingtoolbox/ondemand/timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""General-purpose one-time timeout functionality based on Unix signal alarm."""
from typing import Callable
import signal

from spatialprofilingtoolbox.db.database_connection import DBCursor
from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger

logger = colorized_logger(__name__)


TIMEOUT_SECONDS_DEFAULT = 300


class SPTTimeoutError(RuntimeError):
def __init__(self, message: str):
super().__init__(message)
self.message = message


class TimeoutHandler:
active: bool
callback: Callable
timeout: int

def __init__(self, callback: Callable, timeout: int):
self.active = True
self.callback = callback
self.timeout = timeout

def handle(self, signum, frame) -> None:
if self.active:
message = f'Waited {self.timeout} seconds, timed out.'
logger.error(message)
self.callback()
raise TimeoutError(message)

def disalarm(self) -> None:
self.active = False


def create_timeout_handler(callback: Callable, timeout_seconds: int = TIMEOUT_SECONDS_DEFAULT) -> TimeoutHandler:
handler = TimeoutHandler(callback, timeout_seconds)
signal.signal(signal.SIGALRM, handler.handle)
signal.alarm(timeout_seconds)
return handler
51 changes: 31 additions & 20 deletions spatialprofilingtoolbox/ondemand/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from spatialprofilingtoolbox.db.describe_features import get_handle
from spatialprofilingtoolbox.ondemand.job_reference import ComputationJobReference
from spatialprofilingtoolbox.ondemand.scheduler import MetricComputationScheduler
from spatialprofilingtoolbox.ondemand.timeout import create_timeout_handler
from spatialprofilingtoolbox.ondemand.timeout import SPTTimeoutError
from spatialprofilingtoolbox.standalone_utilities.log_formats import colorized_logger
Job = ComputationJobReference

Expand All @@ -33,17 +35,17 @@ def start(self) -> None:

def _listen_for_queue_activity(self) -> None:
with DBConnection() as connection:
connection._set_autocommit(True)
self.connection = connection
self.connection._set_autocommit(True)
self.connection.execute('LISTEN new_items_in_queue ;')
logger.info('Listening on new_items_in_queue channel.')
self.notifications = self.connection.notifies()
while True:
self._wait_for_queue_activity_on(connection)
self._wait_for_queue_activity_on_connection()
self._work_until_complete()

def _wait_for_queue_activity_on(self, connection: PsycopgConnection) -> None:
connection.execute('LISTEN new_items_in_queue ;')
logger.info('Listening on new_items_in_queue channel.')
notifications = connection.notifies()
for notification in notifications:
notifications.close()
def _wait_for_queue_activity_on_connection(self) -> None:
for _ in self.notifications:
logger.info('Received notice of new items in the job queue.')
break

Expand All @@ -57,26 +59,35 @@ def _work_until_complete(self) -> None:
logger.info(f'Finished jobs {" ".join(completed_pids)}.')

def _one_job(self) -> tuple[bool, int]:
with DBConnection() as connection:
connection._set_autocommit(True)
self.connection = connection
pid = self.connection.info.backend_pid
job = self.queue.pop_uncomputed()
if job is None:
return (False, pid)
logger.info(f'{pid} doing job {job.feature_specification} {job.sample}.')
self._compute(job)
self._notify_complete(job)
return (True, pid)
pid = self.connection.info.backend_pid
job = self.queue.pop_uncomputed()
if job is None:
return (False, pid)
logger.info(f'{pid} doing job {job.feature_specification} {job.sample}.')
self._compute(job)
self._notify_complete(job)
return (True, pid)

def _no_value_wrapup(self, job) -> None:
provider = self._get_provider(job)
provider._warn_no_value()
provider._insert_null()

def _compute(self, job: Job) -> None:
provider = self._get_provider(job)
generic_handler = create_timeout_handler(
lambda *arg: self._no_value_wrapup(job),
timeout_seconds=150,
)
try:
provider.compute()
except SPTTimeoutError:
pass
except Exception as error:
logger.error(error)
print_exception(type(error), error, error.__traceback__)
self._get_provider(job)._warn_no_value()
finally:
generic_handler.disalarm()

def _notify_complete(self, job: Job) -> None:
self.connection.execute('NOTIFY one_job_complete ;')
Expand Down
Loading

0 comments on commit dcc5927

Please sign in to comment.