Skip to content

Commit

Permalink
Add support for batch requests to JSON RPC server
Browse files Browse the repository at this point in the history
  • Loading branch information
RaoulSchaffranek committed Nov 21, 2024
1 parent 853eaa1 commit ef1692b
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 42 deletions.
153 changes: 111 additions & 42 deletions pyk/src/pyk/rpc/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import json
import logging
from dataclasses import dataclass
from functools import partial
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, Iterable

from typing_extensions import Protocol

Expand Down Expand Up @@ -71,37 +72,105 @@ class JsonRpcMethod(Protocol):
def __call__(self, **kwargs: Any) -> Any: ...


class JsonRpcRequestHandler(BaseHTTPRequestHandler):
methods: dict[str, JsonRpcMethod]
@dataclass(frozen=True)
class JsonRpcRequest:

def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) -> None:
self.methods = methods
super().__init__(*args, **kwargs)
method: str
params: Any
id: Any

@staticmethod
def validate(request_dict: Any, valid_methods: Iterable[str]) -> JsonRpcRequest | JsonRpcError:
required_fields = ['jsonrpc', 'method', 'id']
for field in required_fields:
if field not in request_dict:
return JsonRpcError(-32600, f'Invalid request: missing field "{field}"', request_dict.get('id', None))

jsonrpc_version = request_dict['jsonrpc']
if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION:
return JsonRpcError(
-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request_dict.get('id', None)
)

method_name = request_dict['method']
if method_name not in valid_methods:
return JsonRpcError(-32601, f'Method "{method_name}" not found.', request_dict.get('id', None))

