Skip to content

Commit

Permalink
feat(prover): add chunk & batch proving circuit error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
colinlyguo committed Aug 28, 2023
1 parent 826e847 commit 371c60f
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 63 deletions.
89 changes: 61 additions & 28 deletions common/libzkp/impl/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,32 @@ pub unsafe extern "C" fn get_batch_vk() -> *const c_char {

/// # Safety
#[no_mangle]
pub unsafe extern "C" fn check_chunk_proofs(chunk_proofs: *const c_char) -> c_char {
let chunk_proofs = c_char_to_vec(chunk_proofs);
let chunk_proofs = serde_json::from_slice::<Vec<ChunkProof>>(&chunk_proofs).unwrap();
assert!(!chunk_proofs.is_empty());

let valid = panic::catch_unwind(|| PROVER.get().unwrap().check_chunk_proofs(&chunk_proofs));
valid.unwrap_or(false) as c_char
pub unsafe extern "C" fn check_chunk_proofs(chunk_proofs: *const c_char) -> *const c_char {
let check_result: Result<bool, String> = panic::catch_unwind(|| {
let chunk_proofs = c_char_to_vec(chunk_proofs);
let chunk_proofs = serde_json::from_slice::<Vec<ChunkProof>>(&chunk_proofs)
.map_err(|e| format!("Failed to deserialize chunk proofs: {:?}", e))?;
if chunk_proofs.is_empty() {
return Err("Provided chunk proofs are empty.".to_string());
}
PROVER.get()
.map_err(|_| "Failed to get reference to PROVER.".to_string())?
.check_chunk_proofs(&chunk_proofs)
.map_err(|e| format!("Error checking chunk proofs: {:?}", e))
}).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err)));

let r = match check_result {
Ok(valid) => ChunkProofResult {
message: Some(vec![valid as u8]),
error: None,
},
Err(err) => ChunkProofResult {
message: None,
error: Some(err),
},
};

serde_json::to_vec(&r).map_or(null(), vec_to_c_char)
}

/// # Safety
Expand All @@ -69,28 +88,42 @@ pub unsafe extern "C" fn gen_batch_proof(
chunk_hashes: *const c_char,
chunk_proofs: *const c_char,
) -> *const c_char {
let chunk_hashes = c_char_to_vec(chunk_hashes);
let chunk_proofs = c_char_to_vec(chunk_proofs);

let chunk_hashes = serde_json::from_slice::<Vec<ChunkHash>>(&chunk_hashes).unwrap();
let chunk_proofs = serde_json::from_slice::<Vec<ChunkProof>>(&chunk_proofs).unwrap();
assert_eq!(chunk_hashes.len(), chunk_proofs.len());

let chunk_hashes_proofs = chunk_hashes
.into_iter()
.zip(chunk_proofs.into_iter())
.collect();

let proof_result = panic::catch_unwind(|| {
let proof = PROVER
.get_mut()
.unwrap()
.gen_agg_evm_proof(chunk_hashes_proofs, None, OUTPUT_DIR.as_deref())
.unwrap();
let proof_result: Result<Vec<u8>, String> = panic::catch_unwind(|| {
let chunk_hashes = c_char_to_vec(chunk_hashes);
let chunk_proofs = c_char_to_vec(chunk_proofs);

let chunk_hashes = serde_json::from_slice::<Vec<ChunkHash>>(&chunk_hashes)
.map_err(|e| format!("Failed to deserialize chunk hashes: {:?}", e))?;
let chunk_proofs = serde_json::from_slice::<Vec<ChunkProof>>(&chunk_proofs)
.map_err(|e| format!("Failed to deserialize chunk proofs: {:?}", e))?;

serde_json::to_vec(&proof).unwrap()
});
proof_result.map_or(null(), vec_to_c_char)
if chunk_hashes.len() != chunk_proofs.len() {
return Err("Chunk hashes and chunk proofs lengths mismatch.".to_string());
}

let chunk_hashes_proofs = chunk_hashes.into_iter().zip(chunk_proofs.into_iter()).collect();

let proof = PROVER.get_mut()
.map_err(|_| "Failed to get mutable reference to PROVER.".to_string())?
.gen_agg_evm_proof(chunk_hashes_proofs, None, OUTPUT_DIR.as_deref())
.map_err(|e| format!("Proof generation failed: {:?}", e))?;

serde_json::to_vec(&proof)
.map_err(|e| format!("Failed to serialize the proof: {:?}", e))
}).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err)));

let r = match proof_result {
Ok(proof_bytes) => ChunkProofResult {
message: Some(proof_bytes),
error: None,
},
Err(err) => ChunkProofResult {
message: None,
error: Some(err),
},
};

serde_json::to_vec(&r).map_or(null(), vec_to_c_char)
}

