Skip to content

Commit

Permalink
added retries
Browse files Browse the repository at this point in the history
  • Loading branch information
georgebv committed Mar 12, 2024
1 parent f2ec9a1 commit fdd0020
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 139 deletions.
8 changes: 8 additions & 0 deletions src/aiosalesforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"ResponseEvent",
"RestApiCallConsumptionEvent",
"RetryEvent",
"ExceptionRule",
"ResponseRule",
"RetryPolicy",
]

from .client import Salesforce
Expand All @@ -17,3 +20,8 @@
RestApiCallConsumptionEvent,
RetryEvent,
)
from .retries import (
ExceptionRule,
ResponseRule,
RetryPolicy,
)
68 changes: 51 additions & 17 deletions src/aiosalesforce/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import itertools
import logging
import re
import warnings
Expand All @@ -15,8 +17,10 @@
RequestEvent,
ResponseEvent,
RestApiCallConsumptionEvent,
RetryEvent,
)
from aiosalesforce.exceptions import SalesforceWarning, raise_salesforce_error
from aiosalesforce.retries import POLICY_DEFAULT, RetryPolicy
from aiosalesforce.sobject import SobjectClient

logger = logging.getLogger(__name__)
Expand All @@ -33,23 +37,26 @@ class Salesforce:
base_url : str
Base URL of the Salesforce instance.
Must be in the format:
- Production : https://[MyDomainName].my.salesforce.com
- Sandbox : https://[MyDomainName]-[SandboxName].sandbox.my.salesforce.com
- Developer org : https://[MyDomainName].develop.my.salesforce.com
Production : https://[MyDomainName].my.salesforce.com
Sandbox : https://[MyDomainName]-[SandboxName].sandbox.my.salesforce.com
Developer org : https://[MyDomainName].develop.my.salesforce.com
auth : Auth
Authentication object.
version : str, optional
Salesforce API version.
By default, uses the latest version.
event_hooks : list[Callable[[Event], Awaitable[None] | None]], optional
List of event hooks.
An event hook is a function taking a single argument which contains
information (type and context) about the event.
When an event occurs, all event hooks are called concurrently.
Therefore, the order of execution is not guaranteed and it is the
responsibility of the event hook to determine if it should react to the event.
Asynchronous event hooks are awaited concurrently and synchronous hooks
are executed using the running event loop's executor.
List of functions or coroutines executed when an event occurs.
Hooks are executed concurrently and order of execution is not guaranteed.
All hooks must be thread-safe.
retry_policy : RetryPolicy, optional
Retry policy for requests.
The default policy retries requests up to 3 times with exponential backoff
and retries the following:
httpx Transport errors (excluding timeouts)
Server errors (5xx)
Row lock errors
Rate limit errors
"""

Expand All @@ -59,6 +66,7 @@ class Salesforce:
auth: Auth
version: str
event_bus: EventBus
retry_policy: RetryPolicy

def __init__(
self,
Expand All @@ -67,6 +75,7 @@ def __init__(
auth: Auth,
version: str = "60.0",
event_hooks: list[Callable[[Event], Awaitable[None] | None]] | None = None,
retry_policy: RetryPolicy = POLICY_DEFAULT,
) -> None:
self.httpx_client = httpx_client
self.auth = auth
Expand All @@ -92,6 +101,7 @@ def __init__(
self.base_url = str(match_.groups()[0])

self.event_bus = EventBus(event_hooks)
self.retry_policy = retry_policy

@wraps(httpx.AsyncClient.request)
async def request(self, *args, **kwargs) -> httpx.Response:
Expand All @@ -117,9 +127,22 @@ async def request(self, *args, **kwargs) -> httpx.Response:
await self.event_bus.publish_event(
RequestEvent(type="request", request=request)
)

retry_context = self.retry_policy.create_context()
refreshed: bool = False
while True:
response = await self.httpx_client.send(request)
for attempt in itertools.count():
try:
response = await self.httpx_client.send(request)
except Exception as exc:
if await retry_context.should_retry(exc):
await asyncio.gather(
self.retry_policy.sleep(attempt),
self.event_bus.publish_event(
RetryEvent(type="retry", request=request)
),
)
continue
raise
await self.event_bus.publish_event(
RestApiCallConsumptionEvent(
type="rest_api_call_consumption", response=response
Expand All @@ -139,7 +162,18 @@ async def request(self, *args, **kwargs) -> httpx.Response:
request.headers["Authorization"] = f"Bearer {access_token}"
refreshed = True
continue
# TODO Check retry policies; emit retry event
if await retry_context.should_retry(response):
await asyncio.gather(
self.retry_policy.sleep(attempt),
self.event_bus.publish_event(
RetryEvent(
type="retry",
request=request,
response=response,
)
),
)
continue
raise_salesforce_error(response)
if "Warning" in response.headers:
warnings.warn(response.headers["Warning"], SalesforceWarning)
Expand All @@ -151,7 +185,7 @@ async def request(self, *args, **kwargs) -> httpx.Response:
async def query(
self,
query: str,
include_deleted_records: bool = False,
include_all_records: bool = False,
) -> AsyncIterator[dict]:
"""
Execute a SOQL query.
Expand All @@ -160,7 +194,7 @@ async def query(
----------
query : str
SOQL query.
include_deleted_records : bool, optional
include_all_records : bool, optional
If True, includes all (active/deleted/archived) records.
Returns
Expand All @@ -169,7 +203,7 @@ async def query(
An asynchronous iterator of query results.
"""
operation = "query" if not include_deleted_records else "queryAll"
operation = "query" if not include_all_records else "queryAll"

next_url: str | None = None
while True:
Expand Down
10 changes: 10 additions & 0 deletions src/aiosalesforce/events/event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@


class EventBus:
"""
Event bus used to dispatch events to subscribed callbacks.
Parameters
----------
callbacks : list[Callable[[Event], Awaitable[None] | None]], optional
List of callbacks to subscribe to the event bus.
"""

_callbacks: set[CallbackType]

def __init__(self, callbacks: list[CallbackType] | None = None) -> None:
Expand Down
43 changes: 33 additions & 10 deletions src/aiosalesforce/events/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class ResponseMixin:
response: Response
response: Response | None

@property
def consumed(self) -> int | None:
Expand All @@ -20,6 +20,8 @@ def remaining(self) -> int | None:

@cached_property
def __api_usage(self) -> tuple[int, int] | tuple[None, None]:
if self.response is None:
return (None, None)
if "application/json" not in self.response.headers.get("content-type", None):
return (None, None)
try:
Expand Down Expand Up @@ -48,31 +50,52 @@ class Event:

@dataclass
class RequestEvent(Event):
"""Emitted before a request is sent for the first time."""
"""
Emitted before a request is sent for the first time.
Is not emitted by authentication requests.
"""

type: Literal["request"]
request: Request


@dataclass
class ResponseEvent(Event, ResponseMixin):
"""Emitted after an OK response is received after the last retry attempt."""
class RetryEvent(Event, ResponseMixin):
"""
Emitted immediately before a request is retried.
type: Literal["response"]
response: Response
Is not emitted by authentication requests.
"""

type: Literal["retry"]
request: Request
response: Response | None = None


@dataclass
class RetryEvent(Event, ResponseMixin):
"""Emitted immediately before a request is retried."""
class ResponseEvent(Event, ResponseMixin):
"""
Emitted after an OK (status code < 300) response is received.
type: Literal["retry"]
Is not emitted by authentication requests.
"""

type: Literal["response"]
response: Response


@dataclass
class RestApiCallConsumptionEvent(Event, ResponseMixin):
"""Emitted after a REST API call is consumed."""
"""
Emitted after a REST API call is consumed.
Emitted by all requests, including authentication requests.
"""

type: Literal["rest_api_call_consumption"]
response: Response
Expand Down
107 changes: 39 additions & 68 deletions src/aiosalesforce/retries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,42 @@
__all__ = [
"Always",
"Retry",
"RetryPolicy",
"ExceptionRule",
"ResponseRule",
"RULE_EXCEPTION_RETRY_TRANSPORT_ERRORS",
"RULE_RESPONSE_RETRY_SERVER_ERRORS",
"RULE_RESPONSE_RETRY_UNABLE_TO_LOCK_ROW",
"RULE_RESPONSE_RETRY_TOO_MANY_REQUESTS",
"POLICY_DEFAULT",
]

import asyncio
import logging
import random
import time

from httpx import Response

from .always import Always
from .base import RetryHook

logger = logging.getLogger(__name__)


class Retry:
def __init__(
self,
retry_hooks: list[RetryHook],
max_retries: int = 10,
timeout: float = 60.0,
backoff_base: float = 0.5,
backoff_factor: float = 2.0,
backoff_max: float = 20.0,
backoff_jitter: bool = True,
) -> None:
self.retry_hooks = retry_hooks
self.max_retries = max_retries
self.timeout = timeout
self.backoff_base = backoff_base
self.backoff_factor = backoff_factor
self.backoff_max = backoff_max
self.backoff_jitter = backoff_jitter

self._start = time.time()
self._attempt_count = 0

def should_retry(self, response: Response) -> bool:
if self._attempt_count >= self.max_retries:
logger.debug("Max retries reached")
return False
if time.time() - self._start > self.timeout:
logger.debug("Timeout reached")
return False
for hook in self.retry_hooks:
if hook.should_retry(response):
logger.debug(
"Retrying '%s %s' due to %s, this is attempt %d/%d",
response.request.method,
response.request.url,
hook.__class__.__name__,
self._attempt_count + 1,
self.max_retries,
)
self._attempt_count += 1
return True
return False

async def sleep(self) -> None:
delay = min(
self.backoff_max,
self.backoff_base
* (self.backoff_factor ** max(0, self._attempt_count - 1)),
)
if self.backoff_jitter:
delay = random.uniform(0, delay) # noqa: S311
logger.debug("Sleeping for %s seconds", delay)
await asyncio.sleep(delay)
from httpx import TimeoutException, TransportError

from .policy import RetryPolicy
from .rules import ExceptionRule, ResponseRule

RULE_EXCEPTION_RETRY_TRANSPORT_ERRORS = ExceptionRule(
TransportError,
lambda exc: not isinstance(exc, TimeoutException),
max_retries=3,
)
RULE_RESPONSE_RETRY_SERVER_ERRORS = ResponseRule(
lambda response: response.status_code >= 500,
max_retries=3,
)
RULE_RESPONSE_RETRY_UNABLE_TO_LOCK_ROW = ResponseRule(
lambda response: "UNABLE_TO_LOCK_ROW" in response.text,
max_retries=3,
)
RULE_RESPONSE_RETRY_TOO_MANY_REQUESTS = ResponseRule(
lambda response: response.status_code == 429,
max_retries=3,
)

POLICY_DEFAULT = RetryPolicy(
response_rules=[
RULE_RESPONSE_RETRY_SERVER_ERRORS,
RULE_RESPONSE_RETRY_UNABLE_TO_LOCK_ROW,
RULE_RESPONSE_RETRY_TOO_MANY_REQUESTS,
],
exception_rules=[RULE_EXCEPTION_RETRY_TRANSPORT_ERRORS],
)
13 changes: 0 additions & 13 deletions src/aiosalesforce/retries/always.py

This file was deleted.

Loading

0 comments on commit fdd0020

Please sign in to comment.