Skip to content

Commit

Permalink
refactor: use tokio async pattern (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
0xmountaintop authored Sep 26, 2024
1 parent 3bb1e18 commit 4322c71
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 94 deletions.
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ tiny-keccak = { version = "2.0.0", features = ["sha3", "keccak"] }
rand = "0.8.5"
eth-keystore = "0.5.0"
rlp = "0.5.2"
tokio = "1.37.0"
tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] }
sled = "0.34.7"
http = "1.1.0"
clap = { version = "4.5", features = ["derive"] }
Expand Down
7 changes: 5 additions & 2 deletions examples/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ impl CloudProver {
}
}

fn main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing();

let args = Args::parse();
Expand All @@ -59,5 +60,7 @@ fn main() -> anyhow::Result<()> {
.with_proving_service(Box::new(cloud_prover))
.build()?;

Arc::new(prover).run()
prover.run().await;

Ok(())
}
7 changes: 5 additions & 2 deletions examples/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ impl LocalProver {
}
}

fn main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing();

let args = Args::parse();
Expand All @@ -53,5 +54,7 @@ fn main() -> anyhow::Result<()> {
.with_proving_service(Box::new(local_prover))
.build()?;

Arc::new(prover).run()
prover.run().await;

Ok(())
}
68 changes: 21 additions & 47 deletions src/coordinator_handler/coordinator_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use super::{
LoginRequest, Response, SubmitProofRequest, SubmitProofResponseData,
};
use crate::{config::CoordinatorConfig, prover::CircuitType, utils::get_version};
use std::sync::{Mutex, MutexGuard};
use tokio::runtime::Runtime;
use tokio::sync::{Mutex, MutexGuard};

