diff --git a/common/libzkp/impl/src/batch.rs b/common/libzkp/impl/src/batch.rs index 9baf43c9ce..8629b12d96 100644 --- a/common/libzkp/impl/src/batch.rs +++ b/common/libzkp/impl/src/batch.rs @@ -1,4 +1,5 @@ use crate::utils::{c_char_to_str, c_char_to_vec, string_to_c_char, vec_to_c_char, OUTPUT_DIR}; +use crate::types::{CheckChunkProofsResponse, ProofResult}; use libc::c_char; use prover::{ aggregator::{Prover, Verifier}, @@ -54,13 +55,33 @@ pub unsafe extern "C" fn get_batch_vk() -> *const c_char { /// # Safety #[no_mangle] -pub unsafe extern "C" fn check_chunk_proofs(chunk_proofs: *const c_char) -> c_char { - let chunk_proofs = c_char_to_vec(chunk_proofs); - let chunk_proofs = serde_json::from_slice::>(&chunk_proofs).unwrap(); - assert!(!chunk_proofs.is_empty()); - - let valid = panic::catch_unwind(|| PROVER.get().unwrap().check_chunk_proofs(&chunk_proofs)); - valid.unwrap_or(false) as c_char +pub unsafe extern "C" fn check_chunk_proofs(chunk_proofs: *const c_char) -> *const c_char { + let check_result: Result = panic::catch_unwind(|| { + let chunk_proofs = c_char_to_vec(chunk_proofs); + let chunk_proofs = serde_json::from_slice::>(&chunk_proofs) + .map_err(|e| format!("Failed to deserialize chunk proofs: {:?}", e))?; + if chunk_proofs.is_empty() { + return Err("Provided chunk proofs are empty.".to_string()); + } + + PROVER.get() + .map_err(|_| "Failed to get reference to PROVER.".to_string())? + .check_chunk_proofs(&chunk_proofs) + .map_err(|e| format!("Error checking chunk proofs: {:?}", e)) + }).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err))); + + let r = match check_result { + Ok(valid) => CheckChunkProofsResponse { + ok: valid, + error: None, + }, + Err(err) => CheckChunkProofsResponse { + ok: false, + error: Some(err), + }, + }; + + serde_json::to_vec(&r).map_or(null(), vec_to_c_char) } /// # Safety @@ -69,28 +90,43 @@ pub unsafe extern "C" fn gen_batch_proof( chunk_hashes: *const c_char, chunk_proofs: *const c_char, ) -> *const c_char { - let chunk_hashes = c_char_to_vec(chunk_hashes); - let chunk_proofs = c_char_to_vec(chunk_proofs); + let proof_result: Result, String> = panic::catch_unwind(|| { + let chunk_hashes = c_char_to_vec(chunk_hashes); + let chunk_proofs = c_char_to_vec(chunk_proofs); + + let chunk_hashes = serde_json::from_slice::>(&chunk_hashes) + .map_err(|e| format!("Failed to deserialize chunk hashes: {:?}", e))?; + let chunk_proofs = serde_json::from_slice::>(&chunk_proofs) + .map_err(|e| format!("Failed to deserialize chunk proofs: {:?}", e))?; - let chunk_hashes = serde_json::from_slice::>(&chunk_hashes).unwrap(); - let chunk_proofs = serde_json::from_slice::>(&chunk_proofs).unwrap(); - assert_eq!(chunk_hashes.len(), chunk_proofs.len()); + if chunk_hashes.len() != chunk_proofs.len() { + return Err("Chunk hashes and chunk proofs lengths mismatch.".to_string()); + } - let chunk_hashes_proofs = chunk_hashes - .into_iter() - .zip(chunk_proofs.into_iter()) - .collect(); + let chunk_hashes_proofs = chunk_hashes.into_iter().zip(chunk_proofs.into_iter()).collect(); - let proof_result = panic::catch_unwind(|| { let proof = PROVER .get_mut() - .unwrap() + .map_err(|_| "Failed to get mutable reference to PROVER.".to_string())? .gen_agg_evm_proof(chunk_hashes_proofs, None, OUTPUT_DIR.as_deref()) - .unwrap(); - - serde_json::to_vec(&proof).unwrap() - }); - proof_result.map_or(null(), vec_to_c_char) + .map_err(|e| format!("Proof generation failed: {:?}", e))?; + + serde_json::to_vec(&proof) + .map_err(|e| format!("Failed to serialize the proof: {:?}", e)) + }).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err))); + + let r = match proof_result { + Ok(proof_bytes) => ProofResult { + message: Some(proof_bytes), + error: None, + }, + Err(err) => ProofResult { + message: None, + error: Some(err), + }, + }; + + serde_json::to_vec(&r).map_or(null(), vec_to_c_char) } /// # Safety diff --git a/common/libzkp/impl/src/chunk.rs b/common/libzkp/impl/src/chunk.rs index e273711f67..9c0ec9bfdc 100644 --- a/common/libzkp/impl/src/chunk.rs +++ b/common/libzkp/impl/src/chunk.rs @@ -1,4 +1,5 @@ use crate::utils::{c_char_to_str, c_char_to_vec, string_to_c_char, vec_to_c_char, OUTPUT_DIR}; +use crate::types::ProofResult; use libc::c_char; use prover::{ utils::init_env_and_log, @@ -55,20 +56,33 @@ pub unsafe extern "C" fn get_chunk_vk() -> *const c_char { /// # Safety #[no_mangle] pub unsafe extern "C" fn gen_chunk_proof(block_traces: *const c_char) -> *const c_char { - let block_traces = c_char_to_vec(block_traces); - let block_traces = serde_json::from_slice::>(&block_traces).unwrap(); + let proof_result: Result, String> = panic::catch_unwind(|| { + let block_traces = c_char_to_vec(block_traces); + let block_traces = serde_json::from_slice::>(&block_traces) + .map_err(|e| format!("Failed to deserialize block traces: {:?}", e))?; - let proof_result = panic::catch_unwind(|| { let proof = PROVER .get_mut() - .unwrap() + .map_err(|_| "Failed to get mutable reference to PROVER.".to_string())? .gen_chunk_proof(block_traces, None, OUTPUT_DIR.as_deref()) - .unwrap(); - - serde_json::to_vec(&proof).unwrap() - }); - - proof_result.map_or(null(), vec_to_c_char) + .map_err(|e| format!("Proof generation failed: {:?}", e))?; + + serde_json::to_vec(&proof) + .map_err(|e| format!("Failed to serialize the proof: {:?}", e)) + }).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err))); + + let r = match proof_result { + Ok(proof_bytes) => ProofResult { + message: Some(proof_bytes), + error: None, + }, + Err(err) => ProofResult { + message: None, + error: Some(err), + }, + }; + + serde_json::to_vec(&r).map_or(null(), vec_to_c_char) } /// # Safety diff --git a/common/libzkp/impl/src/lib.rs b/common/libzkp/impl/src/lib.rs index d60bbd5e18..c172cbabf5 100644 --- a/common/libzkp/impl/src/lib.rs +++ b/common/libzkp/impl/src/lib.rs @@ -3,3 +3,4 @@ mod batch; mod chunk; mod utils; +mod types; diff --git a/common/libzkp/impl/src/types.rs b/common/libzkp/impl/src/types.rs new file mode 100644 index 0000000000..c4163a51db --- /dev/null +++ b/common/libzkp/impl/src/types.rs @@ -0,0 +1,22 @@ +use serde::{Deserialize, Serialize}; + +// Represents the result of a chunk proof checking operation. +// `ok` indicates whether the proof checking was successful. +// `error` provides additional details in case the check failed. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CheckChunkProofsResponse { + ok: bool, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +// Encapsulates the result from generating a proof. +// `message` holds the generated proof in byte slice format. +// `error` provides additional details in case the proof generation failed. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ProofResult { + #[serde(skip_serializing_if = "Option::is_none")] + message: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} diff --git a/common/libzkp/interface/libzkp.h b/common/libzkp/interface/libzkp.h index 8967f0d8d7..1c84af15bd 100644 --- a/common/libzkp/interface/libzkp.h +++ b/common/libzkp/interface/libzkp.h @@ -1,7 +1,7 @@ void init_batch_prover(char* params_dir, char* assets_dir); void init_batch_verifier(char* params_dir, char* assets_dir); char* get_batch_vk(); -char check_chunk_proofs(char* chunk_proofs); +char* check_chunk_proofs(char* chunk_proofs); char* gen_batch_proof(char* chunk_hashes, char* chunk_proofs); char verify_batch_proof(char* proof); diff --git a/prover/core/prover.go b/prover/core/prover.go index c40ebb8f71..ef7c7a9d0a 100644 --- a/prover/core/prover.go +++ b/prover/core/prover.go @@ -92,11 +92,19 @@ func (p *ProverCore) ProveBatch(taskID string, chunkInfos []*message.ChunkInfo, return nil, err } - if !p.checkChunkProofs(chunkProofsByt) { + isValid, err := p.checkChunkProofs(chunkProofsByt) + if err != nil { + return nil, err + } + + if !isValid { return nil, fmt.Errorf("Non-match chunk protocol: task-id = %s", taskID) } - proofByt := p.proveBatch(chunkInfosByt, chunkProofsByt) + proofByt, err := p.proveBatch(chunkInfosByt, chunkProofsByt) + if err != nil { + return nil, fmt.Errorf("Error generating batch proof: %v", err) + } err = p.mayDumpProof(taskID, proofByt) if err != nil { @@ -140,21 +148,45 @@ func (p *ProverCore) TracesToChunkInfo(traces []*types.BlockTrace) (*message.Chu return chunkInfo, json.Unmarshal(chunkInfoByt, chunkInfo) } -func (p *ProverCore) checkChunkProofs(chunkProofsByt []byte) bool { - chunkProofsStr := C.CString(string(chunkProofsByt)) +// CheckChunkProofsResponse represents the result of a chunk proof checking operation. +// Ok indicates whether the proof checking was successful. +// Error provides additional details in case the check failed. +type CheckChunkProofsResponse struct { + Ok bool `json:"ok"` + Error string `json:"error"` +} - defer func() { - C.free(unsafe.Pointer(chunkProofsStr)) - }() +// ProofResult encapsulates the result from generating a proof. +// Message holds the generated proof in byte slice format. +// Error provides additional details in case the proof generation failed. +type ProofResult struct { + Message []byte `json:"message"` + Error string `json:"error"` +} + +func (p *ProverCore) checkChunkProofs(chunkProofsByt []byte) (bool, error) { + chunkProofsStr := C.CString(string(chunkProofsByt)) + defer C.free(unsafe.Pointer(chunkProofsStr)) log.Info("Start to check chunk proofs ...") - valid := C.check_chunk_proofs(chunkProofsStr) + cResult := C.check_chunk_proofs(chunkProofsStr) + defer C.free(unsafe.Pointer(cResult)) log.Info("Finish checking chunk proofs!") - return valid != 0 + var result CheckChunkProofsResponse + err := json.Unmarshal([]byte(C.GoString(cResult)), &result) + if err != nil { + return false, fmt.Errorf("Failed to parse check proof result: %v", err) + } + + if result.Error != "" { + return false, fmt.Errorf("Failed to check_chunk_proofs: %s", result.Error) + } + + return result.Ok, nil } -func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) []byte { +func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) ([]byte, error) { chunkInfosStr := C.CString(string(chunkInfosByt)) chunkProofsStr := C.CString(string(chunkProofsByt)) @@ -164,26 +196,43 @@ func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) []b }() log.Info("Start to create batch proof ...") - cProof := C.gen_batch_proof(chunkInfosStr, chunkProofsStr) + bResult := C.gen_batch_proof(chunkInfosStr, chunkProofsStr) + defer C.free(unsafe.Pointer(bResult)) log.Info("Finish creating batch proof!") - proof := C.GoString(cProof) - return []byte(proof) + var result ProofResult + err := json.Unmarshal([]byte(C.GoString(bResult)), &result) + if err != nil { + return nil, fmt.Errorf("Failed to parse batch proof result: %v", err) + } + + if result.Error != "" { + return nil, fmt.Errorf("Failed to gen_batch_proof: %s", result.Error) + } + + return result.Message, nil } -func (p *ProverCore) proveChunk(tracesByt []byte) []byte { +func (p *ProverCore) proveChunk(tracesByt []byte) ([]byte, error) { tracesStr := C.CString(string(tracesByt)) - - defer func() { - C.free(unsafe.Pointer(tracesStr)) - }() + defer C.free(unsafe.Pointer(tracesStr)) log.Info("Start to create chunk proof ...") cProof := C.gen_chunk_proof(tracesStr) + defer C.free(unsafe.Pointer(cProof)) log.Info("Finish creating chunk proof!") - proof := C.GoString(cProof) - return []byte(proof) + var result ProofResult + err := json.Unmarshal([]byte(C.GoString(cProof)), &result) + if err != nil { + return nil, fmt.Errorf("Failed to parse chunk proof result: %v", err) + } + + if result.Error != "" { + return nil, fmt.Errorf("Failed to gen_chunk_proof: %s", result.Error) + } + + return result.Message, nil } func (p *ProverCore) mayDumpProof(id string, proofByt []byte) error {