Skip to content

Commit

Permalink
XiaoMi#215 support multi-query
Browse files Browse the repository at this point in the history
  • Loading branch information
funnyAnt committed May 13, 2022
1 parent 9e16060 commit 7f6769f
Show file tree
Hide file tree
Showing 10 changed files with 413 additions and 26 deletions.
1 change: 1 addition & 0 deletions models/namespace.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type Namespace struct {
DefaultCollation string `json:"default_collation"`
MaxSqlExecuteTime int `json:"max_sql_execute_time"` // sql最大执行时间,大于该时间,进行熔断
MaxSqlResultSize int `json:"max_sql_result_size"` // 限制单分片返回结果集大小不超过max_select_rows
SupportMultiQuery bool `json:"support_multi_query"` //是否支持多语句
}

// Encode encode json
Expand Down
33 changes: 33 additions & 0 deletions mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ type Conn struct {
// currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively.
// It can be allocated from bufPool or heap and should be recycled in the same manner.
currentEphemeralBuffer *[]byte

// Capabilities is the current set of features this connection
// is using. It is the features that are both supported by
// the client and the server, and currently in use.
// It is set during the initial handshake.
//
// It is only used for CapabilityClientDeprecateEOF
// and CapabilityClientFoundRows.
Capabilities uint32
}

// bufPool is used to allocate and free buffers in an efficient way.
Expand Down Expand Up @@ -702,6 +711,30 @@ func parseOKHeader(data []byte) (uint64, uint64, uint16, uint16, error) {
return affectedRows, lastInsertID, statusFlags, warnings, nil
}

func (c *Conn) HandleComSetOption(data []byte) (ret bool, err error) {
ret = true
//parseComSetOption
operation, _, ok := ReadUint16(data, 0)
//c.recycleReadPacket()
if ok {
switch operation {
case 0:
c.Capabilities |= ClientMultiStatements
case 1:
c.Capabilities &^= ClientMultiStatements
default:
ret = false
err = fmt.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data)

}
} else {
ret = false
err = fmt.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data)
}

return ret, err
}

// IsErrorPacket determines whether or not the packet is an error packet. Mostly here for
// consistency with isEOFHeader
func IsErrorPacket(data []byte) bool {
Expand Down
57 changes: 57 additions & 0 deletions parser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ const (
StmtComment
StmtSavepoint
)
const (
eofChar = 0x100
)

// Preview analyzes the beginning of the query using a simpler and faster
// textual comparison to identify the statement type.
Expand Down Expand Up @@ -146,3 +149,57 @@ func StmtType(stmtType int) string {
return "UNKNOWN"
}
}

// SplitStatementToPieces split raw sql statement that may have multi sql pieces to sql pieces
// returns the sql pieces blob contains; or error if sql cannot be parsed
func SplitStatementToPieces(blob string) (pieces []string, err error) {
// fast path: the vast majority of SQL statements do not have semicolons in them
if blob == "" {
return nil, nil
}
switch strings.IndexByte(blob, ';') {
case -1: // if there is no semicolon, return blob as a whole
return []string{blob}, nil
case len(blob) - 1: // if there's a single semicolon and it's the last character, return blob without it
return []string{blob[:len(blob)-1]}, nil
}

pieces = make([]string, 0, 16)
tokenizer := NewScanner(blob)

tkn := 0
var pos Pos
var stmt string
stmtBegin := 0
emptyStatement := true
loop:
for {
tkn, pos, _ = tokenizer.scan()
switch tkn {
case ';':
stmt = blob[stmtBegin:pos.Offset]
if !emptyStatement {
pieces = append(pieces, stmt)
emptyStatement = true
}
stmtBegin = pos.Offset + 1
case 0, eofChar:
blobTail := pos.Offset - 1
if stmtBegin < blobTail {
stmt = blob[stmtBegin : blobTail+1]
if !emptyStatement {
pieces = append(pieces, stmt)
}
}
break loop
default:
emptyStatement = false
}
}

if len(tokenizer.errs) > 0 {
err = tokenizer.errs[0]
}

return
}
16 changes: 10 additions & 6 deletions proxy/server/client_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ package server

import (
"fmt"
"strings"

"github.com/XiaoMi/Gaea/log"
"github.com/XiaoMi/Gaea/mysql"
"strings"
"github.com/XiaoMi/Gaea/util/sync2"
)

// ClientConn session client connection
Expand All @@ -29,11 +31,11 @@ type ClientConn struct {

manager *Manager

capability uint32

namespace string // TODO: remove it when refactor is done

proxy *Server

hasRecycledReadPacket sync2.AtomicBool
}