def send_json_error(self, code: int, message: str, id: Any = None) -> None:
error_dict = {
return JsonRpcRequest(
method=request_dict['method'], params=request_dict.get('params', None), id=request_dict.get('id', None)
)


@dataclass(frozen=True)
class JsonRpcBatchRequest:
requests: tuple[JsonRpcRequest]


@dataclass(frozen=True)
class JsonRpcResult:

def encode(self) -> bytes:
raise NotImplementedError('Subclasses must implement this method')


@dataclass(frozen=True)
class JsonRpcError(JsonRpcResult):

code: int
message: str
id: Any

def to_json(self) -> dict[str, Any]:
return {
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
'error': {
'code': code,
'message': message,
'code': self.code,
'message': self.message,
},
'id': id,
'id': self.id,
}
error_bytes = json.dumps(error_dict).encode('ascii')
self.set_response()
self.wfile.write(error_bytes)

def send_json_success(self, result: Any, id: Any) -> None:
response_dict = {
def encode(self) -> bytes:
return json.dumps(self.to_json()).encode('ascii')


@dataclass(frozen=True)
class JsonRpcSuccess(JsonRpcResult):
payload: Any
id: Any

def to_json(self) -> dict[str, Any]:
return {
'jsonrpc': JsonRpcServer.JSONRPC_VERSION,
'result': result,
'id': id,
'result': self.payload,
'id': self.id,
}
response_bytes = json.dumps(response_dict).encode('ascii')
self.set_response()

def encode(self) -> bytes:
return json.dumps(self.to_json()).encode('ascii')


@dataclass(frozen=True)
class JsonRpcBatchResult(JsonRpcResult):
results: tuple[JsonRpcError | JsonRpcSuccess, ...]

def encode(self) -> bytes:
return json.dumps([result.to_json() for result in self.results]).encode('ascii')


class JsonRpcRequestHandler(BaseHTTPRequestHandler):
methods: dict[str, JsonRpcMethod]

def __init__(self, methods: dict[str, JsonRpcMethod], *args: Any, **kwargs: Any) -> None:
self.methods = methods
super().__init__(*args, **kwargs)

def _send_response(self, response: JsonRpcResult) -> None:
self.send_response_headers()
response_bytes = response.encode()
self.wfile.write(response_bytes)

def set_response(self) -> None:
def send_response_headers(self) -> None:
self.send_response(200)
self.send_header('Content-type', 'text/html')
self.end_headers()
Expand All @@ -113,44 +182,44 @@ def do_POST(self) -> None: # noqa: N802
content = self.rfile.read(int(content_len))
_LOGGER.debug(f'Received bytes: {content.decode()}')

request: dict
request: dict[str, Any] | list[dict[str, Any]]
try:
request = json.loads(content)
_LOGGER.info(f'Received request: {request}')
except json.JSONDecodeError:
_LOGGER.warning(f'Invalid JSON: {content.decode()}')
self.send_json_error(-32700, 'Invalid JSON')
json_error = JsonRpcError(-32700, 'Invalid JSON', None)
self._send_response(json_error)
return

required_fields = ['jsonrpc', 'method', 'id']
for field in required_fields:
if field not in request:
_LOGGER.warning(f'Missing required field "{field}": {request}')
self.send_json_error(-32600, f'Invalid request: missing field "{field}"', request.get('id', None))
return
response: JsonRpcResult
if isinstance(request, list):
response = self._batch_request(request)
else:
response = self._single_request(request)

jsonrpc_version = request['jsonrpc']
if jsonrpc_version != JsonRpcServer.JSONRPC_VERSION:
_LOGGER.warning(f'Bad JSON-RPC version: {jsonrpc_version}')
self.send_json_error(-32600, f'Invalid request: bad version: "{jsonrpc_version}"', request['id'])
return
self._send_response(response)

method_name = request['method']
if method_name not in self.methods:
_LOGGER.warning(f'Method not found: {method_name}')
self.send_json_error(-32601, f'Method "{method_name}" not found.', request['id'])
return
def _batch_request(self, requests: list[dict[str, Any]]) -> JsonRpcBatchResult:
return JsonRpcBatchResult(tuple(self._single_request(request) for request in requests))

def _single_request(self, request: dict[str, Any]) -> JsonRpcError | JsonRpcSuccess:
validation_result = JsonRpcRequest.validate(request, self.methods.keys())
if isinstance(validation_result, JsonRpcError):
return validation_result

method_name = request['method']
method = self.methods[method_name]
params = request.get('params', None)
params = validation_result.params
_LOGGER.info(f'Executing method {method_name}')
result: Any
if type(params) is dict:
result = method(**params)
elif type(params) is list:
result = method(*params)
elif params is None:
result = method()
else:
self.send_json_error(-32602, 'Unrecognized method parameter format.')
return JsonRpcError(-32602, 'Unrecognized method parameter format.', validation_result.id)
_LOGGER.debug(f'Got response {result}')
self.send_json_success(result, request['id'])
return JsonRpcSuccess(result, validation_result.id)
128 changes: 128 additions & 0 deletions pyk/src/tests/unit/test_json_rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from __future__ import annotations

import json
from http.client import HTTPConnection
from threading import Thread
from time import sleep
from typing import Any

from pyk.rpc.rpc import JsonRpcServer, ServeRpcOptions
from pyk.testing import KRunTest


class StatefulKJsonRpcServer(JsonRpcServer):

x: int = 42
y: int = 43

def __init__(self, options: ServeRpcOptions) -> None:
super().__init__(options)

self.register_method('get_x', self.exec_get_x)
self.register_method('get_y', self.exec_get_y)
self.register_method('set_x', self.exec_set_x)
self.register_method('set_y', self.exec_set_y)
self.register_method('add', self.exec_add)

def exec_get_x(self) -> int:
return self.x

def exec_get_y(self) -> int:
return self.y

def exec_set_x(self, n: int) -> None:
self.x = n

def exec_set_y(self, n: int) -> None:
self.y = n

def exec_add(self) -> int:
return self.x + self.y


class TestJsonRPCServer(KRunTest):

def test_json_rpc_server(self) -> None:
server = StatefulKJsonRpcServer(ServeRpcOptions({'port': 0}))

def run_server() -> None:
server.serve()

def wait_until_server_is_up() -> None:
while True:
try:
server.port()
return
except ValueError:
sleep(0.1)

thread = Thread(target=run_server)
thread.start()

wait_until_server_is_up()

http_client = HTTPConnection('localhost', server.port())
rpc_client = SimpleClient(http_client)

def wait_until_ready() -> None:
while True:
try:
rpc_client.request('get_x', [])
except ConnectionRefusedError:
sleep(0.1)
continue
break

wait_until_ready()

rpc_client.request('set_x', [123])
res = rpc_client.request('get_x')
assert res == 123

rpc_client.request('set_y', [456])
res = rpc_client.request('get_y')
assert res == 456

res = rpc_client.request('add', [])
assert res == (123 + 456)

res = rpc_client.batch_request(('set_x', [1]), ('set_y', [2]), ('add', []))
assert len(res) == 3
assert res[2]['result'] == 1 + 2

server.shutdown()
thread.join()


class SimpleClient:

client: HTTPConnection
_request_id: int = 0

def __init__(self, client: HTTPConnection) -> None:
self.client = client

def request_id(self) -> int:
self._request_id += 1
return self._request_id

def request(self, method: str, params: Any = None) -> Any:
body = json.dumps({'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()})

self.client.request('POST', '/', body)
response = self.client.getresponse()
result = json.loads(response.read())
return result['result']

def batch_request(self, *requests: tuple[str, Any]) -> list[Any]:
body = json.dumps(
[
{'jsonrpc': '2.0', 'method': method, 'params': params, 'id': self.request_id()}
for method, params in requests
]
)

self.client.request('POST', '/', body)
response = self.client.getresponse()
result = json.loads(response.read())
return result

0 comments on commit ef1692b

Please sign in to comment.