/// # Safety
Expand Down
48 changes: 34 additions & 14 deletions common/libzkp/impl/src/chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,43 @@ pub unsafe extern "C" fn get_chunk_vk() -> *const c_char {
.map_or(null(), |vk| string_to_c_char(base64::encode(vk)))
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChunkProofResult {
message: Option<Vec<u8>>,
error: Option<String>,
}

/// # Safety
#[no_mangle]
pub unsafe extern "C" fn gen_chunk_proof(block_traces: *const c_char) -> *const c_char {
let block_traces = c_char_to_vec(block_traces);
let block_traces = serde_json::from_slice::<Vec<BlockTrace>>(&block_traces).unwrap();

let proof_result = panic::catch_unwind(|| {
let proof = PROVER
.get_mut()
.unwrap()
.gen_chunk_proof(block_traces, None, OUTPUT_DIR.as_deref())
.unwrap();

serde_json::to_vec(&proof).unwrap()
});

proof_result.map_or(null(), vec_to_c_char)
let proof_result: Result<Vec<u8>, String> = panic::catch_unwind(|| {
let block_traces = c_char_to_vec(block_traces);
let block_traces = serde_json::from_slice::<Vec<BlockTrace>>(&block_traces)
.map_err(|e| format!("Failed to deserialize block traces: {:?}", e))?;

let prover = PROVER.get_mut().unwrap_or_else(|| {
panic!("Failed to get mutable reference to PROVER.");
});

let proof = prover.gen_chunk_proof(block_traces, None, OUTPUT_DIR.as_deref())
.map_err(|e| format!("Proof generation failed: {:?}", e))?;

serde_json::to_vec(&proof)
.map_err(|e| format!("Failed to serialize the proof: {:?}", e))
}).unwrap_or_else(|err| Err(format!("Unwind error: {:?}", err)));

let r = match proof_result {
Ok(proof_bytes) => ChunkProofResult {
message: Some(proof_bytes),
error: None,
},
Err(err) => ChunkProofResult {
message: None,
error: Some(err),
},
};

serde_json::to_vec(&r).map_or(null(), vec_to_c_char)
}

/// # Safety
Expand Down
92 changes: 71 additions & 21 deletions prover/core/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import "C" //nolint:typecheck

import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -92,11 +93,19 @@ func (p *ProverCore) ProveBatch(taskID string, chunkInfos []*message.ChunkInfo,
return nil, err
}

if !p.checkChunkProofs(chunkProofsByt) {
isValid, err := p.checkChunkProofs(chunkProofsByt)
if err != nil {
return nil, err
}

if !isValid {
return nil, fmt.Errorf("Non-match chunk protocol: task-id = %s", taskID)
}

proofByt := p.proveBatch(chunkInfosByt, chunkProofsByt)
proofByt, err := p.proveBatch(chunkInfosByt, chunkProofsByt)
if err != nil {
return nil, fmt.Errorf("Error generating batch proof: %v", err)
}

err = p.mayDumpProof(taskID, proofByt)
if err != nil {
Expand Down Expand Up @@ -140,21 +149,35 @@ func (p *ProverCore) TracesToChunkInfo(traces []*types.BlockTrace) (*message.Chu
return chunkInfo, json.Unmarshal(chunkInfoByt, chunkInfo)
}

func (p *ProverCore) checkChunkProofs(chunkProofsByt []byte) bool {
func (p *ProverCore) checkChunkProofs(chunkProofsByt []byte) (bool, error) {
chunkProofsStr := C.CString(string(chunkProofsByt))

defer func() {
C.free(unsafe.Pointer(chunkProofsStr))
}()
defer C.free(unsafe.Pointer(chunkProofsStr))

log.Info("Start to check chunk proofs ...")
valid := C.check_chunk_proofs(chunkProofsStr)
cResult := C.check_chunk_proofs(chunkProofsStr)
defer C.free(unsafe.Pointer(cResult))
log.Info("Finish checking chunk proofs!")

return valid != 0
resultJson := C.GoString(cResult)

var result ChunkProofResult
err := json.Unmarshal([]byte(resultJson), &result)
if err != nil {
return false, errors.New("Failed to parse check proof result: " + err.Error())
}

if result.Error != nil {
return false, errors.New(*result.Error)
}

if result.Message != nil && len(*result.Message) > 0 {
return (*result.Message)[0] != 0, nil
}

return false, errors.New("Unexpected state: No message returned")
}

func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) []byte {
func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) ([]byte, error) {
chunkInfosStr := C.CString(string(chunkInfosByt))
chunkProofsStr := C.CString(string(chunkProofsByt))

Expand All @@ -167,23 +190,50 @@ func (p *ProverCore) proveBatch(chunkInfosByt []byte, chunkProofsByt []byte) []b
cProof := C.gen_batch_proof(chunkInfosStr, chunkProofsStr)
log.Info("Finish creating batch proof!")

proof := C.GoString(cProof)
return []byte(proof)
proofResult := &ChunkProofResult{}
err := json.Unmarshal([]byte(C.GoString(cProof)), proofResult)
if err != nil {
return nil, err
}

if proofResult.Error != nil {
return nil, errors.New(*proofResult.Error)
}

return *proofResult.Message, nil
}

func (p *ProverCore) proveChunk(tracesByt []byte) []byte {
tracesStr := C.CString(string(tracesByt))
type ChunkProofResult struct {
Message *[]byte `json:"message"`
Error *string `json:"error"`
}

defer func() {
C.free(unsafe.Pointer(tracesStr))
}()
func (p *ProverCore) proveChunk(tracesByt []byte) ([]byte, error) {
tracesStr := C.CString(string(tracesByt))
defer C.free(unsafe.Pointer(tracesStr))

log.Info("Start to create chunk proof ...")
log.Println("Start to create chunk proof ...")
cProof := C.gen_chunk_proof(tracesStr)
log.Info("Finish creating chunk proof!")
defer C.free(unsafe.Pointer(cProof))
log.Println("Finish creating chunk proof!")

proofJson := C.GoString(cProof)

var result ChunkProofResult
err := json.Unmarshal([]byte(proofJson), &result)
if err != nil {
return nil, errors.New("Failed to parse chunk proof result: " + err.Error())
}

if result.Error != nil {
return nil, errors.New(*result.Error)
}

if result.Message != nil {
return *result.Message, nil
}

proof := C.GoString(cProof)
return []byte(proof)
return nil, errors.New("Unexpected state: No message or error returned")
}

func (p *ProverCore) mayDumpProof(id string, proofByt []byte) error {
Expand Down

0 comments on commit 371c60f

Please sign in to comment.