// HandshakeResponseInfo handshake response information
Expand All @@ -49,11 +51,13 @@ type HandshakeResponseInfo struct {
// NewClientConn constructor of ClientConn
func NewClientConn(c *mysql.Conn, manager *Manager) *ClientConn {
salt, _ := mysql.RandomBuf(20)
return &ClientConn{
cc := &ClientConn{
Conn: c,
salt: salt,
manager: manager,
manager: manager,
}
cc.hasRecycledReadPacket.Set(false)
return cc
}

func (cc *ClientConn) CompactVersion(sv string) string {
Expand Down Expand Up @@ -177,7 +181,7 @@ func (cc *ClientConn) readHandshakeResponse() (HandshakeResponseInfo, error) {
return info, fmt.Errorf("readHandshakeResponse: only support protocol 4.1")
}

cc.capability = capability
cc.Capabilities = capability
// Max packet size. Don't do anything with this now.
_, pos, ok = mysql.ReadUint32(data, pos)
if !ok {
Expand Down
16 changes: 13 additions & 3 deletions proxy/server/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type SessionExecutor struct {
stmts map[uint32]*Stmt //prepare相关,client端到proxy的stmt

parser *parser.Parser
session *Session
}

// Response response info
Expand Down Expand Up @@ -279,10 +280,13 @@ func (se *SessionExecutor) ExecuteCommand(cmd byte, data []byte) Response {
sql := string(data)
// handle phase
r, err := se.handleQuery(sql)
if err != nil {
if se.session.IsClosed() {
return CreateNoopResponse()
} else if err != nil {
return CreateErrorResponse(se.status, err)
} else {
return CreateResultResponse(se.status, r)
}
return CreateResultResponse(se.status, r)
case mysql.ComPing:
return CreateOKResponse(se.status)
case mysql.ComInitDB:
Expand Down Expand Up @@ -332,7 +336,13 @@ func (se *SessionExecutor) ExecuteCommand(cmd byte, data []byte) Response {
}
return CreateOKResponse(se.status)
case mysql.ComSetOption:
return CreateEOFResponse(se.status)
ok, err := se.session.c.HandleComSetOption(data)
if ok {
return CreateEOFResponse(se.status)
} else {
log.Warn("dispatch command failed, error: %v", err)
return CreateErrorResponse(se.status, mysql.NewError(mysql.ErrUnknown, err.Error()))
}
default:
msg := fmt.Sprintf("command %d not supported now", cmd)
log.Warn("dispatch command failed, error: %s", msg)
Expand Down
52 changes: 51 additions & 1 deletion proxy/server/executor_handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,61 @@ func (se *SessionExecutor) handleQuery(sql string) (r *mysql.Result, err error)
stmtType := parser.Preview(sql)
reqCtx.Set(util.StmtType, stmtType)

r, err = se.doQuery(reqCtx, sql)
if ns.supportMultiQuery {
r, err = se.doMultiStmts(reqCtx, sql)
} else {
r, err = se.doQuery(reqCtx, sql)
}

se.manager.RecordSessionSQLMetrics(reqCtx, se, sql, startTime, err)
return r, err
}

//handle multi-stmts,like `select 1;set autcommit=0;insert into...;`
func (se *SessionExecutor) doMultiStmts(reqCtx *util.RequestContext, sql string) (r *mysql.Result, errRet error) {
if !se.session.c.hasRecycledReadPacket.Get() {
se.session.c.RecycleReadPacket()
se.session.c.hasRecycledReadPacket.Set(true)
}

piecesSql, err := parser.SplitStatementToPieces(sql)
if err != nil {
log.Warn("parse sql error. sql: [%s], err: %v", sql, err)
return nil, err
}

stmtsNum := len(piecesSql)
if stmtsNum == 1 { //single statements
return se.doQuery(reqCtx, sql)
} else if stmtsNum > 1 && se.session.c.Capabilities&mysql.ClientMultiStatements == 0 {
errRet = fmt.Errorf("client's Capabilities not support multi statements,but proxy receive multi statements:[%s]", sql)
return nil, errRet
}

//multi-query
for index, piece := range piecesSql {
reqCtx.Set(util.StmtType, parser.Preview(piece))
reqCtx.Set(util.FromSlave, 0)

r, errRet = se.doQuery(reqCtx, piece)
if errRet != nil {
return nil, errRet
}

if index < stmtsNum-1 {
//write result to client
response := CreateResultResponse(se.status|mysql.ServerMoreResultsExists, r)
if err = se.session.writeResponse(response); err != nil {
log.Warn("session write response error, error: %v", err)
se.session.Close()
return r, errRet
}
}
}

return r, errRet
}

func (se *SessionExecutor) doQuery(reqCtx *util.RequestContext, sql string) (*mysql.Result, error) {
stmtType := reqCtx.Get("stmtType").(int)

Expand Down
Loading

0 comments on commit 7f6769f

Please sign in to comment.