Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: use tokio async pattern #15

Merged
merged 12 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ tiny-keccak = { version = "2.0.0", features = ["sha3", "keccak"] }
rand = "0.8.5"
eth-keystore = "0.5.0"
rlp = "0.5.2"
tokio = "1.37.0"
tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] }
sled = "0.34.7"
http = "1.1.0"
clap = { version = "4.5", features = ["derive"] }
Expand Down
7 changes: 5 additions & 2 deletions examples/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ impl CloudProver {
}
}

fn main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing();

let args = Args::parse();
Expand All @@ -59,5 +60,7 @@ fn main() -> anyhow::Result<()> {
.with_proving_service(Box::new(cloud_prover))
.build()?;

Arc::new(prover).run()
prover.run().await;

Ok(())
}
7 changes: 5 additions & 2 deletions examples/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ impl LocalProver {
}
}

fn main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing();

let args = Args::parse();
Expand All @@ -53,5 +54,7 @@ fn main() -> anyhow::Result<()> {
.with_proving_service(Box::new(local_prover))
.build()?;

Arc::new(prover).run()
prover.run().await;

Ok(())
}
68 changes: 21 additions & 47 deletions src/coordinator_handler/coordinator_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use super::{
LoginRequest, Response, SubmitProofRequest, SubmitProofResponseData,
};
use crate::{config::CoordinatorConfig, prover::CircuitType, utils::get_version};
use std::sync::{Mutex, MutexGuard};
use tokio::runtime::Runtime;
use tokio::sync::{Mutex, MutexGuard};

