Skip to content

Commit

Permalink
Refactor to get rid of nested ifs
Browse files Browse the repository at this point in the history
  • Loading branch information
mtaner committed Aug 7, 2024
1 parent 034ff72 commit ed9a90b
Showing 1 changed file with 87 additions and 110 deletions.
197 changes: 87 additions & 110 deletions inline_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@ package ghostferry
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"time"
"encoding/json"
"reflect"
sqlorig "database/sql"


sql "github.com/Shopify/ghostferry/sqlwrapper"

"github.com/golang/snappy"
"github.com/go-mysql-org/go-mysql/schema"
"github.com/golang/snappy"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -730,10 +728,6 @@ func (v *InlineVerifier) verifyAllEventsInStore() (bool, map[string]map[string][
// a union of mismatches of fingerprints and mismatches due to decompressed
// data.
func (v *InlineVerifier) verifyBinlogBatch(batch BinlogVerifyBatch, skipJsonColumnCheck bool) ([]InlineVerifierMismatches, error) {
if !skipJsonColumnCheck {
skipJsonColumnCheck = false
}

targetSchema := batch.SchemaName
if targetSchemaName, exists := v.DatabaseRewrites[targetSchema]; exists {
targetSchema = targetSchemaName
Expand Down Expand Up @@ -809,127 +803,124 @@ func (v *InlineVerifier) compareJsonColumnValues(batch BinlogVerifyBatch, mismat
}
}

if len(jsonColumnNames) > 0 {
addJsonColumnNamesToIgnoredColumnsForVerification(sourceTableSchema, jsonColumnNames)

mismatches, err := v.verifyBinlogBatch(batch, true)

if err != nil {
return nil, err
}
if len(jsonColumnNames) == 0 {
return mismatches, nil
}

if len(mismatches) == 0 {
args := make([]interface{}, len(batch.PaginationKeys))
for i, paginationKey := range batch.PaginationKeys {
args[i] = paginationKey
}
addJsonColumnNamesToIgnoredColumnsForVerification(sourceTableSchema, jsonColumnNames)

sourceQuery := jsonColumnValueQuery(sourceTableSchema, batch.SchemaName, batch.TableName, jsonColumnNames, len(batch.PaginationKeys))
targetQuery := jsonColumnValueQuery(sourceTableSchema, targetSchema, targetTable, jsonColumnNames, len(batch.PaginationKeys))
mismatches, err := v.verifyBinlogBatch(batch, true)
if err != nil {
return nil, err
}

sourceStatement, _ := v.sourceStmtCache.StmtFor(v.SourceDB, sourceQuery)
targetStatement, _ := v.targetStmtCache.StmtFor(v.TargetDB, targetQuery)
if len(mismatches) > 0 {
return mismatches, nil
}

sourceRows, _ := sourceStatement.Query(args...)
targetRows, _ := targetStatement.Query(args...)
args := make([]interface{}, len(batch.PaginationKeys))
for i, paginationKey := range batch.PaginationKeys {
args[i] = paginationKey
}

defer sourceRows.Close()
defer targetRows.Close()
sourceQuery := jsonColumnValueQuery(sourceTableSchema, batch.SchemaName, batch.TableName, jsonColumnNames, len(batch.PaginationKeys))
targetQuery := jsonColumnValueQuery(sourceTableSchema, targetSchema, targetTable, jsonColumnNames, len(batch.PaginationKeys))

sourceRowBatch := [][][]byte{}
targetRowBatch := [][][]byte{}
sourceStatement, _ := v.sourceStmtCache.StmtFor(v.SourceDB, sourceQuery)
targetStatement, _ := v.targetStmtCache.StmtFor(v.TargetDB, targetQuery)

sourceRowBatch, err = addRowsToBatch(sourceRows, sourceRowBatch, len(jsonColumnNames))
sourceRows, _ := sourceStatement.Query(args...)
targetRows, _ := targetStatement.Query(args...)

if err != nil {
return nil, err
}
defer sourceRows.Close()
defer targetRows.Close()

targetRowBatch, err = addRowsToBatch(targetRows, targetRowBatch, len(jsonColumnNames))
mismatchedJsonColumns := []string{}
paginationKeysWithMismatchedJson := []uint64{}

if err != nil {
return nil, err
}
for {
hasSourceRows := sourceRows.Next()
hasTargetRows := targetRows.Next()

var sourceJsonColumnValue map[string]interface{}
var sourcePaginationKey uint64
if !hasSourceRows && !hasTargetRows {
break
}

var targetJsonColumnValue map[string]interface{}
var targetPaginationKey uint64
if (hasSourceRows && !hasTargetRows) || (!hasSourceRows && hasTargetRows) {
return nil, fmt.Errorf("Number of source and target rows are different")
}

for i := 0; i < len(args); i++ {
sourceRowData := sourceRowBatch[i]
targetRowData := targetRowBatch[i]
sourceRowData, err := ScanByteRow(sourceRows, len(jsonColumnNames) + 1)
if err != nil {
return nil, err
}

sourcePaginationKey, _ = strconv.ParseUint(string(sourceRowData[0]), 10, 64)
targetPaginationKey, _ = strconv.ParseUint(string(targetRowData[0]), 10, 64)
targetRowData, err := ScanByteRow(targetRows, len(jsonColumnNames) + 1)
if err != nil {
return nil, err
}

mismatchJsonColumns := []string{}
paginationKeysWithMismatchedJson := []uint64{}
var sourceJsonColumnValue map[string]interface{}
var sourcePaginationKey uint64

for j, jsonColumn := range jsonColumnNames {
err := json.Unmarshal([]byte(sourceRowData[j+1]), &sourceJsonColumnValue)
var targetJsonColumnValue map[string]interface{}
var targetPaginationKey uint64

if err != nil {
fmt.Println("Error parsing JSON:", err)
return nil, err
}
sourcePaginationKey, _ = strconv.ParseUint(string(sourceRowData[0]), 10, 64)
targetPaginationKey, _ = strconv.ParseUint(string(targetRowData[0]), 10, 64)

err = json.Unmarshal([]byte(targetRowData[j+1]), &targetJsonColumnValue)
for j, jsonColumn := range jsonColumnNames {
err := json.Unmarshal([]byte(sourceRowData[j+1]), &sourceJsonColumnValue)
if err != nil {
return nil, fmt.Errorf("unmarshalling target rowdata: %w")
}

if err != nil {
fmt.Println("Error parsing JSON:", err)
return nil, err
}
err = json.Unmarshal([]byte(targetRowData[j+1]), &targetJsonColumnValue)
if err != nil {
return nil, fmt.Errorf("unmarshalling target rowdata: %w")
}

if sourcePaginationKey == targetPaginationKey && reflect.DeepEqual(sourceJsonColumnValue, targetJsonColumnValue) {
continue
}
if sourcePaginationKey == targetPaginationKey && reflect.DeepEqual(sourceJsonColumnValue, targetJsonColumnValue) {
continue
}

if !uint64SliceContains(paginationKeysWithMismatchedJson, sourcePaginationKey) {
paginationKeysWithMismatchedJson = append(paginationKeysWithMismatchedJson, sourcePaginationKey)
}
if !uint64SliceContains(paginationKeysWithMismatchedJson, sourcePaginationKey) {
paginationKeysWithMismatchedJson = append(paginationKeysWithMismatchedJson, sourcePaginationKey)
}

if !stringSliceContains(mismatchJsonColumns, jsonColumn) {
mismatchJsonColumns = append(mismatchJsonColumns, jsonColumn)
}
}
if !stringSliceContains(mismatchedJsonColumns, jsonColumn) {
mismatchedJsonColumns = append(mismatchedJsonColumns, jsonColumn)
}
}
}

if len(mismatchJsonColumns) > 0 {
removeJsonColumnsFromIgnoredColumnsForVerification(sourceTableSchema, mismatchJsonColumns)
if len(mismatchedJsonColumns) == 0 {
return mismatches, nil
}

mismatched, err := v.verifyBinlogBatch(batch, true)
removeJsonColumnsFromIgnoredColumnsForVerification(sourceTableSchema, mismatchedJsonColumns)

if err != nil {
return nil, err
}
mismatched, err := v.verifyBinlogBatch(batch, true)

filteredMismatches := []InlineVerifierMismatches{}
if err != nil {
return nil, err
}

// filtering out the mismatches that have successful json value comparison
for _, mismatch := range mismatched {
for _, mismatchedJsonPK := range paginationKeysWithMismatchedJson {
if mismatch.Pk == mismatchedJsonPK {
filteredMismatches = append(filteredMismatches, mismatch)
}
}
}
filteredMismatches := []InlineVerifierMismatches{}

return filteredMismatches, nil
}
// filtering out the mismatches that have successful json value comparison
for _, mismatch := range mismatched {
for _, mismatchedJsonPK := range paginationKeysWithMismatchedJson {
if mismatch.Pk == mismatchedJsonPK {
filteredMismatches = append(filteredMismatches, mismatch)
}

return mismatches, nil
}

return mismatches, nil
}

return mismatches, nil
return filteredMismatches, nil
}


func jsonColumnValueQuery(sourceTableSchema *TableSchema, schemaName string, tableName string, jsonColumnNames []string, paginationKeysCount int) (string) {
func jsonColumnValueQuery(sourceTableSchema *TableSchema, schemaName string, tableName string, jsonColumnNames []string, paginationKeysCount int) string {
paginationColumn := QuoteField(sourceTableSchema.GetPaginationColumn().Name)

return fmt.Sprintf(
Expand All @@ -955,19 +946,6 @@ func addJsonColumnNamesToIgnoredColumnsForVerification(sourceTableSchema *TableS
sourceTableSchema.RowMd5Query()
}

func addRowsToBatch(rows *sqlorig.Rows, batch [][][]byte, jsonColumnsCount int) ([][][]byte, error) {
for rows.Next() {
rowData, err := ScanByteRow(rows, jsonColumnsCount + 1)
if err != nil {
return nil, err
}

batch = append(batch, rowData)
}

return batch, nil
}

func removeJsonColumnsFromIgnoredColumnsForVerification(sourceTableSchema *TableSchema, jsonColumnNames []string){
for _, jsonColumn := range jsonColumnNames {
delete(sourceTableSchema.IgnoredColumnsForVerification, jsonColumn)
Expand Down Expand Up @@ -996,4 +974,3 @@ func stringSliceContains(s []string, item string) bool {

return false
}

0 comments on commit ed9a90b

Please sign in to comment.