Skip to content

Commit

Permalink
feat(prover): add chunk & batch proving circuit error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
colinlyguo committed Aug 28, 2023
1 parent f8d48a6 commit 97801c5
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 54 deletions.
82 changes: 59 additions & 23 deletions common/libzkp/impl/src/batch.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::utils::{c_char_to_str, c_char_to_vec, string_to_c_char, vec_to_c_char, OUTPUT_DIR};
use crate::types::{CheckChunkProofsResponse, ProofResult};
use libc::c_char;
use prover::{
aggregator::{Prover, Verifier},
Expand Down Expand Up @@ -54,13 +55,33 @@ pub unsafe extern "C" fn get_batch_vk() -> *const c_char {

/// # Safety
#[no_mangle]
pub unsafe extern "C" fn check_chunk_proofs(chunk_proofs: *const c_char) -> c_char {
let chunk_proofs = c_char_to_vec(chunk_proofs);
let chunk_proofs = serde_json::from_slice::<Vec<ChunkProof>>(&chunk_proofs).unwrap();
assert!(!chunk_proofs.is_empty());

let valid = panic::catch_unwind(|| PROVER.get().unwrap().check_chunk_proofs(&chunk_proofs));
valid.unwrap_or(false) as c_char
pub unsafe extern "C" fn check_chunk_proofs(chunk_proofs: *const c_char) -> *const c_char {
let check_result: Result<bool, String> = panic::catch_unwind(|| {
let chunk_proofs = c_char_to_vec(chunk_proofs);
let chunk_proofs = serde_json::from_slice::<Vec<ChunkProof>>(&chunk_proofs)
.map_err(|e| format!("Failed to deserialize chunk proofs: {:?}", e))?;
if chunk_proofs.is_empty() {
return Err("Provided chunk proofs are empty.".to_string());
}

PROVER.get()
.map_err(|_| "Failed to get reference to PROVER.".to_string())?
.check_chunk_proofs(&chunk_proofs)
.map_err(|e| format!("Error checking chunk proofs: {:?}", e))
}).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err)));

let r = match check_result {
Ok(valid) => CheckChunkProofsResponse {
ok: valid,
error: None,
},
Err(err) => CheckChunkProofsResponse {
ok: false,
error: Some(err),
},
};

serde_json::to_vec(&r).map_or(null(), vec_to_c_char)
}

/// # Safety
Expand All @@ -69,28 +90,43 @@ pub unsafe extern "C" fn gen_batch_proof(
chunk_hashes: *const c_char,
chunk_proofs: *const c_char,
) -> *const c_char {
let chunk_hashes = c_char_to_vec(chunk_hashes);
let chunk_proofs = c_char_to_vec(chunk_proofs);
let proof_result: Result<Vec<u8>, String> = panic::catch_unwind(|| {
let chunk_hashes = c_char_to_vec(chunk_hashes);
let chunk_proofs = c_char_to_vec(chunk_proofs);

let chunk_hashes = serde_json::from_slice::<Vec<ChunkHash>>(&chunk_hashes)
.map_err(|e| format!("Failed to deserialize chunk hashes: {:?}", e))?;
let chunk_proofs = serde_json::from_slice::<Vec<ChunkProof>>(&chunk_proofs)
.map_err(|e| format!("Failed to deserialize chunk proofs: {:?}", e))?;

let chunk_hashes = serde_json::from_slice::<Vec<ChunkHash>>(&chunk_hashes).unwrap();
let chunk_proofs = serde_json::from_slice::<Vec<ChunkProof>>(&chunk_proofs).unwrap();
assert_eq!(chunk_hashes.len(), chunk_proofs.len());
if chunk_hashes.len() != chunk_proofs.len() {
return Err("Chunk hashes and chunk proofs lengths mismatch.".to_string());
}

let chunk_hashes_proofs = chunk_hashes
.into_iter()
.zip(chunk_proofs.into_iter())
.collect();
let chunk_hashes_proofs = chunk_hashes.into_iter().zip(chunk_proofs.into_iter()).collect();

let proof_result = panic::catch_unwind(|| {
let proof = PROVER
.get_mut()
.unwrap()
.map_err(|_| "Failed to get mutable reference to PROVER.".to_string())?
.gen_agg_evm_proof(chunk_hashes_proofs, None, OUTPUT_DIR.as_deref())
.unwrap();

serde_json::to_vec(&proof).unwrap()
});
proof_result.map_or(null(), vec_to_c_char)
.map_err(|e| format!("Proof generation failed: {:?}", e))?;

serde_json::to_vec(&proof)
.map_err(|e| format!("Failed to serialize the proof: {:?}", e))
}).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err)));

