Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mask-pp committed Aug 30, 2023
1 parent 90f8a28 commit 26ba4d9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 183 deletions.
4 changes: 2 additions & 2 deletions prover-stats-api/internal/orm/orm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import (
"github.com/stretchr/testify/assert"
"gorm.io/gorm"

"scroll-tech/database/migrate"

"scroll-tech/common/database"
"scroll-tech/common/docker"
"scroll-tech/common/types"
"scroll-tech/common/types/message"
"scroll-tech/common/utils"

"scroll-tech/database/migrate"
)

var (
Expand Down
217 changes: 36 additions & 181 deletions prover-stats-api/internal/orm/prover_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ package orm
import (
"context"
"fmt"
"github.com/google/uuid"
"math/big"
"scroll-tech/common/types/message"
"time"

"github.com/google/uuid"
"github.com/shopspring/decimal"
"gorm.io/gorm"
"gorm.io/gorm/clause"

"scroll-tech/common/types"
"scroll-tech/common/types/message"
"scroll-tech/common/utils"
"github.com/shopspring/decimal"
"gorm.io/gorm"
)

// ProverTask is assigned provers info of chunk/batch proof prover task
Expand Down Expand Up @@ -54,46 +53,47 @@ func (*ProverTask) TableName() string {
return "prover_task"
}

// IsProverAssigned checks if a prover with the given public key has been assigned a task.
func (o *ProverTask) IsProverAssigned(ctx context.Context, publicKey string) (bool, error) {
// GetProverTasksByProver get all prover tasks by the given prover's public key.
func (o *ProverTask) GetProverTasksByProver(ctx context.Context, pubkey string, offset, limit int) ([]*ProverTask, error) {
var proverTasks []*ProverTask
db := o.db.WithContext(ctx)
var task ProverTask
err := db.Where("prover_public_key = ? AND proving_status = ?", publicKey, types.ProverAssigned).First(&task).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return false, nil
}
return false, err
db = db.Model(&ProverTask{})
db = db.Where("prover_public_key", pubkey)
db = db.Order("id desc")
db = db.Offset(offset)
db = db.Limit(limit)
if err := db.Find(&proverTasks).Error; err != nil {
return nil, fmt.Errorf("ProverTask.GetProverTasksByProver error: %w, prover %s", err, pubkey)
}
return true, nil
return proverTasks, nil
}

// GetProverTasks get prover tasks
func (o *ProverTask) GetProverTasks(ctx context.Context, fields map[string]interface{}, orderByList []string, offset, limit int) ([]ProverTask, error) {
// GetProverTotalReward get prover all reward by the given prover's public key.
func (o *ProverTask) GetProverTotalReward(ctx context.Context, pubkey string) (*big.Int, error) {
var totalReward decimal.Decimal
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})

for k, v := range fields {
db = db.Where(k, v)
}

for _, orderBy := range orderByList {
db = db.Order(orderBy)
}

if limit != 0 {
db = db.Limit(limit)
db = db.Select("sum(reward)")
db = db.Where("prover_public_key", pubkey)
if err := db.Scan(&totalReward).Error; err != nil {
return nil, fmt.Errorf("ProverTask.GetProverTotalReward error:%w, prover:%s", err, pubkey)
}
return totalReward.BigInt(), nil
}

if offset != 0 {
db = db.Offset(offset)
}
// GetProverTasksByHash retrieves the ProverTask records associated with the specified hashes.
// The returned prover task objects are sorted in ascending order by their ids.
func (o *ProverTask) GetProverTasksByHash(ctx context.Context, hash string) (*ProverTask, error) {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("task_id", hash)
db = db.Order("id asc")

var proverTasks []ProverTask
if err := db.Find(&proverTasks).Error; err != nil {
return nil, err
var proverTask *ProverTask
if err := db.Find(&proverTask).Error; err != nil {
return nil, fmt.Errorf("ProverTask.GetProverTasksByHash error: %w, hash: %v", err, hash)
}
return proverTasks, nil
return proverTask, nil
}

// GetProverTasksByHashes retrieves the ProverTask records associated with the specified hashes.
Expand All @@ -116,108 +116,6 @@ func (o *ProverTask) GetProverTasksByHashes(ctx context.Context, taskType messag
return proverTasks, nil
}

// GetAssignedProverTaskByTaskIDAndProver get prover task taskID and public key
// TODO: when prover all upgrade need DEPRECATED this function
func (o *ProverTask) GetAssignedProverTaskByTaskIDAndProver(ctx context.Context, taskType message.ProofType, taskID, proverPublicKey, proverVersion string) (*ProverTask, error) {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("task_type", int(taskType))
db = db.Where("task_id", taskID)
db = db.Where("prover_public_key", proverPublicKey)
db = db.Where("prover_version", proverVersion)
db = db.Where("proving_status", types.ProverAssigned)

var proverTask ProverTask
err := db.First(&proverTask).Error
if err != nil {
return nil, fmt.Errorf("ProverTask.GetProverTaskByTaskIDAndProver err:%w, taskID:%s, pubkey:%s, prover_version:%s", err, taskID, proverPublicKey, proverVersion)
}
return &proverTask, nil
}

// GetProverTaskByUUIDAndPublicKey get prover task taskID by uuid and public key
func (o *ProverTask) GetProverTaskByUUIDAndPublicKey(ctx context.Context, uuid, publicKey string) (*ProverTask, error) {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("uuid", uuid)
db = db.Where("prover_public_key", publicKey)

var proverTask ProverTask
err := db.First(&proverTask).Error
if err != nil {
return nil, fmt.Errorf("ProverTask.GetProverTaskByUUID err:%w, uuid:%s publicKey:%s", err, uuid, publicKey)
}
return &proverTask, nil
}

// GetAssignedTaskOfOtherProvers get the chunk/batch task assigned other provers
func (o *ProverTask) GetAssignedTaskOfOtherProvers(ctx context.Context, taskType message.ProofType, taskID, proverPublicKey string) ([]ProverTask, error) {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("task_type", int(taskType))
db = db.Where("task_id", taskID)
db = db.Where("prover_public_key != ?", proverPublicKey)
db = db.Where("proving_status = ?", int(types.ProverAssigned))

var proverTasks []ProverTask
if err := db.Find(&proverTasks).Error; err != nil {
return nil, fmt.Errorf("ProverTask.GetAssignedProverTask error: %w, taskID: %v", err, taskID)
}
return proverTasks, nil
}

// GetProvingStatusByTaskID retrieves the proving status of a prover task
func (o *ProverTask) GetProvingStatusByTaskID(ctx context.Context, taskType message.ProofType, taskID string) (types.ProverProveStatus, error) {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Select("proving_status")
db = db.Where("task_type", int(taskType))
db = db.Where("task_id = ?", taskID)

var proverTask ProverTask
if err := db.Find(&proverTask).Error; err != nil {
return types.ProverProofInvalid, fmt.Errorf("ProverTask.GetProvingStatusByTaskID error: %w, taskID: %v", err, taskID)
}
return types.ProverProveStatus(proverTask.ProvingStatus), nil
}

// GetTimeoutAssignedProverTasks get the timeout and assigned proving_status prover task
func (o *ProverTask) GetTimeoutAssignedProverTasks(ctx context.Context, limit int, taskType message.ProofType, timeout time.Duration) ([]ProverTask, error) {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("proving_status", int(types.ProverAssigned))
db = db.Where("task_type", int(taskType))
db = db.Where("assigned_at < ?", utils.NowUTC().Add(-timeout))
db = db.Limit(limit)

var proverTasks []ProverTask
err := db.Find(&proverTasks).Error
if err != nil {
return nil, fmt.Errorf("ProverTask.GetAssignedProverTasks error:%w", err)
}
return proverTasks, nil
}

// TaskTimeoutMoreThanOnce get the timeout twice task. a temp design
func (o *ProverTask) TaskTimeoutMoreThanOnce(ctx context.Context, taskType message.ProofType, taskID string) bool {
db := o.db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("task_type", int(taskType))
db = db.Where("task_id", taskID)
db = db.Where("proving_status", int(types.ProverProofInvalid))

var count int64
if err := db.Count(&count).Error; err != nil {
return true
}

if count >= 1 {
return true
}

return false
}

// InsertProverTask insert a prover Task record
func (o *ProverTask) InsertProverTask(ctx context.Context, proverTask *ProverTask, dbTX ...*gorm.DB) error {
db := o.db.WithContext(ctx)
Expand All @@ -231,46 +129,3 @@ func (o *ProverTask) InsertProverTask(ctx context.Context, proverTask *ProverTas
}
return nil
}

// UpdateProverTaskProof update the prover task's proof
func (o *ProverTask) UpdateProverTaskProof(ctx context.Context, uuid uuid.UUID, proof []byte) error {
db := o.db
db = db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("uuid = ?", uuid)
if err := db.Update("proof", proof).Error; err != nil {
return fmt.Errorf("ProverTask.UpdateProverTaskProof error: %w, uuid: %v", err, uuid)
}
return nil
}

// UpdateProverTaskProvingStatus updates the proving_status of a specific ProverTask record.
func (o *ProverTask) UpdateProverTaskProvingStatus(ctx context.Context, uuid uuid.UUID, status types.ProverProveStatus, dbTX ...*gorm.DB) error {
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("uuid = ?", uuid)

if err := db.Update("proving_status", status).Error; err != nil {
return fmt.Errorf("ProverTask.UpdateProverTaskProvingStatus error: %w, uuid:%s, status: %v", err, uuid, status.String())
}
return nil
}

// UpdateProverTaskFailureType update the prover task failure type
func (o *ProverTask) UpdateProverTaskFailureType(ctx context.Context, uuid uuid.UUID, failureType types.ProverTaskFailureType, dbTX ...*gorm.DB) error {
db := o.db
if len(dbTX) > 0 && dbTX[0] != nil {
db = dbTX[0]
}
db = db.WithContext(ctx)
db = db.Model(&ProverTask{})
db = db.Where("uuid", uuid)
if err := db.Update("failure_type", int(failureType)).Error; err != nil {
return fmt.Errorf("ProverTask.UpdateProverTaskFailureType error: %w, uuid:%s, failure type: %v", err, uuid.String(), failureType.String())
}
return nil
}

0 comments on commit 26ba4d9

Please sign in to comment.