Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADR 031: Homedb Cache (Bolt 5.8) #1115

Draft
wants to merge 13 commits into
base: 5.0
Choose a base branch
from
7 changes: 4 additions & 3 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ Closing a driver will immediately shut down all connections in the pool.
.. autoclass:: neo4j.Driver()
:members: session, execute_query_bookmark_manager, encrypted, close,
verify_connectivity, get_server_info, verify_authentication,
supports_session_auth, supports_multi_db
supports_session_auth, supports_multi_db, force_home_database_resolution

.. method:: execute_query(query, parameters_=None,routing_=neo4j.RoutingControl.WRITE, database_=None, impersonated_user_=None, bookmark_manager_=self.execute_query_bookmark_manager, auth_=None, result_transformer_=Result.to_eager_result, **kwargs)

Expand Down Expand Up @@ -260,7 +260,8 @@ Closing a driver will immediately shut down all connections in the pool.
:param database\_:
Database to execute the query against.

None (default) uses the database configured on the server side.
:data:`None` (default) uses the database configured on the server
side.

.. Note::
It is recommended to always specify the database explicitly
Expand Down Expand Up @@ -1034,7 +1035,7 @@ Specifically, the following applies:
all queries within that session are executed with the explicit database
name 'movies' supplied. Any change to the user’s home database is
reflected only in sessions created after such change takes effect. This
behavior requires additional network communication. In clustered
behavior may requires additional network communication. In clustered
environments, it is strongly recommended to avoid a single point of
failure. For instance, by ensuring that the connection URI resolves to
multiple endpoints. For older Bolt protocol versions the behavior is the
Expand Down
3 changes: 2 additions & 1 deletion docs/source/async_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ Closing a driver will immediately shut down all connections in the pool.
:param database\_:
Database to execute the query against.

None (default) uses the database configured on the server side.
:data:`None` (default) uses the database configured on the server
side.

.. Note::
It is recommended to always specify the database explicitly
Expand Down
3 changes: 2 additions & 1 deletion src/neo4j/_async/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,8 @@ async def example(driver: neo4j.AsyncDriver) -> int:
:param database_:
Database to execute the query against.

None (default) uses the database configured on the server side.
:data:`None` (default) uses the database configured on the server
side.

.. Note::
It is recommended to always specify the database explicitly
Expand Down
138 changes: 138 additions & 0 deletions src/neo4j/_async/home_db_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import math
import typing as t
from time import monotonic

from .._async_compat.concurrency import AsyncCooperativeLock


if t.TYPE_CHECKING:
# TAuthKey = t.Tuple[t.Tuple[]]
TKey = str | tuple[tuple[str, t.Hashable], ...] | tuple[None]
TVal = tuple[float, str]


class AsyncHomeDbCache:
_ttl: float
_enabled: bool
_max_size: int | None

def __init__(
self,
enabled: bool = True,
ttl: float = float("inf"),
max_size: int | None = None,
) -> None:
if math.isnan(ttl) or ttl <= 0:
raise ValueError(f"home db cache ttl must be greater 0, got {ttl}")
self._enabled = enabled
self._ttl = ttl
self._cache: dict[TKey, TVal] = {}
self._lock = AsyncCooperativeLock()
self._oldest_entry = monotonic()
if max_size is not None and max_size <= 0:
raise ValueError(
f"home db cache max_size must be greater 0 or None, "
f"got {max_size}"
)
self._max_size = max_size

def compute_key(
self,
imp_user: str | None,
auth: dict | None,
) -> TKey:
if not self._enabled:
return (None,)
if imp_user is not None:
return imp_user
if auth is not None:
return _consolidate_auth_token(auth)
return (None,)

def get(self, key: TKey) -> str | None:
if not self._enabled:
return None
with self._lock:
self._clean(monotonic())
val = self._cache.get(key)
if val is None:
return None
return val[1]

def set(self, key: TKey, value: str | None) -> None:
if not self._enabled:
return
with self._lock:
now = monotonic()
self._clean(now)
if value is None:
self._cache.pop(key, None)
else:
self._cache[key] = (now, value)

def clear(self) -> None:
if not self._enabled:
return
with self._lock:
self._cache = {}
self._oldest_entry = monotonic()

def _clean(self, now: float | None = None) -> None:
now = monotonic() if now is None else now
if now - self._oldest_entry > self._ttl:
self._cache = {
k: v for k, v in self._cache.items() if now - v[0] < self._ttl
}
self._oldest_entry = min(
(v[0] for v in self._cache.values()), default=now
)
if self._max_size and len(self._cache) > self._max_size:
self._cache = dict(
sorted(
self._cache.items(),
key=lambda item: item[1][0],
reverse=True,
)[: int(self._max_size * 0.9)]
)

def __len__(self) -> int:
return len(self._cache)

@property
def enabled(self) -> bool:
return self._enabled


def _consolidate_auth_token(auth: dict) -> tuple | str:
if auth.get("scheme") == "basic" and isinstance(
auth.get("principal"), str
):
return auth["principal"]
return _hashable_dict(auth)