let r = match proof_result {
Ok(proof_bytes) => ProofResult {
message: Some(proof_bytes),
error: None,
},
Err(err) => ProofResult {
message: None,
error: Some(err),
},
};

serde_json::to_vec(&r).map_or(null(), vec_to_c_char)
}

/// # Safety
Expand Down
34 changes: 24 additions & 10 deletions common/libzkp/impl/src/chunk.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::utils::{c_char_to_str, c_char_to_vec, string_to_c_char, vec_to_c_char, OUTPUT_DIR};
use crate::types::ProofResult;
use libc::c_char;
use prover::{
utils::init_env_and_log,
Expand Down Expand Up @@ -55,20 +56,33 @@ pub unsafe extern "C" fn get_chunk_vk() -> *const c_char {
/// # Safety
#[no_mangle]
pub unsafe extern "C" fn gen_chunk_proof(block_traces: *const c_char) -> *const c_char {
let block_traces = c_char_to_vec(block_traces);
let block_traces = serde_json::from_slice::<Vec<BlockTrace>>(&block_traces).unwrap();
let proof_result: Result<Vec<u8>, String> = panic::catch_unwind(|| {
let block_traces = c_char_to_vec(block_traces);
let block_traces = serde_json::from_slice::<Vec<BlockTrace>>(&block_traces)
.map_err(|e| format!("Failed to deserialize block traces: {:?}", e))?;

let proof_result = panic::catch_unwind(|| {
let proof = PROVER
.get_mut()
.unwrap()
.map_err(|_| "Failed to get mutable reference to PROVER.".to_string())?
.gen_chunk_proof(block_traces, None, OUTPUT_DIR.as_deref())
.unwrap();

serde_json::to_vec(&proof).unwrap()
});

proof_result.map_or(null(), vec_to_c_char)
.map_err(|e| format!("Proof generation failed: {:?}", e))?;

serde_json::to_vec(&proof)
.map_err(|e| format!("Failed to serialize the proof: {:?}", e))
}).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err)));

let r = match proof_result {
Ok(proof_bytes) => ProofResult {
message: Some(proof_bytes),
error: None,
},
Err(err) => ProofResult {
message: None,
error: Some(err),
},
};

serde_json::to_vec(&r).map_or(null(), vec_to_c_char)
}

/// # Safety
Expand Down
1 change: 1 addition & 0 deletions common/libzkp/impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
mod batch;
mod chunk;
mod utils;
mod types;
22 changes: 22 additions & 0 deletions common/libzkp/impl/src/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use serde::{Deserialize, Serialize};

// Represents the result of a chunk proof checking operation.
// `ok` indicates whether the proof checking was successful.
// `error` provides additional details in case the check failed.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CheckChunkProofsResponse {
ok: bool,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}

// Encapsulates the result from generating a proof.
// `message` holds the generated proof in byte slice format.
// `error` provides additional details in case the proof generation failed.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ProofResult {
#[serde(skip_serializing_if = "Option::is_none")]
message: Option<Vec<u8>>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
2 changes: 1 addition & 1 deletion common/libzkp/interface/libzkp.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
void init_batch_prover(char* params_dir, char* assets_dir);
void init_batch_verifier(char* params_dir, char* assets_dir);
char* get_batch_vk();
char check_chunk_proofs(char* chunk_proofs);
char* check_chunk_proofs(char* chunk_proofs);
char* gen_batch_proof(char* chunk_hashes, char* chunk_proofs);
char verify_batch_proof(char* proof);

Expand Down
89 changes: 69 additions & 20 deletions prover/core/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,19 @@ func (p *ProverCore) ProveBatch(taskID string, chunkInfos []*message.ChunkInfo,
return nil, err
}

