Skip to content

Commit

Permalink
Refactor parsing and handling of batched Electrum RPC
Browse files Browse the repository at this point in the history
Related to #464
  • Loading branch information
romanz authored and Roman Zeyde committed Oct 27, 2021
1 parent f4afe88 commit f5ceb05
Showing 1 changed file with 122 additions and 72 deletions.
194 changes: 122 additions & 72 deletions src/electrum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,65 +396,58 @@ impl Rpc {
}

pub fn handle_requests(&self, client: &mut Client, lines: &[String]) -> Vec<String> {
lines
let parsed: Vec<Result<Calls, Value>> = lines
.iter()
.map(|line| self.handle_request(client, &line))
.map(|line| {
parse_requests(&line)
.map(Calls::parse)
.map_err(error_msg_no_id)
})
.collect();

parsed
.into_iter()
.map(|calls| self.handle_calls(client, calls).to_string())
.collect()
}

fn handle_request(&self, client: &mut Client, line: &str) -> String {
let error_msg_no_id = |err| error_msg(Value::Null, RpcError::Standard(err));
let response: Value = match serde_json::from_str(line) {
// parse JSON from str
Ok(value) => match serde_json::from_value(value) {
// parse RPC from JSON
Ok(requests) => match requests {
Requests::Single(request) => self.call(client, request),
Requests::Batch(requests) => json!(requests
.into_iter()
.map(|request| self.call(client, request))
.collect::<Vec<Value>>()),
},
Err(err) => {
warn!("invalid RPC request ({:?}): {}", line, err);
error_msg_no_id(StandardError::InvalidRequest)
}
},
Err(err) => {
warn!("invalid JSON ({:?}): {}", line, err);
error_msg_no_id(StandardError::ParseError)
}
};
response.to_string()
fn handle_calls(&self, client: &mut Client, calls: Result<Calls, Value>) -> Value {
match calls {
Ok(Calls::Batch(batch)) => json!(batch
.into_iter()
.map(|call| self.single_call(client, call))
.collect::<Vec<_>>()),
Ok(Calls::Single(call)) => json!(self.single_call(client, call)),
Err(response) => response, // JSON parsing may fail - the response does not contain request id
}
}

fn call(&self, client: &mut Client, request: Request) -> Value {
let Request { id, method, params } = request;
let call = match Call::parse(&method, params) {
fn single_call(&self, client: &mut Client, call: Result<Call, Value>) -> Value {
let Call { id, method, params } = match call {
Ok(call) => call,
Err(err) => return error_msg(id, RpcError::Standard(err)),
Err(response) => return response, // params parsing may fail - the response contains request id
};
self.rpc_duration.observe_duration(&method, || {
let result = match call {
Call::Banner => Ok(json!(self.banner)),
Call::BlockHeader(args) => self.block_header(args),
Call::BlockHeaders(args) => self.block_headers(args),
Call::Donation => Ok(Value::Null),
Call::EstimateFee(args) => self.estimate_fee(args),
Call::Features => self.features(),
Call::HeadersSubscribe => self.headers_subscribe(client),
Call::MempoolFeeHistogram => self.get_fee_histogram(),
Call::PeersSubscribe => Ok(json!([])),
Call::Ping => Ok(Value::Null),
Call::RelayFee => self.relayfee(),
Call::ScriptHashGetBalance(args) => self.scripthash_get_balance(client, args),
Call::ScriptHashGetHistory(args) => self.scripthash_get_history(client, args),
Call::ScriptHashListUnspent(args) => self.scripthash_list_unspent(client, args),
Call::ScriptHashSubscribe(args) => self.scripthash_subscribe(client, args),
Call::TransactionBroadcast(args) => self.transaction_broadcast(args),
Call::TransactionGet(args) => self.transaction_get(args),
Call::TransactionGetMerkle(args) => self.transaction_get_merkle(args),
Call::Version(args) => self.version(args),
let result = match params {
Params::Banner => Ok(json!(self.banner)),
Params::BlockHeader(args) => self.block_header(args),
Params::BlockHeaders(args) => self.block_headers(args),
Params::Donation => Ok(Value::Null),
Params::EstimateFee(args) => self.estimate_fee(args),
Params::Features => self.features(),
Params::HeadersSubscribe => self.headers_subscribe(client),
Params::MempoolFeeHistogram => self.get_fee_histogram(),
Params::PeersSubscribe => Ok(json!([])),
Params::Ping => Ok(Value::Null),
Params::RelayFee => self.relayfee(),
Params::ScriptHashGetBalance(args) => self.scripthash_get_balance(client, args),
Params::ScriptHashGetHistory(args) => self.scripthash_get_history(client, args),
Params::ScriptHashListUnspent(args) => self.scripthash_list_unspent(client, args),
Params::ScriptHashSubscribe(args) => self.scripthash_subscribe(client, args),
Params::TransactionBroadcast(args) => self.transaction_broadcast(args),
Params::TransactionGet(args) => self.transaction_get(args),
Params::TransactionGetMerkle(args) => self.transaction_get_merkle(args),
Params::Version(args) => self.version(args),
};
match result {
Ok(value) => result_msg(id, value),
Expand All @@ -474,7 +467,7 @@ impl Rpc {
}

#[derive(Deserialize)]
enum Call {
enum Params {
Banner,
BlockHeader((usize,)),
BlockHeaders((usize, usize)),
Expand All @@ -496,28 +489,28 @@ enum Call {
Version((String, Version)),
}

impl Call {
fn parse(method: &str, params: Value) -> std::result::Result<Call, StandardError> {
impl Params {
fn parse(method: &str, params: Value) -> std::result::Result<Params, StandardError> {
Ok(match method {
"blockchain.block.header" => Call::BlockHeader(convert(params)?),
"blockchain.block.headers" => Call::BlockHeaders(convert(params)?),
"blockchain.estimatefee" => Call::EstimateFee(convert(params)?),
"blockchain.headers.subscribe" => Call::HeadersSubscribe,
"blockchain.relayfee" => Call::RelayFee,
"blockchain.scripthash.get_balance" => Call::ScriptHashGetBalance(convert(params)?),
"blockchain.scripthash.get_history" => Call::ScriptHashGetHistory(convert(params)?),
"blockchain.scripthash.listunspent" => Call::ScriptHashListUnspent(convert(params)?),
"blockchain.scripthash.subscribe" => Call::ScriptHashSubscribe(convert(params)?),
"blockchain.transaction.broadcast" => Call::TransactionBroadcast(convert(params)?),
"blockchain.transaction.get" => Call::TransactionGet(convert(params)?),
"blockchain.transaction.get_merkle" => Call::TransactionGetMerkle(convert(params)?),
"mempool.get_fee_histogram" => Call::MempoolFeeHistogram,
"server.banner" => Call::Banner,
"server.donation_address" => Call::Donation,
"server.features" => Call::Features,
"server.peers.subscribe" => Call::PeersSubscribe,
"server.ping" => Call::Ping,
"server.version" => Call::Version(convert(params)?),
"blockchain.block.header" => Params::BlockHeader(convert(params)?),
"blockchain.block.headers" => Params::BlockHeaders(convert(params)?),
"blockchain.estimatefee" => Params::EstimateFee(convert(params)?),
"blockchain.headers.subscribe" => Params::HeadersSubscribe,
"blockchain.relayfee" => Params::RelayFee,
"blockchain.scripthash.get_balance" => Params::ScriptHashGetBalance(convert(params)?),
"blockchain.scripthash.get_history" => Params::ScriptHashGetHistory(convert(params)?),
"blockchain.scripthash.listunspent" => Params::ScriptHashListUnspent(convert(params)?),
"blockchain.scripthash.subscribe" => Params::ScriptHashSubscribe(convert(params)?),
"blockchain.transaction.broadcast" => Params::TransactionBroadcast(convert(params)?),
"blockchain.transaction.get" => Params::TransactionGet(convert(params)?),
"blockchain.transaction.get_merkle" => Params::TransactionGetMerkle(convert(params)?),
"mempool.get_fee_histogram" => Params::MempoolFeeHistogram,
"server.banner" => Params::Banner,
"server.donation_address" => Params::Donation,
"server.features" => Params::Features,
"server.peers.subscribe" => Params::PeersSubscribe,
"server.ping" => Params::Ping,
"server.version" => Params::Version(convert(params)?),
_ => {
warn!("unknown method {}", method);
return Err(StandardError::MethodNotFound);
Expand All @@ -526,6 +519,41 @@ impl Call {
}
}

struct Call {
id: Value,
method: String,
params: Params,
}

impl Call {
fn parse(request: Request) -> Result<Call, Value> {
match Params::parse(&request.method, request.params) {
Ok(params) => Ok(Call {
id: request.id,
method: request.method,
params,
}),
Err(e) => Err(error_msg(request.id, RpcError::Standard(e))),
}
}
}

enum Calls {
Batch(Vec<Result<Call, Value>>),
Single(Result<Call, Value>),
}

impl Calls {
fn parse(requests: Requests) -> Calls {
match requests {
Requests::Single(request) => Calls::Single(Call::parse(request)),
Requests::Batch(batch) => {
Calls::Batch(batch.into_iter().map(Call::parse).collect::<Vec<_>>())
}
}
}
}

fn convert<T>(params: Value) -> std::result::Result<T, StandardError>
where
T: serde::de::DeserializeOwned,
Expand All @@ -548,3 +576,25 @@ fn result_msg(id: Value, result: Value) -> Value {
fn error_msg(id: Value, error: RpcError) -> Value {
json!({"jsonrpc": "2.0", "id": id, "error": error.to_value()})
}

fn error_msg_no_id(err: StandardError) -> Value {
error_msg(Value::Null, RpcError::Standard(err))
}

fn parse_requests(line: &str) -> Result<Requests, StandardError> {
match serde_json::from_str(line) {
// parse JSON from str
Ok(value) => match serde_json::from_value(value) {
// parse RPC from JSON
Ok(requests) => Ok(requests),
Err(err) => {
warn!("invalid RPC request ({:?}): {}", line, err);
Err(StandardError::InvalidRequest)
}
},
Err(err) => {
warn!("invalid JSON ({:?}): {}", line, err);
Err(StandardError::ParseError)
}
}
}

0 comments on commit f5ceb05

Please sign in to comment.