Skip to content

Commit

Permalink
Add a view table for user-job view
Browse files Browse the repository at this point in the history
  • Loading branch information
TorecLuik committed Aug 6, 2024
1 parent 399c661 commit 1557f45
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 27 deletions.
2 changes: 1 addition & 1 deletion biomero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
except pkg_resources.DistributionNotFound:
__version__ = "Version not found"

from .aggregates import *
from .eventsourcing import *
227 changes: 206 additions & 21 deletions biomero/aggregates.py → biomero/eventsourcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from eventsourcing.domain import Aggregate, event
from eventsourcing.application import Application
from uuid import UUID
from eventsourcing.application import Application, AggregateNotFound
from eventsourcing.system import ProcessApplication
from eventsourcing.dispatch import singledispatchmethod
from uuid import NAMESPACE_URL, UUID, uuid5
from typing import Any, Dict, List
from fabric import Result
import logging
from sqlalchemy import create_engine, text, Column, Integer, String, URL
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.exc import IntegrityError
from sqlalchemy.schema import CreateTable



# Create a logger for this module
logger = logging.getLogger(__name__)

# -------------------- DOMAIN MODEL -------------------- #


class ResultDict(dict):
def __init__(self, result: Result):
Expand All @@ -49,8 +59,8 @@ class WorkflowInitiated(Aggregate.Created):
group: int