if !p.checkChunkProofs(chunkProofsByt) {
isValid, err := p.checkChunkProofs(chunkProofsByt)
if err != nil {
return nil, err
}

if !isValid {
return nil, fmt.Errorf("Non-match chunk protocol: task-id = %s", taskID)
}

proofByt := p.proveBatch(chunkInfosByt, chunkProofsByt)
proofByt, err := p.proveBatch(chunkInfosByt, chunkProofsByt)
if err != nil {
return nil, fmt.Errorf("Error generating batch proof: %v", err)
}

err = p.mayDumpProof(taskID, proofByt)
if err != nil {
Expand Down Expand Up @@ -140,21 +148,45 @@ func (p *ProverCore) TracesToChunkInfo(traces []*types.BlockTrace) (*message.Chu
return chunkInfo, json.Unmarshal(chunkInfoByt, chunkInfo)
}

func (p *ProverCore) checkChunkProofs(chunkProofsByt []byte) bool {
chunkProofsStr := C.CString(string(chunkProofsByt))
// CheckChunkProofsResponse represents the result of a chunk proof checking operation.
// Ok indicates whether the proof checking was successful.
// Error provides additional details in case the check failed.
type CheckChunkProofsResponse struct {
Ok bool `json:"ok"`
Error string `json:"error"`
}

defer func() {
C.free(unsafe.Pointer(chunkProofsStr))
}()
// ProofResult encapsulates the result from generating a proof.
// Message holds the generated proof in byte slice format.
// Error provides additional details in case the proof generation failed.
type ProofResult struct {
Message []byte `json:"message"`
Error string `json:"error"`
}

func (p *ProverCore) checkChunkProofs(chunkProofsByt []byte) (bool, error) {
chunkProofsStr := C.CString(string(chunkProofsByt))
defer C.free(unsafe.Pointer(chunkProofsStr))

log.Info("Start to check chunk proofs ...")
valid := C.check_chunk_proofs(chunkProofsStr)
cResult := C.check_chunk_proofs(chunkProofsStr)
defer C.free(unsafe.Pointer(cResult))
log.Info("Finish checking chunk proofs!")

return valid != 0
var result CheckChunkProofsResponse
err := json.Unmarshal([]byte(C.GoString(cResult)), &result)
if err != nil {
return false, fmt.Errorf("Failed to parse check proof result: %v", err)
}

if result.Error != "" {
return false, fmt.Errorf("Failed to check_chunk_proofs: %s", result.Error)
}

return result.Ok, nil
}

func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) []byte {
func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) ([]byte, error) {
chunkInfosStr := C.CString(string(chunkInfosByt))
chunkProofsStr := C.CString(string(chunkProofsByt))

Expand All @@ -164,26 +196,43 @@ func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) []b
}()

log.Info("Start to create batch proof ...")
cProof := C.gen_batch_proof(chunkInfosStr, chunkProofsStr)
bResult := C.gen_batch_proof(chunkInfosStr, chunkProofsStr)
defer C.free(unsafe.Pointer(bResult))
log.Info("Finish creating batch proof!")

proof := C.GoString(cProof)
return []byte(proof)
var result ProofResult
err := json.Unmarshal([]byte(C.GoString(bResult)), &result)
if err != nil {
return nil, fmt.Errorf("Failed to parse batch proof result: %v", err)
}

if result.Error != "" {
return nil, fmt.Errorf("Failed to gen_batch_proof: %s", result.Error)
}

return result.Message, nil
}

func (p *ProverCore) proveChunk(tracesByt []byte) []byte {
func (p *ProverCore) proveChunk(tracesByt []byte) ([]byte, error) {
tracesStr := C.CString(string(tracesByt))

defer func() {
C.free(unsafe.Pointer(tracesStr))
}()
defer C.free(unsafe.Pointer(tracesStr))

log.Info("Start to create chunk proof ...")
cProof := C.gen_chunk_proof(tracesStr)
defer C.free(unsafe.Pointer(cProof))
log.Info("Finish creating chunk proof!")

proof := C.GoString(cProof)
return []byte(proof)
var result ProofResult
err := json.Unmarshal([]byte(C.GoString(cProof)), &result)
if err != nil {
return nil, fmt.Errorf("Failed to parse chunk proof result: %v", err)
}

if result.Error != "" {
return nil, fmt.Errorf("Failed to gen_chunk_proof: %s", result.Error)
}

return result.Message, nil
}

func (p *ProverCore) mayDumpProof(id string, proofByt []byte) error {
Expand Down

0 comments on commit 97801c5

Please sign in to comment.