pub struct CoordinatorClient {
circuit_type: CircuitType,
Expand All @@ -14,7 +13,6 @@ pub struct CoordinatorClient {
key_signer: KeySigner,
api: Api,
token: Mutex<Option<String>>,
rt: Runtime,
}

impl CoordinatorClient {
Expand All @@ -26,9 +24,6 @@ impl CoordinatorClient {
prover_name: String,
key_signer: KeySigner,
) -> anyhow::Result<Self> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let api = Api::new(cfg)?;
let client = Self {
circuit_type,
Expand All @@ -38,79 +33,58 @@ impl CoordinatorClient {
key_signer,
api,
token: Mutex::new(None),
rt,
};
Ok(client)
}

pub fn get_task(&self, req: &GetTaskRequest) -> anyhow::Result<Response<GetTaskResponseData>> {
let token = self.get_token_sync(false)?;
let response = self.get_task_sync(req, &token)?;
pub async fn get_task(
&self,
req: &GetTaskRequest,
) -> anyhow::Result<Response<GetTaskResponseData>> {
let token = self.get_token(false).await?;
let response = self.api.get_task(req, &token).await?;

if response.errcode == ErrorCode::ErrJWTTokenExpired {
let token = self.get_token_sync(true)?;
self.get_task_sync(req, &token)
let token = self.get_token(true).await?;
self.api.get_task(req, &token).await
} else {
Ok(response)
}
}

pub fn submit_proof(
pub async fn submit_proof(
&self,
req: &SubmitProofRequest,
) -> anyhow::Result<Response<SubmitProofResponseData>> {
let token = self.get_token_sync(false)?;
let response = self.submit_proof_sync(req, &token)?;
let token = self.get_token(false).await?;
let response = self.api.submit_proof(req, &token).await?;

if response.errcode == ErrorCode::ErrJWTTokenExpired {
let token = self.get_token_sync(true)?;
self.submit_proof_sync(req, &token)
let token = self.get_token(true).await?;
self.api.submit_proof(req, &token).await
} else {
Ok(response)
}
}

fn get_task_sync(
&self,
req: &GetTaskRequest,
token: &String,
) -> anyhow::Result<Response<GetTaskResponseData>> {
self.rt.block_on(self.api.get_task(req, token))
}

fn submit_proof_sync(
&self,
req: &SubmitProofRequest,
token: &String,
) -> anyhow::Result<Response<SubmitProofResponseData>> {
self.rt.block_on(self.api.submit_proof(req, token))
}

fn get_token_sync(&self, force_relogin: bool) -> anyhow::Result<String> {
self.rt.block_on(self.get_token_async(force_relogin))
}

/// Retrieves a token for authentication, optionally forcing a re-login.
///
/// This function attempts to get the stored token if `force_relogin` is set to `false`.
///
/// If the token is expired, `force_relogin` is set to `true`, or a login was never performed
/// before, it will authenticate and fetch a new token.
async fn get_token_async(&self, force_relogin: bool) -> anyhow::Result<String> {
let token_guard = self
.token
.lock()
.expect("Mutex locking only occurs within `get_token` fn, so there can be no double `lock` for one thread");

match token_guard.as_deref() {
Some(token) if !force_relogin => return Ok(token.to_string()),
async fn get_token(&self, force_relogin: bool) -> anyhow::Result<String> {
let token_guard = self.token.lock().await;

match *token_guard {
Some(ref token) if !force_relogin => return Ok(token.to_string()),
_ => (),
}

self.login_async(token_guard).await
self.login(token_guard).await
}

async fn login_async<'t>(
async fn login<'t>(
&self,
mut token_guard: MutexGuard<'t, Option<String>>,
) -> anyhow::Result<String> {
Expand Down
50 changes: 29 additions & 21 deletions src/prover/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod builder;
pub mod proving_service;
pub mod types;
use tokio::task::JoinSet;
pub use {builder::ProverBuilder, proving_service::ProvingService, types::*};

use crate::{
Expand All @@ -25,31 +26,33 @@ pub struct Prover {
}

impl Prover {
pub fn run(self: std::sync::Arc<Self>) -> anyhow::Result<()> {
pub async fn run(self) {
assert!(self.n_workers == self.coordinator_clients.len());
if self.circuit_type == CircuitType::Chunk {
assert!(self.l2geth_client.is_some());
}

for i in 0..self.n_workers {
let self_clone = std::sync::Arc::clone(&self);
thread::spawn(move || {
self_clone.working_loop(i);
let mut provers = JoinSet::new();
let self_arc = std::sync::Arc::new(self);
for i in 0..self_arc.n_workers {
let self_clone = std::sync::Arc::clone(&self_arc);
provers.spawn(async move {
self_clone.working_loop(i).await;
});
}

Ok(())
while provers.join_next().await.is_some() {}
}

fn working_loop(&self, i: usize) {
async fn working_loop(&self, i: usize) {
loop {
let coordinator_client = &self.coordinator_clients[i];
let prover_name = coordinator_client.prover_name.clone();

log::info!("{:?}: getting task from coordinator", prover_name);

let get_task_request = self.build_get_task_request();
let coordinator_task = coordinator_client.get_task(&get_task_request);
let get_task_request = self.build_get_task_request().await;
let coordinator_task = coordinator_client.get_task(&get_task_request).await;

if let Err(e) = coordinator_task {
log::error!("{:?}: failed to get task: {:?}", prover_name, e);
Expand All @@ -75,7 +78,7 @@ impl Prover {
let coordinator_task_id = coordinator_task.task_id.clone();
let task_type = coordinator_task.task_type;

let proving_input = match self.build_proving_input(&coordinator_task) {
let proving_input = match self.build_proving_input(&coordinator_task).await {
Ok(input) => input,
Err(e) => {
log::error!(
Expand Down Expand Up @@ -148,7 +151,7 @@ impl Prover {
failure_type: None,
failure_msg: None,
};
match coordinator_client.submit_proof(&submit_proof_req) {
match coordinator_client.submit_proof(&submit_proof_req).await {
Ok(_) => {
log::info!(
"{:?}: proof submitted. task_type: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}, proving_service_task_id: {:?}",
Expand Down Expand Up @@ -193,7 +196,7 @@ impl Prover {
failure_type: Some(ProofFailureType::Panic), // TODO: handle ProofFailureType::NoPanic
failure_msg: Some(task_err),
};
match coordinator_client.submit_proof(&submit_proof_req) {
match coordinator_client.submit_proof(&submit_proof_req).await {
Ok(_) => {
log::info!(
"{:?}: proof_err submitted. task_type: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}, proving_service_task_id: {:?}",
Expand Down Expand Up @@ -224,21 +227,25 @@ impl Prover {
}
}

fn build_get_task_request(&self) -> GetTaskRequest {
let prover_height = self.l2geth_client.as_ref().and_then(|l2geth_client| {
l2geth_client
.block_number_sync()
.ok()
.and_then(|block_number| block_number.as_number())
});
async fn build_get_task_request(&self) -> GetTaskRequest {
let prover_height = match &self.l2geth_client {
None => None,
Some(l2geth_client) => match l2geth_client.block_number().await {
Ok(block_number) => block_number.as_number(),
Err(_) => None,
},
};

GetTaskRequest {
task_types: vec![self.circuit_type],
prover_height,
}
}

fn build_proving_input(&self, task: &GetTaskResponseData) -> anyhow::Result<ProveRequest> {
async fn build_proving_input(
&self,
task: &GetTaskResponseData,
) -> anyhow::Result<ProveRequest> {
anyhow::ensure!(
task.task_type == self.circuit_type,
"task type mismatch. self: {:?}, task: {:?}, coordinator_task_uuid: {:?}, coordinator_task_id: {:?}",
Expand All @@ -258,7 +265,8 @@ impl Prover {
.l2geth_client
.as_ref()
.unwrap()
.get_sorted_traces_by_hashes(&chunk_task_detail.block_hashes)?;
.get_sorted_traces_by_hashes(&chunk_task_detail.block_hashes)
.await?;
let input = serde_json::to_string(&traces)?;

Ok(ProveRequest {
Expand Down
26 changes: 5 additions & 21 deletions src/tracing_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,20 @@ use prover_darwin_v2::BlockTrace;
use serde::{de::DeserializeOwned, Serialize};
use std::cmp::Ordering;
use std::fmt::Debug;
use tokio::runtime::Runtime;

pub type CommonHash = H256;

pub struct L2gethClient {
provider: Provider<Http>,
rt: Runtime,
}

impl L2gethClient {
pub fn new(cfg: L2GethConfig) -> anyhow::Result<Self> {
let provider = Provider::<Http>::try_from(cfg.endpoint)?;
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
Ok(Self { provider, rt })
Ok(Self { provider })
}

async fn get_block_trace_by_hash_async<T>(&self, hash: &CommonHash) -> anyhow::Result<T>
pub async fn get_block_trace_by_hash<T>(&self, hash: &CommonHash) -> anyhow::Result<T>
where
T: Serialize + DeserializeOwned + Debug + Send,
{
Expand All @@ -40,25 +35,14 @@ impl L2gethClient {
Ok(trace)
}

pub fn get_block_trace_by_hash_sync<T>(&self, hash: &CommonHash) -> anyhow::Result<T>
where
T: Serialize + DeserializeOwned + Debug + Send,
{
self.rt.block_on(self.get_block_trace_by_hash_async(hash))
}

async fn block_number_async(&self) -> anyhow::Result<BlockNumber> {
pub async fn block_number(&self) -> anyhow::Result<BlockNumber> {
log::info!("l2geth_client calling block_number");

let trace = self.provider.request("eth_blockNumber", ()).await?;
Ok(trace)
}

pub fn block_number_sync(&self) -> anyhow::Result<BlockNumber> {
self.rt.block_on(self.block_number_async())
}

pub fn get_sorted_traces_by_hashes(
pub async fn get_sorted_traces_by_hashes(
&self,
block_hashes: &[CommonHash],
) -> anyhow::Result<Vec<BlockTrace>> {
Expand All @@ -69,7 +53,7 @@ impl L2gethClient {

let mut block_traces = Vec::new();
for hash in block_hashes.iter() {
let trace = self.get_block_trace_by_hash_sync(hash)?;
let trace = self.get_block_trace_by_hash(hash).await?;
block_traces.push(trace);
}

Expand Down