Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yet Another Async Branch #154

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions dbos/_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import inspect
import sys
from contextlib import AbstractContextManager
from typing import Any, Callable, Coroutine, Generic, TypeVar, Union, cast

if sys.version_info < (3, 10):
from typing_extensions import TypeAlias
else:
from typing import TypeAlias

T = TypeVar("T", covariant=True) # A generic type for OK Result values
R = TypeVar("R", covariant=True) # A generic type for OK Result values


# OK branch of functional Result type
class Ok(Generic[T]):
__slots__ = "_value"

def __init__(self, value: T) -> None:
self._value = value

def is_ok(self) -> bool:
return True

def is_err(self) -> bool:
return False

def __call__(self) -> T:
return self._value


# Err branch of functional Result type
class Err:
__slots__ = "_value"

def __init__(self, value: BaseException) -> None:
self._value = value

def is_ok(self) -> bool:
return False

def is_err(self) -> bool:
return True

def __call__(self) -> Any:
raise self._value


Result: TypeAlias = Union[Ok[T], Err]


def _to_result_sync(func: Callable[[], T]) -> Result[T]:
try:
result = func()
return Ok(result)
except Exception as e:
return Err(e)


async def _to_result_async(func: Callable[[], Coroutine[Any, Any, T]]) -> Result[T]:
try:
result = await func()
return Ok(result)
except Exception as e:
return Err(e)


def to_result(
func: Callable[[], Union[T, Coroutine[Any, Any, T]]]
) -> Union[Result[T], Coroutine[Any, Any, Result[T]]]:

return (
_to_result_async(func)
if inspect.iscoroutinefunction(func)
else _to_result_sync(cast(Callable[[], T], func))
)


# def chain_result(
# result: Union[Result[T], Coroutine[Any, Any, Result[T]]],
# next_func: Callable[[Result[T]], R],
# ) -> Union[R, Coroutine[Any, Any, R]]:

# def chain_result_sync(result: Result[T]) -> R:
# return next_func(result)

# async def chain_result_async(coro: Coroutine[Any, Any, Result[T]]) -> R:
# result = await coro
# return next_func(result)

# return (
# chain_result_async(result)
# if inspect.iscoroutine(result)
# else chain_result_sync(cast(Result[T], result))
# )


def chain_result(
func: Callable[[], Union[T, Coroutine[Any, Any, T]]],
next_func: Callable[[Result[T]], R],
) -> Union[R, Coroutine[Any, Any, R]]:

def chain_result_sync(
func: Callable[[], T],
) -> R:
result = _to_result_sync(func)
return next_func(result)

async def chain_result_async(
func: Callable[[], Coroutine[Any, Any, T]],
) -> R:
result = await _to_result_async(func)
return next_func(result)

return (
chain_result_async(func)
if inspect.iscoroutinefunction(func)
else chain_result_sync(cast(Callable[[], T], func))
)


# def chain_ctx_mgr(
# result: Union[Result[T], Coroutine[Any, Any, Result[T]]],
# acm: AbstractContextManager,
# ) -> Union[T, Coroutine[Any, Any, T]]:
# exc = True
# try:
# try:
# return result()
# except:
# exc = False
# if not acm.__exit__(*sys.exc_info()):
# raise

# finally:
# if exc:
# acm.__exit__(None, None, None)
16 changes: 15 additions & 1 deletion pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dev = [
"pytest-order>=1.3.0",
"pyjwt>=2.9.0",
"pdm-backend>=2.4.2",
"pytest-asyncio>=0.24.0",
]

[tool.black]
Expand Down
103 changes: 103 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import uuid

import pytest
import sqlalchemy as sa

# Public API
from dbos import DBOS, SetWorkflowID
from dbos._context import get_local_dbos_context

# Private API because this is a test


@pytest.mark.skip(reason="This test is not working")
@pytest.mark.asyncio
async def test_async_workflow(dbos: DBOS) -> None:
txn_counter: int = 0
wf_counter: int = 0
step_counter: int = 0

@DBOS.workflow()
async def test_workflow(var1: str, var2: str) -> str:
ctx = get_local_dbos_context()
nonlocal wf_counter
wf_counter += 1
res1 = test_transaction(var1)
res2 = test_step(var2)
DBOS.logger.info("I'm test_workflow")
return res1 + res2

@DBOS.step()
def test_step(var: str) -> str:
nonlocal step_counter
step_counter += 1
DBOS.logger.info("I'm test_step")
return var + f"step{step_counter}"

@DBOS.transaction(isolation_level="REPEATABLE READ")
def test_transaction(var: str) -> str:
rows = (DBOS.sql_session.execute(sa.text("SELECT 1"))).fetchall()
nonlocal txn_counter
txn_counter += 1
DBOS.logger.info("I'm test_transaction")
return var + f"txn{txn_counter}{rows[0][0]}"

wfuuid = str(uuid.uuid4())
with SetWorkflowID(wfuuid):
result = await test_workflow("alice", "bob")
assert result == "alicetxn11bobstep1"
dbos._sys_db.wait_for_buffer_flush()

with SetWorkflowID(wfuuid):
result = await test_workflow("alice", "bob")
assert result == "alicetxn11bobstep1"

assert wf_counter == 2
assert step_counter == 1
assert txn_counter == 1


@pytest.mark.skip(reason="This test is not working")
@pytest.mark.asyncio
async def test_sync_workflow(dbos: DBOS) -> None:
txn_counter: int = 0
wf_counter: int = 0
step_counter: int = 0

@DBOS.workflow()
def test_workflow(var1: str, var2: str) -> str:
nonlocal wf_counter
wf_counter += 1
res1 = test_transaction(var1)
res2 = test_step(var2)
DBOS.logger.info("I'm test_workflow")
return res1 + res2

@DBOS.step()
def test_step(var: str) -> str:
nonlocal step_counter
step_counter += 1
DBOS.logger.info("I'm test_step")
return var + f"step{step_counter}"

@DBOS.transaction(isolation_level="REPEATABLE READ")
def test_transaction(var: str) -> str:
rows = (DBOS.sql_session.execute(sa.text("SELECT 1"))).fetchall()
nonlocal txn_counter
txn_counter += 1
DBOS.logger.info("I'm test_transaction")
return var + f"txn{txn_counter}{rows[0][0]}"

wfuuid = str(uuid.uuid4())
with SetWorkflowID(wfuuid):
result = test_workflow("alice", "bob")
assert result == "alicetxn11bobstep1"
dbos._sys_db.wait_for_buffer_flush()

with SetWorkflowID(wfuuid):
result = test_workflow("alice", "bob")
assert result == "alicetxn11bobstep1"

assert wf_counter == 2
assert step_counter == 1
assert txn_counter == 1
Loading