Skip to content

Commit

Permalink
Tweak dbutil.ScanJSON to work with sql.Row as well as sql.Rows
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Jul 3, 2024
1 parent df20f62 commit e669474
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
20 changes: 12 additions & 8 deletions dbutil/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,33 @@ import (

var validate = validator.New()

// Scanable is an interface to allow scanning of sql.Row or sql.Rows
type Scannable interface {
Scan(dest ...any) error
}

// ScanJSON scans a row which is JSON into a destination struct
func ScanJSON(rows *sql.Rows, destination any) error {
func ScanJSON(src Scannable, dest any) error {
var raw json.RawMessage
err := rows.Scan(&raw)
if err != nil {

if err := src.Scan(&raw); err != nil {
return fmt.Errorf("error scanning row JSON: %w", err)
}

err = json.Unmarshal(raw, destination)
if err != nil {
if err := json.Unmarshal(raw, dest); err != nil {
return fmt.Errorf("error unmarshalling row JSON: %w", err)
}

return nil
}

// ScanAndValidateJSON scans a row which is JSON into a destination struct and validates it
func ScanAndValidateJSON(rows *sql.Rows, destination any) error {
if err := ScanJSON(rows, destination); err != nil {
func ScanAndValidateJSON(src Scannable, dest any) error {
if err := ScanJSON(src, dest); err != nil {
return err
}

err := validate.Struct(destination)
err := validate.Struct(dest)
if err != nil {
return fmt.Errorf("error validating unmarsalled JSON: %w", err)
}
Expand Down
6 changes: 6 additions & 0 deletions dbutil/scan_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ func TestScanJSON(t *testing.T) {

rows.Close()

// can also scan as a single row
row := db.QueryRowContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f WHERE id = 3) r`)
err = dbutil.ScanJSON(row, f)
assert.NoError(t, err)
assert.Equal(t, "a5850c89-dd29-46f6-9de1-d068b3c2db94", f.UUID)

// can all scan all rows with ScanAllJSON
rows = queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT f.uuid, f.name, f.age FROM foo f) r`)

Expand Down

0 comments on commit e669474

Please sign in to comment.