pub struct CoordinatorClient {
circuit_type: CircuitType,
Expand All @@ -14,7 +13,6 @@ pub struct CoordinatorClient {
key_signer: KeySigner,
api: Api,
token: Mutex<Option<String>>,
rt: Runtime,
}

impl CoordinatorClient {
Expand All @@ -26,9 +24,6 @@ impl CoordinatorClient {
prover_name: String,
key_signer: KeySigner,
) -> anyhow::Result<Self> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let api = Api::new(cfg)?;
let client = Self {
circuit_type,
Expand All @@ -38,79 +33,58 @@ impl CoordinatorClient {
key_signer,
api,
token: Mutex::new(None),
rt,
};
Ok(client)
}

pub fn get_task(&self, req: &GetTaskRequest) -> anyhow::Result<Response<GetTaskResponseData>> {
let token = self.get_token_sync(false)?;
let response = self.get_task_sync(req, &token)?;
pub async fn get_task(
&self,
req: &GetTaskRequest,
) -> anyhow::Result<Response<GetTaskResponseData>> {
let token = self.get_token(false).await?;
let response = self.api.get_task(req, &token).await?;

if response.errcode == ErrorCode::ErrJWTTokenExpired {
let token = self.get_token_sync(true)?;
self.get_task_sync(req, &token)
let token = self.get_token(true).await?;
self.api.get_task(req, &token).await
} else {
Ok(response)
}
}

pub fn submit_proof(
pub async fn submit_proof(
&self,
req: &SubmitProofRequest,
) -> anyhow::Result<Response<SubmitProofResponseData>> {
let token = self.get_token_sync(false)?;
let response = self.submit_proof_sync(req, &token)?;
let token = self.get_token(false).await?;
let response = self.api.submit_proof(req, &token).await?;

if response.errcode == ErrorCode::ErrJWTTokenExpired {
let token = self.get_token_sync(true)?;
self.submit_proof_sync(req, &token)
let token = self.get_token(true).await?;
self.api.submit_proof(req, &token).await
} else {
Ok(response)
}
}

fn get_task_sync(
&self,
req: &GetTaskRequest,
token: &String,
) -> anyhow::Result<Response<GetTaskResponseData>> {
self.rt.block_on(self.api.get_task(req, token))
}

fn submit_proof_sync(
&self,
req: &SubmitProofRequest,
token: &String,
) -> anyhow::Result<Response<SubmitProofResponseData>> {
self.rt.block_on(self.api.submit_proof(req, token))
}

fn get_token_sync(&self, force_relogin: bool) -> anyhow::Result<String> {
self.rt.block_on(self.get_token_async(force_relogin))
}

/// Retrieves a token for authentication, optionally forcing a re-login.
///
/// This function attempts to get the stored token if `force_relogin` is set to `false`.
///
/// If the token is expired, `force_relogin` is set to `true`, or a login was never performed
/// before, it will authenticate and fetch a new token.
async fn get_token_async(&self, force_relogin: bool) -> anyhow::Result<String> {
let token_guard = self
.token
.lock()
.expect("Mutex locking only occurs within `get_token` fn, so there can be no double `lock` for one thread");

match token_guard.as_deref() {
Some(token) if !force_relogin => return Ok(token.to_string()),
async fn get_token(&self, force_relogin: bool) -> anyhow::Result<String> {
let token_guard = self.token.lock().await;

match *token_guard {
Some(ref token) if !force_relogin => return Ok(token.to_string()),
_ => (),
}

self.login_async(token_guard).await
self.login(token_guard).await
}

async fn login_async<'t>(
async fn login<'t>(
&self,
mut token_guard: MutexGuard<'t, Option<String>>,
) -> anyhow::Result<String> {
Expand Down
50 changes: 29 additions & 21 deletions src/prover/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod builder;
pub mod proving_service;
pub mod types;
use tokio::task::JoinSet;
pub use {builder::ProverBuilder, proving_service::ProvingService, types::*};

use crate::{
Expand All @@ -25,31 +26,33 @@ pub struct Prover {
}

impl Prover {
pub fn run(self: std::sync::Arc<Self>) -> anyhow::Result<()> {
pub async fn run(self) {
assert!(self.n_workers == self.coordinator_clients.len());
if self.circuit_type == CircuitType::Chunk {
assert!(self.l2geth_client.is_some());
}

for i in 0..self.n_workers {
let self_clone = std::sync::Arc::clone(&self);
thread::spawn(move || {
self_clone.working_loop(i);
let mut provers = JoinSet::new();
let self_arc = std::sync::Arc::new(self);
for i in 0..self_arc.n_workers {
let self_clone = std::sync::Arc::clone(&self_arc);
provers.spawn(async move {
self_clone.working_loop(i).await;
});
}

Ok(())
while provers.join_next().await.is_some() {}
}

fn working_loop(&self, i: usize) {
async fn working_loop(&self, i: usize) {
loop {
let coordinator_client = &self.coordinator_clients[i];
let prover_name = coordinator_client.prover_name.clone();

log::info!("{:?}: getting task from coordinator", prover_name);

let get_task_request = self.build_get_task_request();
let coordinator_task = coordinator_client.get_task(&get_task_request);
let get_task_request = self.build_get_task_request().await;
let coordinator_task = coordinator_client.get_task(&get_task_request).await;

if let Err(e) = coordinator_task {
log::error!("{:?}: failed to get task: {:?}", prover_name, e);
Expand All @@ -75,7 +78,7 @@ impl Prover {
let coordinator_task_id = coordinator_task.task_id.clone();
let task_type = coordinator_task.task_type;

let proving_input = match self.build_proving_input(&coordinator_task) {
let proving_input = match self.build_proving_input(&coordinator_task).await {
Ok(input) => input,
Err(e) => {
log::error!(
Expand Down Expand Up @@ -148,7 +151,7 @@ impl Prover {
failure_type: None,
failure_msg: None,
};
match coordinator_client.submit_proof(&submit_proof_req) {
match coordinator_client.submit_proof(&submit_proof_req).await {
Ok(_) => {
log::info!(
"{:?}: proof submitted. task_type: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}, proving_service_task_id: {:?}",
Expand Down Expand Up @@ -193,7 +196,7 @@ impl Prover {
failure_type: Some(ProofFailureType::Panic), // TODO: handle ProofFailureType::NoPanic
failure_msg: Some(task_err),
};
match coordinator_client.submit_proof(&submit_proof_req) {
match coordinator_client.submit_proof(&submit_proof_req).await {
Ok(_) => {
log::info!(
"{:?}: proof_err submitted. task_type: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}, proving_service_task_id: {:?}",
Expand Down Expand Up @@ -224,21 +227,25 @@ impl Prover {
}
}

fn build_get_task_request(&self) -> GetTaskRequest {
let prover_height = self.l2geth_client.as_ref().and_then(|l2geth_client| {
l2geth_client
.block_number_sync()
.ok()
.and_then(|block_number| block_number.as_number())
});
async fn build_get_task_request(&self) -> GetTaskRequest {
let prover_height = match &self.l2geth_client {
None => None,
Some(l2geth_client) => match l2geth_client.block_number().await {
Ok(block_number) => block_number.as_number(),
Err(_) => None,
},
};

GetTaskRequest {
task_types: vec![self.circuit_type],
prover_height,
}
}

fn build_proving_input(&self, task: &GetTaskResponseData) -> anyhow::Result<ProveRequest> {
async fn build_proving_input(
&self,
task: &GetTaskResponseData,
) -> anyhow::Result<ProveRequest> {
anyhow::ensure!(
task.task_type == self.circuit_type,
"task type mismatch. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}",
Expand All @@ -258,7 +265,8 @@ impl Prover {
.l2geth_client
.as_ref()
.unwrap()
.get_sorted_traces_by_hashes(&chunk_task_detail.block_hashes)?;
.get_sorted_traces_by_hashes(&chunk_task_detail.block_hashes)
.await?;
let input = serde_json::to_string(&traces)?;

Ok(ProveRequest {
Expand Down
26 changes: 5 additions & 21 deletions src/tracing_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,20 @@ use prover_darwin_v2::BlockTrace;
use serde::{de::DeserializeOwned, Serialize};
use std::cmp::Ordering;
use std::fmt::Debug;
use tokio::runtime::Runtime;

pub type CommonHash = H256;

pub struct L2gethClient {
provider: Provider<Http>,
rt: Runtime,
}

impl L2gethClient {
pub fn new(cfg: L2GethConfig) -> anyhow::Result<Self> {
let provider = Provider::<Http>::try_from(cfg.endpoint)?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
Ok(Self { provider, rt })
Ok(Self { provider })
}

async fn get_block_trace_by_hash_async<T>(&self, hash: &CommonHash) -> anyhow::Result<T>
pub async fn get_block_trace_by_hash<T>(&self, hash: &CommonHash) -> anyhow::Result<T>
where
T: Serialize + DeserializeOwned + Debug + Send,
{
Expand All @@ -40,25 +35,14 @@ impl L2gethClient {
Ok(trace)
}

pub fn get_block_trace_by_hash_sync<T>(&self, hash: &CommonHash) -> anyhow::Result<T>
where
T: Serialize + DeserializeOwned + Debug + Send,
{
self.rt.block_on(self.get_block_trace_by_hash_async(hash))
}

async fn block_number_async(&self) -> anyhow::Result<BlockNumber> {
pub async fn block_number(&self) -> anyhow::Result<BlockNumber> {
log::info!("l2geth_client calling block_number");

let trace = self.provider.request("eth_blockNumber", ()).await?;
Ok(trace)
}

pub fn block_number_sync(&self) -> anyhow::Result<BlockNumber> {
self.rt.block_on(self.block_number_async())
}

pub fn get_sorted_traces_by_hashes(
pub async fn get_sorted_traces_by_hashes(
&self,
block_hashes: &[CommonHash],
) -> anyhow::Result<Vec<BlockTrace>> {
Expand All @@ -69,7 +53,7 @@ impl L2gethClient {

let mut block_traces = Vec::new();
for hash in block_hashes.iter() {
let trace = self.get_block_trace_by_hash_sync(hash)?;
let trace = self.get_block_trace_by_hash(hash).await?;
block_traces.push(trace);
}

Expand Down

0 comments on commit 4322c71

Please sign in to comment.