@event(WorkflowInitiated)
def __init__(self, name: str,
description: str,
def __init__(self, name: str,
description: str,
user: int,
group: int):
self.name = name
Expand Down Expand Up @@ -104,9 +114,9 @@ class TaskCreated(Aggregate.Created):
params: Dict[str, Any]

@event(TaskCreated)
def __init__(self,
workflow_id: UUID,
task_name: str,
def __init__(self,
workflow_id: UUID,
task_name: str,
task_version: str,
input_data: Dict[str, Any],
params: Dict[str, Any]
Expand All @@ -129,7 +139,7 @@ class JobIdAdded(Aggregate.Event):
def add_job_id(self, job_id):
logger.debug(f"Adding job_id to Task: task_id={self.id}, job_id={job_id}")
self.job_ids.append(job_id)

class StatusUpdated(Aggregate.Event):
status: str

Expand Down Expand Up @@ -176,29 +186,50 @@ def fail_task(self, error_message: str):
pass


class JobAccount(Aggregate):
INITIAL_VERSION = 0

def __init__(self, user_id, group_id):
self.user_id = user_id
self.group_id = group_id
self.jobs = []

@classmethod
def create_id(cls, user_id, group_id):
return uuid5(NAMESPACE_URL, f'/jobaccount/{group_id}/{user_id}')

@event('JobAdded')
def add_job(self, job_id):
logger.debug(f"Adding job: id={self.id}, job={job_id}, user={self.user_id}, group={self.group_id}")
self.jobs.append(job_id)


# -------------------- APPLICATIONS -------------------- #


class WorkflowTracker(Application):

def initiate_workflow(self,
name: str,
description: str,
def initiate_workflow(self,
name: str,
description: str,
user: int,
group: int) -> UUID:
logger.debug(f"Initiating workflow: name={name}, description={description}, user={user}, group={group}")
workflow = WorkflowRun(name, description, user, group)
self.save(workflow)
return workflow.id

def add_task_to_workflow(self,
workflow_id: UUID,
def add_task_to_workflow(self,
workflow_id: UUID,
task_name: str,
task_version: str,
input_data: Dict[str, Any],
kwargs: Dict[str, Any]
) -> UUID:
logger.debug(f"Adding task to workflow: workflow_id={workflow_id}, task_name={task_name}, task_version={task_version}")

task = Task(workflow_id,
task_name,
task = Task(workflow_id,
task_name,
task_version,
input_data,
kwargs)
Expand Down Expand Up @@ -228,7 +259,7 @@ def fail_workflow(self, workflow_id: UUID, error_message: str):
workflow = self.repository.get(workflow_id)
workflow.fail_workflow(error_message)
self.save(workflow)

def start_task(self, task_id: UUID):
logger.debug(f"Starting task: task_id={task_id}")

Expand All @@ -249,24 +280,178 @@ def fail_task(self, task_id: UUID, error_message: str):
task = self.repository.get(task_id)
task.fail_task(error_message)
self.save(task)

def add_job_id(self, task_id, slurm_job_id):
logger.debug(f"Adding job_id to task: task_id={task_id}, slurm_job_id={slurm_job_id}")

task = self.repository.get(task_id)
task.add_job_id(slurm_job_id)
self.save(task)

def add_result(self, task_id, result):
logger.debug(f"Adding result to task: task_id={task_id}, result={result}")

task = self.repository.get(task_id)
task.add_result(result)
self.save(task)

def update_task_status(self, task_id, status):
logger.debug(f"Adding status to task: task_id={task_id}, status={status}")

task = self.repository.get(task_id)
task.update_task_status(status)
self.save(task)


#--------------------- VIEWS ---------------------------- #

# Base class for declarative class definitions
Base = declarative_base()


class JobView(Base):
__tablename__ = 'biomero_job_view'

slurm_job_id = Column(Integer, primary_key=True)
user = Column(Integer, nullable=False)
group = Column(Integer, nullable=False)


class JobAccounting(ProcessApplication):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Read database configuration from environment variables
database_url = URL.create(
drivername="postgresql+psycopg2",
username=os.getenv('POSTGRES_USER'),
password=os.getenv('POSTGRES_PASSWORD'),
host=os.getenv('POSTGRES_HOST', 'localhost'),
port=os.getenv('POSTGRES_PORT', 5432),
database=os.getenv('POSTGRES_DBNAME')
)

# Set up SQLAlchemy engine and session
self.engine = create_engine(database_url)
self.SessionLocal = sessionmaker(bind=self.engine)

# State tracking
self.workflows = {} # {wf_id: {"user": user, "group": group}}
self.tasks = {} # {task_id: wf_id}
self.jobs = {} # {job_id: (task_id, user, group)}

# Create defined tables (subclasses of Base) if they don't exist
Base.metadata.create_all(self.engine)

@singledispatchmethod
def policy(self, domain_event, process_event):
"""Default policy"""

@policy.register(WorkflowRun.WorkflowInitiated)
def _(self, domain_event, process_event):
"""Handle WorkflowInitiated event"""
user = domain_event.user
group = domain_event.group
wf_id = domain_event.originator_id

# Track workflow
self.workflows[wf_id] = {"user": user, "group": group}
logger.debug(f"Workflow initiated: wf_id={wf_id}, user={user}, group={group}")

# Optionally, persist this state if needed
# Optionally, add an event to do that, then save via collect
# process_event.collect_events(jobaccount, wfView)

@policy.register(WorkflowRun.TaskAdded)
def _(self, domain_event, process_event):
"""Handle TaskAdded event"""
task_id = domain_event.task_id
wf_id = domain_event.originator_id

# Track task
self.tasks[task_id] = wf_id
logger.debug(f"Task added: task_id={task_id}, wf_id={wf_id}")

# Optionally, persist this state if needed
# use .collect_events(agg) instead of .save(agg)
# process_event.collect_events(taskView)

@policy.register(Task.JobIdAdded)
def _(self, domain_event, process_event):
"""Handle JobIdAdded event"""
# Grab event payload
job_id = domain_event.job_id
task_id = domain_event.originator_id

# Find workflow and user/group for the task
wf_id = self.tasks.get(task_id)
if wf_id:
workflow_info = self.workflows.get(wf_id)
if workflow_info:
user = workflow_info["user"]
group = workflow_info["group"]

# Track job
self.jobs[job_id] = (task_id, user, group)
logger.debug(f"Job added: job_id={job_id}, task_id={task_id}, user={user}, group={group}")


# Update view table
self.update_view_table(job_id, user, group)
else:
logger.debug(f"JobIdAdded event ignored: task_id={task_id} not found in tasks")

# use .collect_events(agg) instead of .save(agg)
# process_event.collect_events(jobaccount)

def update_view_table(self, job_id, user, group):
"""Update the view table with new job information."""
with self.SessionLocal() as session:
try:
new_job = JobView(slurm_job_id=job_id, user=user, group=group)
session.add(new_job)
session.commit()
logger.debug(f"Inserted job into view table: job_id={job_id}, user={user}, group={group}")
except IntegrityError:
session.rollback()
# Handle the case where the job already exists in the table if necessary
logger.error(f"Failed to insert job into view table (already exists?): job_id={job_id}, user={user}, group={group}")

def get_jobs(self, user=None, group=None):
"""Retrieve jobs for a specific user and/or group.
Parameters:
- user (int, optional): The user ID to filter by.
- group (int, optional): The group ID to filter by.
Returns:
- Dictionary of user IDs to lists of job IDs if no user is specified.
- Dictionary with a single user ID key and a list of job IDs if user is specified.
Raises:
- ValueError: If neither user nor group is provided.
"""
if user is None and group is None:
# Retrieve all jobs grouped by user
with self.SessionLocal() as session:
jobs = session.query(JobView.user, JobView.slurm_job_id).all()
user_jobs = {}
for user_id, job_id in jobs:
if user_id not in user_jobs:
user_jobs[user_id] = []
user_jobs[user_id].append(job_id)
return user_jobs
else:
with self.SessionLocal() as session:
query = session.query(JobView.slurm_job_id)

if user is not None:
query = query.filter_by(user=user)

if group is not None:
query = query.filter_by(group=group)

jobs = query.all()
result = {user: [job.slurm_job_id for job in jobs]}
logger.debug(f"Retrieved jobs for user={user} and group={group}: {result}")
return result
13 changes: 9 additions & 4 deletions biomero/slurm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from importlib_resources import files
import io
import os
from biomero.aggregates import WorkflowTracker
from biomero.eventsourcing import WorkflowTracker, JobAccounting
from eventsourcing.system import System, SingleThreadedRunner

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -392,13 +393,17 @@ def __init__(self,
self.init_workflows()
self.validate(validate_slurm_setup=init_slurm)

# Setup workflow tracking
# Setup workflow tracking and accounting
self.track_workflows = track_workflows
system = System(pipes=[[WorkflowTracker, JobAccounting]])
if self.track_workflows: # use configured persistence from env
self.workflowTracker = WorkflowTracker()
runner = SingleThreadedRunner(system)
else: # turn off persistence, override
self.workflowTracker = WorkflowTracker(env={
runner = SingleThreadedRunner(system, env={
"PERSISTENCE_MODULE": ""})
runner.start()
self.workflowTracker = runner.get(WorkflowTracker)
self.jobAccounting = runner.get(JobAccounting)

def init_workflows(self, force_update: bool = False):
"""
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ dependencies = [
"fabric==3.1.0",
"paramiko==3.4.0",
"importlib_resources>=5.4.0",
"eventsourcing[crypto,postgres-dev]==9.2.22"
"eventsourcing[crypto,postgres-dev]==9.2.22",
"sqlalchemy==2.0.32"
]

[tool.setuptools.packages]
Expand Down

0 comments on commit 1557f45

Please sign in to comment.