diff --git a/clickhouse_test.go b/clickhouse_test.go index 6492459..0965365 100644 --- a/clickhouse_test.go +++ b/clickhouse_test.go @@ -11,6 +11,7 @@ import ( "os" "os/exec" "path/filepath" + "slices" "sort" "strings" "testing" @@ -51,7 +52,11 @@ func TestClickhouseLocal(t *testing.T) { if err != nil { t.Fatal(err) } - query, err := Compile(pqlInput) + compileOptions, testOptions, err := test.options() + if err != nil { + t.Fatal(err) + } + query, err := compileOptions.Compile(pqlInput) if err != nil { t.Fatal("Compile:", err) } @@ -68,6 +73,7 @@ func TestClickhouseLocal(t *testing.T) { stmt := fmt.Sprintf("CREATE TABLE \"%s\" AS file(%s, %s);", tab.name, fnameBuf, formatBuf) args = append(args, "--query", stmt) } + args = appendClickhouseParameterArgs(args, testOptions.parameterValues) args = append(args, "--query", query) c := exec.Command(clickhouseExe, args...) @@ -160,3 +166,26 @@ func findLocalTables(dir string) ([]localTable, error) { } return result, nil } + +func appendClickhouseParameterArgs(dst []string, params map[string]string) []string { + if len(params) == 0 { + return dst + } + + keys := make([]string, 0, len(params)) + for k := range params { + keys = append(keys, k) + } + slices.Sort(keys) + + sb := new(strings.Builder) + for _, k := range keys { + sb.WriteString("SET param_") + sb.WriteString(k) + sb.WriteString(" = ") + quoteSQLString(sb, params[k]) + sb.WriteString(";") + } + dst = append(dst, "--query", sb.String()) + return dst +} diff --git a/go.mod b/go.mod index bdece98..445cd73 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21.6 require ( github.com/google/go-cmp v0.6.0 github.com/spf13/cobra v1.8.0 + github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a golang.org/x/exp v0.0.0-20240213143201-ec583247a57a golang.org/x/term v0.17.0 zombiezen.com/go/bass v0.0.0-20230823162859-0399f01327dd diff --git a/go.sum b/go.sum index f567900..1e103e9 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29XwJucQo73FrleVK6t4kYz4NVhp34Yw= +github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= golang.org/x/exp v0.0.0-20240213143201-ec583247a57a h1:HinSgX1tJRX3KsL//Gxynpw5CTOAIPhgL4W8PNiIpVE= golang.org/x/exp v0.0.0-20240213143201-ec583247a57a/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= diff --git a/golden_test.go b/golden_test.go index 6881e42..349cd30 100644 --- a/golden_test.go +++ b/golden_test.go @@ -5,6 +5,7 @@ package pql import ( "bytes" + "encoding/json" "errors" "flag" "fmt" @@ -14,6 +15,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/tailscale/hujson" ) var recordGoldens = flag.Bool("record", false, "output golden files") @@ -34,8 +36,12 @@ func TestGoldens(t *testing.T) { if err != nil { t.Fatal(err) } + compileOptions, _, err := test.options() + if err != nil { + t.Fatal(err) + } - got, err := Compile(input) + got, err := compileOptions.Compile(input) if err != nil { t.Error("Compile(...):", err) } @@ -112,6 +118,47 @@ func (test goldenTest) input() (string, error) { return string(input), err } +type testOptions struct { + parameterValues map[string]string +} + +func (test goldenTest) options() (*CompileOptions, *testOptions, error) { + type testParameter struct { + Value string `json:"value"` + SQL string `json:"clickhouse"` + } + + path := filepath.Join(test.dir, "options.jwcc") + input, err := os.ReadFile(path) + if os.IsNotExist(err) { + return nil, new(testOptions), nil + } + if err != nil { + return nil, nil, err + } + input, err = hujson.Standardize(input) + if err != nil { + return nil, nil, fmt.Errorf("parse %s: %v", path, err) + } + var parsed struct { + Parameters map[string]testParameter `json:"parameters"` + } + if err := json.Unmarshal(input, &parsed); err != nil { + return nil, nil, fmt.Errorf("parse %s: %v", path, err) + } + opts := &CompileOptions{ + Parameters: make(map[string]string, len(parsed.Parameters)), + } + testOpts := &testOptions{ + parameterValues: make(map[string]string, len(parsed.Parameters)), + } + for name, p := range parsed.Parameters { + opts.Parameters[name] = p.SQL + testOpts.parameterValues[name] = p.Value + } + return opts, testOpts, nil +} + func shouldIgnoreFilename(name string) bool { return strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") } diff --git a/pql.go b/pql.go index aa0846c..d7add43 100644 --- a/pql.go +++ b/pql.go @@ -14,7 +14,24 @@ import ( // Compile converts the given Pipeline Query Language statement // into the equivalent SQL. +// This is equivalent to new(CompileOptions).Compile(source). func Compile(source string) (string, error) { + return ((*CompileOptions)(nil)).Compile(source) +} + +// CompileOptions a set of optional parameters +// that configure compilation. +// nil is treated the same as the zero value. +type CompileOptions struct { + // Parameters is a map of identifiers to SQL snippets to substitute in. + // For example, a "foo": "$1" entry would replace unquoted "foo" identifiers + // with "$1" in the resulting SQL. + Parameters map[string]string +} + +// Compile converts the given Pipeline Query Language statement +// into the equivalent SQL. +func (opts *CompileOptions) Compile(source string) (string, error) { expr, err := parser.Parse(source) if err != nil { return "", err @@ -31,6 +48,9 @@ func Compile(source string) (string, error) { ctx := &exprContext{ source: source, } + if opts != nil { + ctx.scope = opts.Parameters + } if len(ctes) > 0 { sb.WriteString("WITH ") for i, sub := range ctes { @@ -491,6 +511,7 @@ const ( type exprContext struct { source string + scope map[string]string mode exprMode } @@ -510,6 +531,10 @@ func writeExpression(ctx *exprContext, sb *strings.Builder, x parser.Expr) error if len(x.Parts) == 1 { part := x.Parts[0] if !part.Quoted { + if sql, ok := ctx.scope[part.Name]; ok { + sb.WriteString(sql) + return nil + } if sql, ok := builtinIdentifiers[part.Name]; ok { sb.WriteString(sql) return nil diff --git a/testdata/Goldens/Params/input.pql b/testdata/Goldens/Params/input.pql new file mode 100644 index 0000000..40a5006 --- /dev/null +++ b/testdata/Goldens/Params/input.pql @@ -0,0 +1,2 @@ +Tokens +| where Kind == desiredKind diff --git a/testdata/Goldens/Params/options.jwcc b/testdata/Goldens/Params/options.jwcc new file mode 100644 index 0000000..42c172f --- /dev/null +++ b/testdata/Goldens/Params/options.jwcc @@ -0,0 +1,8 @@ +{ + "parameters": { + "desiredKind": { + "clickhouse": "{desiredKind: Int32}", + "value": "1", + }, + }, +} diff --git a/testdata/Goldens/Params/output.csv b/testdata/Goldens/Params/output.csv new file mode 100644 index 0000000..565e14f --- /dev/null +++ b/testdata/Goldens/Params/output.csv @@ -0,0 +1,2 @@ +Kind,TokenConstant +1,TokenIdentifier diff --git a/testdata/Goldens/Params/output.sql b/testdata/Goldens/Params/output.sql new file mode 100644 index 0000000..6e8cff3 --- /dev/null +++ b/testdata/Goldens/Params/output.sql @@ -0,0 +1 @@ +SELECT * FROM "Tokens" WHERE coalesce("Kind" = {desiredKind: Int32}, FALSE);