Skip to content

Commit

Permalink
redo "use tokio async pattern" (#18)
Browse files Browse the repository at this point in the history
Co-authored-by: Zhuo Zhang <mycinbrin@gmail.com>
  • Loading branch information
0xmountaintop and lispc authored Oct 4, 2024
1 parent 0b2105c commit 08bcf6b
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 115 deletions.
25 changes: 25 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ tiny-keccak = { version = "2.0.0", features = ["sha3", "keccak"] }
rand = "0.8.5"
eth-keystore = "0.5.0"
rlp = "0.5.2"
tokio = "1.37.0"
tokio = { version = "1.37.0", features = ["full"] }
async-trait = "0.1"
sled = "0.34.7"
http = "1.1.0"
clap = { version = "4.5", features = ["derive"] }
Expand Down
19 changes: 11 additions & 8 deletions examples/cloud.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use async_trait::async_trait;
use clap::Parser;
use std::sync::Arc;

use scroll_proving_sdk::{
config::{CloudProverConfig, Config},
Expand All @@ -26,17 +26,18 @@ struct CloudProver {
api_key: String,
}

#[async_trait]
impl ProvingService for CloudProver {
fn is_local(&self) -> bool {
false
}
fn get_vk(&self, req: GetVkRequest) -> GetVkResponse {
async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse {
todo!()
}
fn prove(&self, req: ProveRequest) -> ProveResponse {
async fn prove(&self, req: ProveRequest) -> ProveResponse {
todo!()
}
fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse {
async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse {
todo!()
}
}
Expand All @@ -50,17 +51,19 @@ impl CloudProver {
}
}

fn main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing();

let args = Args::parse();
let cfg: Config = Config::from_file(args.config_file)?;
let cloud_prover = CloudProver::new(cfg.prover.cloud.clone().unwrap());
let prover = ProverBuilder::new(cfg)
.with_proving_service(Box::new(cloud_prover))
.build()?;
.build()
.await?;

Arc::new(prover).run()?;
prover.run().await;

loop {}
Ok(())
}
19 changes: 11 additions & 8 deletions examples/local.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use async_trait::async_trait;
use clap::Parser;
use std::sync::Arc;

use scroll_proving_sdk::{
config::{Config, LocalProverConfig},
Expand All @@ -23,17 +23,18 @@ struct Args {

struct LocalProver {}

#[async_trait]
impl ProvingService for LocalProver {
fn is_local(&self) -> bool {
true
}
fn get_vk(&self, req: GetVkRequest) -> GetVkResponse {
async fn get_vk(&self, req: GetVkRequest) -> GetVkResponse {
todo!()
}
fn prove(&self, req: ProveRequest) -> ProveResponse {
async fn prove(&self, req: ProveRequest) -> ProveResponse {
todo!()
}
fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse {
async fn query_task(&self, req: QueryTaskRequest) -> QueryTaskResponse {
todo!()
}
}
Expand All @@ -44,17 +45,19 @@ impl LocalProver {
}
}

fn main() -> anyhow::Result<()> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
init_tracing();

let args = Args::parse();
let cfg: Config = Config::from_file(args.config_file)?;
let local_prover = LocalProver::new(cfg.prover.local.clone().unwrap());
let prover = ProverBuilder::new(cfg)
.with_proving_service(Box::new(local_prover))
.build()?;
.build()
.await?;

Arc::new(prover).run()?;
prover.run().await;

loop {}
Ok(())
}
68 changes: 21 additions & 47 deletions src/coordinator_handler/coordinator_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use super::{
LoginRequest, Response, SubmitProofRequest, SubmitProofResponseData,
};
use crate::{config::CoordinatorConfig, prover::CircuitType, utils::get_version};
use std::sync::{Mutex, MutexGuard};
use tokio::runtime::Runtime;
use tokio::sync::{Mutex, MutexGuard};

pub struct CoordinatorClient {
circuit_type: CircuitType,
Expand All @@ -14,7 +13,6 @@ pub struct CoordinatorClient {
key_signer: KeySigner,
api: Api,
token: Mutex<Option<String>>,
rt: Runtime,
}

impl CoordinatorClient {
Expand All @@ -26,9 +24,6 @@ impl CoordinatorClient {
prover_name: String,
key_signer: KeySigner,
) -> anyhow::Result<Self> {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let api = Api::new(cfg)?;
let client = Self {
circuit_type,
Expand All @@ -38,79 +33,58 @@ impl CoordinatorClient {
key_signer,
api,
token: Mutex::new(None),
rt,
};
Ok(client)
}

pub fn get_task(&self, req: &GetTaskRequest) -> anyhow::Result<Response<GetTaskResponseData>> {
let token = self.get_token_sync(false)?;
let response = self.get_task_sync(req, &token)?;
pub async fn get_task(
&self,
req: &GetTaskRequest,
) -> anyhow::Result<Response<GetTaskResponseData>> {
let token = self.get_token(false).await?;
let response = self.api.get_task(req, &token).await?;

if response.errcode == ErrorCode::ErrJWTTokenExpired {
let token = self.get_token_sync(true)?;
self.get_task_sync(req, &token)
let token = self.get_token(true).await?;
self.api.get_task(req, &token).await
} else {
Ok(response)
}
}

pub fn submit_proof(
pub async fn submit_proof(
&self,
req: &SubmitProofRequest,
) -> anyhow::Result<Response<SubmitProofResponseData>> {
let token = self.get_token_sync(false)?;
let response = self.submit_proof_sync(req, &token)?;
let token = self.get_token(false).await?;
let response = self.api.submit_proof(req, &token).await?;

if response.errcode == ErrorCode::ErrJWTTokenExpired {
let token = self.get_token_sync(true)?;
self.submit_proof_sync(req, &token)
let token = self.get_token(true).await?;
self.api.submit_proof(req, &token).await
} else {
Ok(response)
}
}

fn get_task_sync(
&self,
req: &GetTaskRequest,
token: &String,
) -> anyhow::Result<Response<GetTaskResponseData>> {
self.rt.block_on(self.api.get_task(req, token))
}

fn submit_proof_sync(
&self,
req: &SubmitProofRequest,
token: &String,
) -> anyhow::Result<Response<SubmitProofResponseData>> {
self.rt.block_on(self.api.submit_proof(req, token))
}

fn get_token_sync(&self, force_relogin: bool) -> anyhow::Result<String> {
self.rt.block_on(self.get_token_async(force_relogin))
}

/// Retrieves a token for authentication, optionally forcing a re-login.
///
/// This function attempts to get the stored token if `force_relogin` is set to `false`.
///
/// If the token is expired, `force_relogin` is set to `true`, or a login was never performed
/// before, it will authenticate and fetch a new token.
async fn get_token_async(&self, force_relogin: bool) -> anyhow::Result<String> {
let token_guard = self
.token
.lock()
.expect("Mutex locking only occurs within `get_token` fn, so there can be no double `lock` for one thread");

match token_guard.as_deref() {
Some(token) if !force_relogin => return Ok(token.to_string()),
async fn get_token(&self, force_relogin: bool) -> anyhow::Result<String> {
let token_guard = self.token.lock().await;

match *token_guard {
Some(ref token) if !force_relogin => return Ok(token.to_string()),
_ => (),
}

self.login_async(token_guard).await
self.login(token_guard).await
}

async fn login_async<'t>(
async fn login<'t>(
&self,
mut token_guard: MutexGuard<'t, Option<String>>,
) -> anyhow::Result<String> {
Expand Down
5 changes: 3 additions & 2 deletions src/prover/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl ProverBuilder {
self
}

pub fn build(self) -> anyhow::Result<Prover> {
pub async fn build(self) -> anyhow::Result<Prover> {
if self.proving_service.is_none() {
anyhow::bail!("proving_service is not provided");
}
Expand All @@ -51,7 +51,8 @@ impl ProverBuilder {
.proving_service
.as_ref()
.unwrap()
.get_vk(get_vk_request);
.get_vk(get_vk_request)
.await;
if let Some(error) = get_vk_response.error {
anyhow::bail!("failed to get vk: {}", error);
}
Expand Down
Loading

0 comments on commit 08bcf6b

Please sign in to comment.