def _hashable_dict(d: dict) -> tuple:
return tuple(
(k, _hashable_dict(v) if isinstance(v, dict) else v)
for k, v in sorted(d.items())
)
6 changes: 4 additions & 2 deletions src/neo4j/_async/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"""

__all__ = [
"AcquireAuth",
"AcquisitionAuth",
"AcquisitionDatabase",
"AsyncBolt",
"AsyncBoltPool",
"AsyncNeo4jPool",
Expand All @@ -37,7 +38,8 @@
ConnectionErrorHandler,
)
from ._pool import (
AcquireAuth,
AcquisitionAuth,
AcquisitionDatabase,
AsyncBoltPool,
AsyncNeo4jPool,
)
39 changes: 14 additions & 25 deletions src/neo4j/_async/io/_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from ..._async_compat.network import AsyncBoltSocket
from ..._async_compat.util import AsyncUtil
from ..._auth_management import to_auth_dict
from ..._codec.hydration import (
HydrationHandlerABC,
v1 as hydration_v1,
Expand All @@ -39,12 +40,10 @@
from ..._sync.config import PoolConfig
from ...addressing import ResolvedAddress
from ...api import (
Auth,
ServerInfo,
Version,
)
from ...exceptions import (
AuthError,
ConfigurationError,
DriverError,
IncompleteCommit,
Expand Down Expand Up @@ -158,10 +157,7 @@ def __init__(
),
self.PROTOCOL_VERSION,
)
# so far `connection.recv_timeout_seconds` is the only available
# configuration hint that exists. Therefore, all hints can be stored at
# connection level. This might change in the future.
self.configuration_hints = {}
self.connection_hints = {}
self.patch = {}
self.outbox = AsyncOutbox(
self.socket,
Expand All @@ -187,7 +183,7 @@ def __init__(
self.user_agent = USER_AGENT

self.auth = auth
self.auth_dict = self._to_auth_dict(auth)
self.auth_dict = to_auth_dict(auth)
self.auth_manager = auth_manager
self.telemetry_disabled = telemetry_disabled

Expand All @@ -206,26 +202,14 @@ def _get_server_state_manager(self) -> ServerStateManagerBase: ...
@abc.abstractmethod
def _get_client_state_manager(self) -> ClientStateManagerBase: ...

@classmethod
def _to_auth_dict(cls, auth):
# Determine auth details
if not auth:
return {}
elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
return vars(Auth("basic", *auth))
else:
try:
return vars(auth)
except (KeyError, TypeError) as e:
# TODO: 6.0 - change this to be a DriverError (or subclass)
raise AuthError(
f"Cannot determine auth details from {auth!r}"
) from e

@property
def connection_id(self):
return self.server_info._metadata.get("connection_id", "<unknown id>")

@property
@abc.abstractmethod
def ssr_enabled(self) -> bool: ...

@property
@abc.abstractmethod
def supports_multiple_results(self):
Expand Down Expand Up @@ -308,6 +292,7 @@ def protocol_handlers(cls, protocol_version=None):
AsyncBolt5x5,
AsyncBolt5x6,
AsyncBolt5x7,
AsyncBolt5x8,
)

handlers = {
Expand All @@ -325,6 +310,7 @@ def protocol_handlers(cls, protocol_version=None):
AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5,
AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6,
AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7,
AsyncBolt5x8.PROTOCOL_VERSION: AsyncBolt5x8,
}

if protocol_version is None:
Expand Down Expand Up @@ -461,7 +447,10 @@ async def open(

# avoid new lines after imports for better readability and conciseness
# fmt: off
if protocol_version == (5, 7):
if protocol_version == (5, 8):
from ._bolt5 import AsyncBolt5x8
bolt_cls = AsyncBolt5x8
elif protocol_version == (5, 7):
from ._bolt5 import AsyncBolt5x7
bolt_cls = AsyncBolt5x7
elif protocol_version == (5, 6):
Expand Down Expand Up @@ -626,7 +615,7 @@ def re_auth(

:returns: whether the auth was changed
"""
new_auth_dict = self._to_auth_dict(auth)
new_auth_dict = to_auth_dict(auth)
if not force and new_auth_dict == self.auth_dict:
self.auth_manager = auth_manager
self.auth = auth
Expand Down
2 changes: 2 additions & 0 deletions src/neo4j/_async/io/_bolt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class AsyncBolt3(AsyncBolt):

PROTOCOL_VERSION = Version(3, 0)

ssr_enabled = False

supports_multiple_results = False

supports_multiple_databases = False
Expand Down
8 changes: 5 additions & 3 deletions src/neo4j/_async/io/_bolt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class AsyncBolt4x0(AsyncBolt):

PROTOCOL_VERSION = Version(4, 0)

ssr_enabled = False

supports_multiple_results = True

supports_multiple_databases = True
Expand Down Expand Up @@ -614,10 +616,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None):
)

def on_success(metadata):
self.configuration_hints.update(metadata.pop("hints", {}))
self.connection_hints.update(metadata.pop("hints", {}))
self.server_info.update(metadata)
if "connection.recv_timeout_seconds" in self.configuration_hints:
recv_timeout = self.configuration_hints[
if "connection.recv_timeout_seconds" in self.connection_hints:
recv_timeout = self.connection_hints[
"connection.recv_timeout_seconds"
]
if isinstance(recv_timeout, int) and recv_timeout > 0:
Expand Down
Loading