From 4322c7184d3dcf17ef674db77dd3bc7ee805c83c Mon Sep 17 00:00:00 2001 From: HAOYUatHZ <37070449+HAOYUatHZ@users.noreply.github.com> Date: Thu, 26 Sep 2024 19:39:54 +1000 Subject: [PATCH] refactor: use `tokio` async pattern (#15) --- Cargo.lock | 13 ++++ Cargo.toml | 2 +- examples/cloud.rs | 7 +- examples/local.rs | 7 +- src/coordinator_handler/coordinator_client.rs | 68 ++++++------------- src/prover/mod.rs | 50 ++++++++------ src/tracing_handler.rs | 26 ++----- 7 files changed, 79 insertions(+), 94 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e892792..b1a3049 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4684,11 +4684,24 @@ dependencies = [ "bytes", "libc", "mio", + "num_cpus", "pin-project-lite", "socket2", + "tokio-macros", "windows-sys 0.48.0", ] +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "tokio-native-tls" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index 517e9b7..10f2af4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,7 @@ tiny-keccak = { version = "2.0.0", features = ["sha3", "keccak"] } rand = "0.8.5" eth-keystore = "0.5.0" rlp = "0.5.2" -tokio = "1.37.0" +tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] } sled = "0.34.7" http = "1.1.0" clap = { version = "4.5", features = ["derive"] } diff --git a/examples/cloud.rs b/examples/cloud.rs index acdddbf..df59545 100644 --- a/examples/cloud.rs +++ b/examples/cloud.rs @@ -49,7 +49,8 @@ impl CloudProver { } } -fn main() -> anyhow::Result<()> { +#[tokio::main] +async fn main() -> anyhow::Result<()> { init_tracing(); let args = Args::parse(); @@ -59,5 +60,7 @@ fn main() -> anyhow::Result<()> { .with_proving_service(Box::new(cloud_prover)) .build()?; - Arc::new(prover).run() + prover.run().await; + + Ok(()) } diff --git a/examples/local.rs b/examples/local.rs index 22e612d..b913150 100644 --- a/examples/local.rs +++ b/examples/local.rs @@ -43,7 +43,8 @@ impl LocalProver { } } -fn main() -> anyhow::Result<()> { +#[tokio::main] +async fn main() -> anyhow::Result<()> { init_tracing(); let args = Args::parse(); @@ -53,5 +54,7 @@ fn main() -> anyhow::Result<()> { .with_proving_service(Box::new(local_prover)) .build()?; - Arc::new(prover).run() + prover.run().await; + + Ok(()) } diff --git a/src/coordinator_handler/coordinator_client.rs b/src/coordinator_handler/coordinator_client.rs index a687430..fa61635 100644 --- a/src/coordinator_handler/coordinator_client.rs +++ b/src/coordinator_handler/coordinator_client.rs @@ -3,8 +3,7 @@ use super::{ LoginRequest, Response, SubmitProofRequest, SubmitProofResponseData, }; use crate::{config::CoordinatorConfig, prover::CircuitType, utils::get_version}; -use std::sync::{Mutex, MutexGuard}; -use tokio::runtime::Runtime; +use tokio::sync::{Mutex, MutexGuard}; pub struct CoordinatorClient { circuit_type: CircuitType, @@ -14,7 +13,6 @@ pub struct CoordinatorClient { key_signer: KeySigner, api: Api, token: Mutex>, - rt: Runtime, } impl CoordinatorClient { @@ -26,9 +24,6 @@ impl CoordinatorClient { prover_name: String, key_signer: KeySigner, ) -> anyhow::Result { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; let api = Api::new(cfg)?; let client = Self { circuit_type, @@ -38,79 +33,58 @@ impl CoordinatorClient { key_signer, api, token: Mutex::new(None), - rt, }; Ok(client) } - pub fn get_task(&self, req: &GetTaskRequest) -> anyhow::Result> { - let token = self.get_token_sync(false)?; - let response = self.get_task_sync(req, &token)?; + pub async fn get_task( + &self, + req: &GetTaskRequest, + ) -> anyhow::Result> { + let token = self.get_token(false).await?; + let response = self.api.get_task(req, &token).await?; if response.errcode == ErrorCode::ErrJWTTokenExpired { - let token = self.get_token_sync(true)?; - self.get_task_sync(req, &token) + let token = self.get_token(true).await?; + self.api.get_task(req, &token).await } else { Ok(response) } } - pub fn submit_proof( + pub async fn submit_proof( &self, req: &SubmitProofRequest, ) -> anyhow::Result> { - let token = self.get_token_sync(false)?; - let response = self.submit_proof_sync(req, &token)?; + let token = self.get_token(false).await?; + let response = self.api.submit_proof(req, &token).await?; if response.errcode == ErrorCode::ErrJWTTokenExpired { - let token = self.get_token_sync(true)?; - self.submit_proof_sync(req, &token) + let token = self.get_token(true).await?; + self.api.submit_proof(req, &token).await } else { Ok(response) } } - fn get_task_sync( - &self, - req: &GetTaskRequest, - token: &String, - ) -> anyhow::Result> { - self.rt.block_on(self.api.get_task(req, token)) - } - - fn submit_proof_sync( - &self, - req: &SubmitProofRequest, - token: &String, - ) -> anyhow::Result> { - self.rt.block_on(self.api.submit_proof(req, token)) - } - - fn get_token_sync(&self, force_relogin: bool) -> anyhow::Result { - self.rt.block_on(self.get_token_async(force_relogin)) - } - /// Retrieves a token for authentication, optionally forcing a re-login. /// /// This function attempts to get the stored token if `force_relogin` is set to `false`. /// /// If the token is expired, `force_relogin` is set to `true`, or a login was never performed /// before, it will authenticate and fetch a new token. - async fn get_token_async(&self, force_relogin: bool) -> anyhow::Result { - let token_guard = self - .token - .lock() - .expect("Mutex locking only occurs within `get_token` fn, so there can be no double `lock` for one thread"); - - match token_guard.as_deref() { - Some(token) if !force_relogin => return Ok(token.to_string()), + async fn get_token(&self, force_relogin: bool) -> anyhow::Result { + let token_guard = self.token.lock().await; + + match *token_guard { + Some(ref token) if !force_relogin => return Ok(token.to_string()), _ => (), } - self.login_async(token_guard).await + self.login(token_guard).await } - async fn login_async<'t>( + async fn login<'t>( &self, mut token_guard: MutexGuard<'t, Option>, ) -> anyhow::Result { diff --git a/src/prover/mod.rs b/src/prover/mod.rs index 6e214c7..4e08906 100644 --- a/src/prover/mod.rs +++ b/src/prover/mod.rs @@ -1,6 +1,7 @@ pub mod builder; pub mod proving_service; pub mod types; +use tokio::task::JoinSet; pub use {builder::ProverBuilder, proving_service::ProvingService, types::*}; use crate::{ @@ -25,31 +26,33 @@ pub struct Prover { } impl Prover { - pub fn run(self: std::sync::Arc) -> anyhow::Result<()> { + pub async fn run(self) { assert!(self.n_workers == self.coordinator_clients.len()); if self.circuit_type == CircuitType::Chunk { assert!(self.l2geth_client.is_some()); } - for i in 0..self.n_workers { - let self_clone = std::sync::Arc::clone(&self); - thread::spawn(move || { - self_clone.working_loop(i); + let mut provers = JoinSet::new(); + let self_arc = std::sync::Arc::new(self); + for i in 0..self_arc.n_workers { + let self_clone = std::sync::Arc::clone(&self_arc); + provers.spawn(async move { + self_clone.working_loop(i).await; }); } - Ok(()) + while provers.join_next().await.is_some() {} } - fn working_loop(&self, i: usize) { + async fn working_loop(&self, i: usize) { loop { let coordinator_client = &self.coordinator_clients[i]; let prover_name = coordinator_client.prover_name.clone(); log::info!("{:?}: getting task from coordinator", prover_name); - let get_task_request = self.build_get_task_request(); - let coordinator_task = coordinator_client.get_task(&get_task_request); + let get_task_request = self.build_get_task_request().await; + let coordinator_task = coordinator_client.get_task(&get_task_request).await; if let Err(e) = coordinator_task { log::error!("{:?}: failed to get task: {:?}", prover_name, e); @@ -75,7 +78,7 @@ impl Prover { let coordinator_task_id = coordinator_task.task_id.clone(); let task_type = coordinator_task.task_type; - let proving_input = match self.build_proving_input(&coordinator_task) { + let proving_input = match self.build_proving_input(&coordinator_task).await { Ok(input) => input, Err(e) => { log::error!( @@ -148,7 +151,7 @@ impl Prover { failure_type: None, failure_msg: None, }; - match coordinator_client.submit_proof(&submit_proof_req) { + match coordinator_client.submit_proof(&submit_proof_req).await { Ok(_) => { log::info!( "{:?}: proof submitted. task_type: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}, proving_service_task_id: {:?}", @@ -193,7 +196,7 @@ impl Prover { failure_type: Some(ProofFailureType::Panic), // TODO: handle ProofFailureType::NoPanic failure_msg: Some(task_err), }; - match coordinator_client.submit_proof(&submit_proof_req) { + match coordinator_client.submit_proof(&submit_proof_req).await { Ok(_) => { log::info!( "{:?}: proof_err submitted. task_type: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}, proving_service_task_id: {:?}", @@ -224,13 +227,14 @@ impl Prover { } } - fn build_get_task_request(&self) -> GetTaskRequest { - let prover_height = self.l2geth_client.as_ref().and_then(|l2geth_client| { - l2geth_client - .block_number_sync() - .ok() - .and_then(|block_number| block_number.as_number()) - }); + async fn build_get_task_request(&self) -> GetTaskRequest { + let prover_height = match &self.l2geth_client { + None => None, + Some(l2geth_client) => match l2geth_client.block_number().await { + Ok(block_number) => block_number.as_number(), + Err(_) => None, + }, + }; GetTaskRequest { task_types: vec![self.circuit_type], @@ -238,7 +242,10 @@ impl Prover { } } - fn build_proving_input(&self, task: &GetTaskResponseData) -> anyhow::Result { + async fn build_proving_input( + &self, + task: &GetTaskResponseData, + ) -> anyhow::Result { anyhow::ensure!( task.task_type == self.circuit_type, "task type mismatch. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}", @@ -258,7 +265,8 @@ impl Prover { .l2geth_client .as_ref() .unwrap() - .get_sorted_traces_by_hashes(&chunk_task_detail.block_hashes)?; + .get_sorted_traces_by_hashes(&chunk_task_detail.block_hashes) + .await?; let input = serde_json::to_string(&traces)?; Ok(ProveRequest { diff --git a/src/tracing_handler.rs b/src/tracing_handler.rs index d495e74..b1a980c 100644 --- a/src/tracing_handler.rs +++ b/src/tracing_handler.rs @@ -6,25 +6,20 @@ use prover_darwin_v2::BlockTrace; use serde::{de::DeserializeOwned, Serialize}; use std::cmp::Ordering; use std::fmt::Debug; -use tokio::runtime::Runtime; pub type CommonHash = H256; pub struct L2gethClient { provider: Provider, - rt: Runtime, } impl L2gethClient { pub fn new(cfg: L2GethConfig) -> anyhow::Result { let provider = Provider::::try_from(cfg.endpoint)?; - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build()?; - Ok(Self { provider, rt }) + Ok(Self { provider }) } - async fn get_block_trace_by_hash_async(&self, hash: &CommonHash) -> anyhow::Result + pub async fn get_block_trace_by_hash(&self, hash: &CommonHash) -> anyhow::Result where T: Serialize + DeserializeOwned + Debug + Send, { @@ -40,25 +35,14 @@ impl L2gethClient { Ok(trace) } - pub fn get_block_trace_by_hash_sync(&self, hash: &CommonHash) -> anyhow::Result - where - T: Serialize + DeserializeOwned + Debug + Send, - { - self.rt.block_on(self.get_block_trace_by_hash_async(hash)) - } - - async fn block_number_async(&self) -> anyhow::Result { + pub async fn block_number(&self) -> anyhow::Result { log::info!("l2geth_client calling block_number"); let trace = self.provider.request("eth_blockNumber", ()).await?; Ok(trace) } - pub fn block_number_sync(&self) -> anyhow::Result { - self.rt.block_on(self.block_number_async()) - } - - pub fn get_sorted_traces_by_hashes( + pub async fn get_sorted_traces_by_hashes( &self, block_hashes: &[CommonHash], ) -> anyhow::Result> { @@ -69,7 +53,7 @@ impl L2gethClient { let mut block_traces = Vec::new(); for hash in block_hashes.iter() { - let trace = self.get_block_trace_by_hash_sync(hash)?; + let trace = self.get_block_trace_by_hash(hash).await?; block_traces.push(trace); }