Skip to content

Commit

Permalink
Fix single character escaping
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshaw committed Oct 26, 2024
1 parent 14d9625 commit a358020
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 44 deletions.
22 changes: 10 additions & 12 deletions stmt/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (b *Stmt) Ready() bool {
// Reset resets the statement buffer.
func (b *Stmt) Reset(r []rune) {
// reset buf
b.Buf, b.Len, b.Prefix, b.Vars = nil, 0, "", nil
b.Buf, b.Len, b.Prefix, b.Vars = nil, 0, "", b.Vars[:0]
// quote state
b.quote, b.quoteDollarTag = 0, ""
// multicomment state
Expand Down Expand Up @@ -187,7 +187,7 @@ func (b *Stmt) Next(unquote func(string, bool) (bool, string, error)) (string, s
var ok bool
parse:
for ; i < b.rlen; i++ {
// log.Printf(">> (%c) %d", b.r[i], i)
// fmt.Fprintf(os.Stderr, "> %d: `%s`\n", i, string(b.r[i:]))
// grab c, next
c, next := b.r[i], grab(b.r, i+1, b.rlen)
switch {
Expand Down Expand Up @@ -231,20 +231,11 @@ parse:
ok, z, _ := unquote(v.Name, true)
if v.Defined = ok || v.Quote == '?'; v.Defined {
b.r, b.rlen = v.Substitute(b.r, z, ok)
i--
}
if b.Len != 0 {
v.I += b.Len + 1
}
}
// unbalance
case c == '(':
b.balanceCount++
// balance
case c == ')':
b.balanceCount = max(0, b.balanceCount-1)
// continue processing quoted string, multiline comment, or unbalanced statements
case b.quote != 0 || b.multilineComment || b.balanceCount != 0:
// skip escaped backslash, semicolon, colon
case c == '\\' && (next == '\\' || next == ';' || next == ':'):
v := &Var{
Expand All @@ -257,7 +248,14 @@ parse:
if b.r, b.rlen = v.Substitute(b.r, string(next), false); b.Len != 0 {
v.I += b.Len + 1
}
i++
// unbalance
case c == '(':
b.balanceCount++
// balance
case c == ')':
b.balanceCount = max(0, b.balanceCount-1)
// continue processing quoted string, multiline comment, or unbalanced statements
case b.quote != 0 || b.multilineComment || b.balanceCount != 0:
// start of command
case c == '\\':
// parse command and params end positions
Expand Down
82 changes: 50 additions & 32 deletions stmt/stmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ func TestNextResetState(t *testing.T) {
{"select $1\\bind a b c\\g", []string{`select $1`}, []string{`\bind| a b c`, `\g|`}, "=", nil},
{"select $1 \\bind a b c \\g", []string{`select $1 `}, []string{`\bind| a b c `, `\g|`}, "=", nil},
{"select $2, $a$ foo $a$, $1 \\bind a b \\g", []string{`select $2, $a$ foo $a$, $1 `}, []string{`\bind| a b `, `\g|`}, "=", nil},
{"select \\;\\\\\\:\\; \n;", []string{"select ;\\:; \n;"}, []string{`|`, `|`}, `=`, []string{`;`, `\`, `:`, `;`}},
{"select \\;\\;\\;\\; \n;", []string{"select ;;;; \n;"}, []string{`|`, `|`}, `=`, []string{`;`, `;`, `;`, `;`}},
{`select \;\;\;\;;`, []string{`select ;;;;;`}, []string{`|`}, `=`, []string{`;`, `;`, `;`, `;`}},
{`select \\\;\\\;\\\;\\\;;`, []string{`select \;\;\;\;;`}, []string{`|`}, `=`, []string{`\`, `;`, `\`, `;`, `\`, `;`, `\`, `;`}},
{`select \:foo;`, []string{`select :foo;`}, []string{`|`}, `=`, []string{":"}},
}
for i, test := range tests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
Expand All @@ -207,61 +212,87 @@ func TestNextResetState(t *testing.T) {
loop:
for {
cmd, params, err := b.Next(unquote)
t.Logf("next buf:%q", string(b.Buf))
t.Logf("next cmd:%q params:%q", cmd, params)
switch {
case err == io.EOF:
break loop
case err != nil:
t.Fatalf("expected no error, got: %v", err)
}
vars = append(vars, b.Vars...)
if b.Ready() || cmd == `\g` {
stmts = append(stmts, b.String())
vars = append(vars, b.Vars...)
b.Reset(nil)
}
cmds = append(cmds, cmd)
aparams = append(aparams, params)
}
if len(stmts) != len(test.stmts) {
t.Logf(">> %#v // %#v", test.stmts, stmts)
t.Fatalf("expected %d statements, got: %d", len(test.stmts), len(stmts))
}
if !reflect.DeepEqual(stmts, test.stmts) {
t.Logf(">> %#v // %#v", test.stmts, stmts)
t.Fatalf("expected statements %s, got: %s", jj(test.stmts), jj(stmts))
vars = append(vars, b.Vars...)
if len(stmts) != len(test.stmts) || !reflect.DeepEqual(stmts, test.stmts) {
t.Errorf("expected %d statements, got: %d", len(test.stmts), len(stmts))
t.Logf("expected:")
for _, s := range test.stmts {
t.Logf(" %q", s)
}
t.Logf("got:")
for _, s := range stmts {
t.Logf(" %q", s)
}
}
if cz := cc(cmds, aparams); !reflect.DeepEqual(cz, test.cmds) {
t.Logf(">> cmds: %#v, aparams: %#v, cz: %#v, test.cmds: %#v", cmds, aparams, cz, test.cmds)
t.Fatalf("expected commands %v, got: %v", jj(test.cmds), jj(cz))
t.Errorf("commands do not match")
t.Logf("expected:")
for _, s := range test.cmds {
t.Logf(" %q", s)
}
t.Logf("got:")
for _, s := range cz {
t.Logf(" %q", s)
}
}
if st := b.State(); st != test.state {
t.Fatalf("expected end parse state %q, got: %q", test.state, st)
t.Errorf("expected end parse state %q, got: %q", test.state, st)
}
if len(vars) != len(test.vars) {
t.Fatalf("expected %d vars, got: %d", len(test.vars), len(vars))
t.Errorf("expected %d vars, got: %d", len(test.vars), len(vars))
t.Logf("expected:")
for _, v := range test.vars {
t.Logf(" %#v", v)
}
t.Logf("got:")
for _, v := range vars {
t.Logf(" %#v", v)
}
}
for _, n := range test.vars {
if !hasVar(vars, n) {
t.Fatalf("missing variable %q", n)
for i, v := range test.vars {
if len(vars) < i {
t.Logf("expected var %d: %#v", i, v)
continue
}
if vars[i].Name != v {
t.Errorf("expected var %d: %q, got: %q", i, v, vars[i].Name)
}
}
b.Reset(nil)
if len(b.Buf) != 0 {
t.Fatalf("after reset b.Buf should have len %d, got: %d", 0, len(b.Buf))
t.Errorf("after reset b.Buf should have len %d, got: %d", 0, len(b.Buf))
}
if b.Len != 0 {
t.Fatalf("after reset should have len %d, got: %d", 0, b.Len)
t.Errorf("after reset should have len %d, got: %d", 0, b.Len)
}
if len(b.Vars) != 0 {
t.Fatalf("after reset should have len(vars) == 0, got: %d", len(b.Vars))
t.Errorf("after reset should have len(vars) == 0, got: %d", len(b.Vars))
}
if b.Prefix != "" {
t.Fatalf("after reset should have empty prefix, got: %s", b.Prefix)
t.Errorf("after reset should have empty prefix, got: %s", b.Prefix)
}
if b.quote != 0 || b.quoteDollarTag != "" || b.multilineComment || b.balanceCount != 0 {
t.Fatal("after reset should have a cleared parse state")
}
if st := b.State(); st != "=" {
t.Fatalf("after reset should have state `=`, got: %q", st)
t.Errorf("after reset should have state `=`, got: %q", st)
}
if b.ready {
t.Fatal("after reset should not be ready")
Expand Down Expand Up @@ -403,10 +434,6 @@ func cc(cmds []string, params []string) []string {
return z
}

func jj(s []string) string {
return "[`" + strings.Join(s, "`,`") + "`]"
}

func sp(a, sep string) func() ([]rune, error) {
s := strings.Split(a, sep)
return func() ([]rune, error) {
Expand All @@ -418,12 +445,3 @@ func sp(a, sep string) func() ([]rune, error) {
return nil, io.EOF
}
}

func hasVar(vars []*Var, n string) bool {
for _, v := range vars {
if v.Name == n {
return true
}
}
return false
}

0 comments on commit a358020

Please sign in to comment.