Skip to content
This repository has been archived by the owner on Aug 26, 2024. It is now read-only.

Commit

Permalink
Use correct type when observing result rows
Browse files Browse the repository at this point in the history
The type assertion never matched because we use pointer there.
Adding a test as well.
  • Loading branch information
martin-sucha committed Sep 22, 2023
1 parent 0770336 commit 6c94303
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
96 changes: 96 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,89 @@ func TestWriteCoalescing_WriteAfterClose(t *testing.T) {
}
}

type frameObserverFunc func(ctx context.Context, observedFrame ObservedFrame)

func (fn frameObserverFunc) ObserveFrame(ctx context.Context, observedFrame ObservedFrame) {
fn(ctx, observedFrame)
}

func TestFrameObserver(t *testing.T) {
srv := NewTestServer(t, protoVersion4, context.Background())
defer srv.Stop()

var frameCount int64
var rowsCount int64
var rowsBytes int64
var frameLengthBytes int64
var frameUncompressedBytes int64

cluster := testCluster(protoVersion4, srv.Address)
cluster.FrameObserver = frameObserverFunc(func(ctx context.Context, observedFrame ObservedFrame) {
if observedFrame.Opcode != FrameOpcodeResult {
return
}
atomic.AddInt64(&frameCount, 1)
atomic.AddInt64(&rowsCount, int64(observedFrame.RowCount))
atomic.AddInt64(&rowsBytes, int64(observedFrame.RowsSize))
atomic.AddInt64(&frameLengthBytes, int64(observedFrame.Length))
atomic.AddInt64(&frameUncompressedBytes, int64(observedFrame.UncompressedSize))
})
db, err := cluster.CreateSession()
if err != nil {
t.Fatalf("create session: %v", err)
}

// Reset the counters, we are not interested in session setup.
atomic.SwapInt64(&frameCount, 0)
atomic.SwapInt64(&rowsCount, 0)
atomic.SwapInt64(&rowsBytes, 0)
atomic.SwapInt64(&frameLengthBytes, 0)
atomic.SwapInt64(&frameUncompressedBytes, 0)

it := db.Query("rows").Iter()

var items []string

for {
var column1 string
if !it.Scan(&column1) {
break
}
items = append(items, column1)
}

if err := it.Close(); err != nil {
t.Fatalf("close: %v", err)
}

if len(items) != 2 || items[0] != "hello" || items[1] != "world" {
t.Errorf("unexpected items: %+v", items)
}

gotFrameCount := atomic.LoadInt64(&frameCount)
gotRowsCount := atomic.LoadInt64(&rowsCount)
gotRowsBytes := atomic.LoadInt64(&rowsBytes)
gotFrameLengthBytes := atomic.LoadInt64(&frameLengthBytes)
gotFrameUncompressedBytes := atomic.LoadInt64(&frameUncompressedBytes)

if gotFrameCount != 1 {
t.Errorf("unexpected frame count, got %d", gotFrameCount)
}
if gotRowsCount != 2 {
t.Errorf("unexpected row count, got %d", gotRowsCount)
}
if gotRowsBytes != 18 {
t.Errorf("unexpected rows bytes, got %d", gotRowsBytes)
}
if gotFrameLengthBytes != 61 {
t.Errorf("unexpected frame length, got %d", gotFrameLengthBytes)
}
// compression was not used.
if gotFrameUncompressedBytes != 0 {
t.Errorf("unexpected frame uncompressed size, got %d", gotFrameUncompressedBytes)
}
}

type recordingFrameHeaderObserver struct {
t *testing.T
mu sync.Mutex
Expand Down Expand Up @@ -1270,6 +1353,19 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer, exts map[string]
rand.Seed(time.Now().UnixNano())
<-time.After(time.Millisecond * 120)
}
case "rows":
// https://martin-sucha.github.io/cqlprotodoc/native_protocol_v4.html#s4.2.5.2
respFrame.writeHeader(0, FrameOpcodeResult, head.stream)
respFrame.writeInt(resultKindRows)
respFrame.writeInt(int32(flagGlobalTableSpec)) // flags
respFrame.writeInt(1) // column count
respFrame.writeString("keyspace") // global table spec: keyspace name
respFrame.writeString("table") // global table spec: table name
respFrame.writeString("column") // column1 name
respFrame.writeShort(uint16(TypeVarchar)) // column1 type
respFrame.writeInt(2) // rows_count
respFrame.writeBytes([]byte("hello")) // row1
respFrame.writeBytes([]byte("world")) // row2
default:
respFrame.writeHeader(0, FrameOpcodeResult, head.stream)
respFrame.writeInt(resultKindVoid)
Expand Down
2 changes: 1 addition & 1 deletion frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ func (fpo *frameParseObserver) observeFrame(ff *framer, f frame) {
ObservedFrameHeader: fpo.head,
UncompressedSize: ff.uncompressedSize,
}
if rows, ok := f.(resultRowsFrame); ok {
if rows, ok := f.(*resultRowsFrame); ok {
of.IsRowsResult = true
of.RowCount = rows.numRows
of.RowsSize = rows.rowsContentSize
Expand Down

0 comments on commit 6c94303

Please sign in to comment.