diff --git a/internal/engine/dolphin/convert.go b/internal/engine/dolphin/convert.go index cfd83b5c4d..6c761802e4 100644 --- a/internal/engine/dolphin/convert.go +++ b/internal/engine/dolphin/convert.go @@ -1556,6 +1556,12 @@ func (c *cc) convertTruncateTableStmt(n *pcast.TruncateTableStmt) *ast.TruncateS } func (c *cc) convertUnaryOperationExpr(n *pcast.UnaryOperationExpr) ast.Node { + if n.Op == opcode.Not || n.Op == opcode.Not2 { + return &ast.BoolExpr{ + Boolop: ast.BoolExprTypeNot, + Args: &ast.List{Items: []ast.Node{c.convert(n.V)}}, + } + } return todo(n) } diff --git a/internal/engine/dolphin/convert_test.go b/internal/engine/dolphin/convert_test.go new file mode 100644 index 0000000000..8c552023bc --- /dev/null +++ b/internal/engine/dolphin/convert_test.go @@ -0,0 +1,74 @@ +package dolphin + +import ( + "strings" + "testing" + + "github.com/sqlc-dev/sqlc/internal/sql/ast" + "github.com/sqlc-dev/sqlc/internal/sql/astutils" +) + +// walkStmts collects every node reachable from the parsed statements. +func walkStmts(t *testing.T, query string) []ast.Node { + t.Helper() + + p := NewParser() + stmts, err := p.Parse(strings.NewReader(query)) + if err != nil { + t.Fatalf("parse %q: %v", query, err) + } + + var nodes []ast.Node + for i := range stmts { + astutils.Walk(astutils.VisitorFunc(func(n ast.Node) { + nodes = append(nodes, n) + }), stmts[i].Raw.Stmt) + } + return nodes +} + +func TestConvertUnaryExpr_Not(t *testing.T) { + for _, tc := range []struct { + name string + query string + }{ + { + name: "NOT keyword", + query: "SELECT personid FROM persons WHERE NOT sqlc.arg('foo')", + }, + { + name: "bang operator", + query: "SELECT personid FROM persons WHERE ! sqlc.arg('foo')", + }, + } { + t.Run(tc.name, func(t *testing.T) { + nodes := walkStmts(t, tc.query) + + var ( + foundNot bool + foundArg bool + ) + for _, n := range nodes { + switch v := n.(type) { + case *ast.TODO: + t.Fatalf("query produced an opaque TODO node, sqlc.arg() would not be rewritten") + case *ast.BoolExpr: + if v.Boolop == ast.BoolExprTypeNot { + foundNot = true + } + case *ast.FuncCall: + if v.Func != nil && v.Func.Schema == "sqlc" && v.Func.Name == "arg" { + foundArg = true + } + } + } + + if !foundNot { + t.Errorf("expected a BoolExpr with BoolExprTypeNot in the AST") + } + if !foundArg { + t.Errorf("expected the nested sqlc.arg() FuncCall to be reachable in the AST") + } + }) + } +}