diff --git a/mssql/ast/loc.go b/mssql/ast/loc.go index 3e9d0009..9cd98704 100644 --- a/mssql/ast/loc.go +++ b/mssql/ast/loc.go @@ -21,6 +21,14 @@ func NodeLoc(n Node) Loc { return v.Loc case *FuncCallExpr: return v.Loc + case *DatePart: + return v.Loc + case *NextValueForExpr: + return v.Loc + case *ParseExpr: + return v.Loc + case *JsonKeyValueExpr: + return v.Loc case *CaseExpr: return v.Loc case *CaseWhen: diff --git a/mssql/ast/loc_test.go b/mssql/ast/loc_test.go index 81a0173a..6f7d4c92 100644 --- a/mssql/ast/loc_test.go +++ b/mssql/ast/loc_test.go @@ -30,6 +30,10 @@ func TestNodeLocExpressions(t *testing.T) { }{ {"BinaryExpr", &BinaryExpr{Loc: Loc{Start: 1, End: 2}}, Loc{Start: 1, End: 2}}, {"FuncCallExpr", &FuncCallExpr{Loc: Loc{Start: 2, End: 5}}, Loc{Start: 2, End: 5}}, + {"DatePart", &DatePart{Loc: Loc{Start: 2, End: 6}}, Loc{Start: 2, End: 6}}, + {"NextValueForExpr", &NextValueForExpr{Loc: Loc{Start: 2, End: 7}}, Loc{Start: 2, End: 7}}, + {"ParseExpr", &ParseExpr{Loc: Loc{Start: 2, End: 8}}, Loc{Start: 2, End: 8}}, + {"JsonKeyValueExpr", &JsonKeyValueExpr{Loc: Loc{Start: 2, End: 9}}, Loc{Start: 2, End: 9}}, {"ColumnRef", &ColumnRef{Loc: Loc{Start: 3, End: 4}}, Loc{Start: 3, End: 4}}, {"LikeExpr", &LikeExpr{Loc: Loc{Start: 4, End: 9}}, Loc{Start: 4, End: 9}}, } diff --git a/mssql/ast/outfuncs.go b/mssql/ast/outfuncs.go index 6783ad1c..f0e92301 100644 --- a/mssql/ast/outfuncs.go +++ b/mssql/ast/outfuncs.go @@ -170,6 +170,14 @@ func writeNode(sb *strings.Builder, node Node) { writeUnaryExpr(sb, n) case *FuncCallExpr: writeFuncCallExpr(sb, n) + case *DatePart: + writeDatePart(sb, n) + case *NextValueForExpr: + writeNextValueForExpr(sb, n) + case *ParseExpr: + writeParseExpr(sb, n) + case *JsonKeyValueExpr: + writeJsonKeyValueExpr(sb, n) case *CaseExpr: writeCaseExpr(sb, n) case *CaseWhen: @@ -1901,6 +1909,22 @@ func writeFuncCallExpr(sb *strings.Builder, n *FuncCallExpr) { if n.Distinct { sb.WriteString(" :distinct true") } + if n.TrimOption != "" { + sb.WriteString(fmt.Sprintf(" :trimOption \"%s\"", escapeString(n.TrimOption))) + } + if n.NullTreatment != "" { + sb.WriteString(fmt.Sprintf(" :nullTreatment \"%s\"", escapeString(n.NullTreatment))) + } + if n.JsonNullClause != "" { + sb.WriteString(fmt.Sprintf(" :jsonNullClause \"%s\"", escapeString(n.JsonNullClause))) + } + if n.ReturnType != nil { + sb.WriteString(" :returnType ") + writeNode(sb, n.ReturnType) + } + if n.WithArrayWrapper { + sb.WriteString(" :withArrayWrapper true") + } if n.Over != nil { sb.WriteString(" :over ") writeNode(sb, n.Over) @@ -2129,6 +2153,64 @@ func writeIifExpr(sb *strings.Builder, n *IifExpr) { sb.WriteString("}") } +func writeDatePart(sb *strings.Builder, n *DatePart) { + sb.WriteString("{DATEPART") + if n.Name != "" { + sb.WriteString(fmt.Sprintf(" :name \"%s\"", escapeString(n.Name))) + } + sb.WriteString(fmt.Sprintf(" :loc %d %d", n.Loc.Start, n.Loc.End)) + sb.WriteString("}") +} + +func writeNextValueForExpr(sb *strings.Builder, n *NextValueForExpr) { + sb.WriteString("{NEXTVALUEFOR") + if n.Sequence != nil { + sb.WriteString(" :sequence ") + writeNode(sb, n.Sequence) + } + if n.Over != nil { + sb.WriteString(" :over ") + writeNode(sb, n.Over) + } + sb.WriteString(fmt.Sprintf(" :loc %d %d", n.Loc.Start, n.Loc.End)) + sb.WriteString("}") +} + +func writeParseExpr(sb *strings.Builder, n *ParseExpr) { + sb.WriteString("{PARSE") + if n.Try { + sb.WriteString(" :try true") + } + if n.Expr != nil { + sb.WriteString(" :expr ") + writeNode(sb, n.Expr) + } + if n.DataType != nil { + sb.WriteString(" :dataType ") + writeNode(sb, n.DataType) + } + if n.Culture != nil { + sb.WriteString(" :culture ") + writeNode(sb, n.Culture) + } + sb.WriteString(fmt.Sprintf(" :loc %d %d", n.Loc.Start, n.Loc.End)) + sb.WriteString("}") +} + +func writeJsonKeyValueExpr(sb *strings.Builder, n *JsonKeyValueExpr) { + sb.WriteString("{JSONKEYVALUE") + if n.Key != nil { + sb.WriteString(" :key ") + writeNode(sb, n.Key) + } + if n.Value != nil { + sb.WriteString(" :value ") + writeNode(sb, n.Value) + } + sb.WriteString(fmt.Sprintf(" :loc %d %d", n.Loc.Start, n.Loc.End)) + sb.WriteString("}") +} + func writeColumnRef(sb *strings.Builder, n *ColumnRef) { sb.WriteString("{COLREF") if n.Server != "" { diff --git a/mssql/ast/parsenodes.go b/mssql/ast/parsenodes.go index 40c66cdf..406d363a 100644 --- a/mssql/ast/parsenodes.go +++ b/mssql/ast/parsenodes.go @@ -1625,19 +1625,66 @@ const ( // FuncCallExpr represents a function call. type FuncCallExpr struct { - Name *TableRef // potentially schema-qualified - Args *List - Distinct bool - Star bool // e.g., COUNT(*) - Over *OverClause - Within *List // WITHIN GROUP (ORDER BY ...) - Loc Loc + Name *TableRef // potentially schema-qualified + Args *List + Distinct bool + Star bool // e.g., COUNT(*) + TrimOption string // LEADING, TRAILING, or BOTH for TRIM + NullTreatment string // IGNORE or RESPECT for window null treatment + JsonNullClause string // NULL or ABSENT for JSON constructors + ReturnType *DataType + WithArrayWrapper bool + Over *OverClause + Within *List // WITHIN GROUP (ORDER BY ...) + Loc Loc } func (n *FuncCallExpr) nodeTag() {} func (n *FuncCallExpr) exprNode() {} func (n *FuncCallExpr) tableExpr() {} // table-valued functions can appear in FROM +// DatePart represents a T-SQL datepart token used by date/time functions such +// as DATEADD, DATEDIFF, DATEPART, DATENAME, DATETRUNC, and DATE_BUCKET. +type DatePart struct { + Name string + Loc Loc +} + +func (n *DatePart) nodeTag() {} +func (n *DatePart) exprNode() {} + +// NextValueForExpr represents NEXT VALUE FOR sequence_name. +type NextValueForExpr struct { + Sequence *TableRef + Over *OverClause + Loc Loc +} + +func (n *NextValueForExpr) nodeTag() {} +func (n *NextValueForExpr) exprNode() {} + +// ParseExpr represents PARSE or TRY_PARSE. +type ParseExpr struct { + Try bool + Expr ExprNode + DataType *DataType + Culture ExprNode + Loc Loc +} + +func (n *ParseExpr) nodeTag() {} +func (n *ParseExpr) exprNode() {} + +// JsonKeyValueExpr represents JSON_OBJECT key:value entries. +type JsonKeyValueExpr struct { + Key ExprNode + Value ExprNode + Loc Loc +} + +func (n *JsonKeyValueExpr) nodeTag() {} +func (n *JsonKeyValueExpr) exprNode() {} + // CaseExpr represents a CASE expression. // Ref: https://learn.microsoft.com/en-us/sql/t-sql/language-elements/case-transact-sql type CaseExpr struct { diff --git a/mssql/ast/walk_generated.go b/mssql/ast/walk_generated.go index 174901bf..b2dec328 100644 --- a/mssql/ast/walk_generated.go +++ b/mssql/ast/walk_generated.go @@ -464,6 +464,9 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Name) } walkList(v, n.Args) + if n.ReturnType != nil { + Walk(v, n.ReturnType) + } if n.Over != nil { Walk(v, n.Over) } @@ -526,6 +529,9 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Right) Walk(v, n.Condition) walkList(v, n.Using) + case *JsonKeyValueExpr: + Walk(v, n.Key) + Walk(v, n.Value) case *KillQueryNotificationStmt: Walk(v, n.SubscriptionID) case *KillStatsJobStmt: @@ -568,6 +574,13 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Type) } walkList(v, n.Args) + case *NextValueForExpr: + if n.Sequence != nil { + Walk(v, n.Sequence) + } + if n.Over != nil { + Walk(v, n.Over) + } case *NullifExpr: Walk(v, n.Left) Walk(v, n.Right) @@ -594,6 +607,12 @@ func walkChildren(v Visitor, node Node) { Walk(v, n.Default) case *ParenExpr: Walk(v, n.Expr) + case *ParseExpr: + Walk(v, n.Expr) + if n.DataType != nil { + Walk(v, n.DataType) + } + Walk(v, n.Culture) case *PivotExpr: Walk(v, n.Source) Walk(v, n.AggFunc) diff --git a/mssql/ast/walk_test.go b/mssql/ast/walk_test.go index 82a34e01..4c7e49cd 100644 --- a/mssql/ast/walk_test.go +++ b/mssql/ast/walk_test.go @@ -209,6 +209,7 @@ func allKnownNodes() []Node { &LikeExpr{}, &IsExpr{}, &ExistsExpr{}, &CastExpr{}, &ConvertExpr{}, &TryCastExpr{}, &TryConvertExpr{}, &CoalesceExpr{}, &NullifExpr{}, &IifExpr{}, + &DatePart{}, &NextValueForExpr{}, &ParseExpr{}, &JsonKeyValueExpr{}, &ColumnRef{}, &VariableRef{}, &StarExpr{}, &Literal{}, &SubqueryExpr{}, &SubqueryComparisonExpr{}, &CollateExpr{}, &AtTimeZoneExpr{}, &ParenExpr{}, diff --git a/mssql/parser/compare_test.go b/mssql/parser/compare_test.go index f74a50cb..27d7ea5d 100644 --- a/mssql/parser/compare_test.go +++ b/mssql/parser/compare_test.go @@ -222,6 +222,224 @@ func TestParseFunctions(t *testing.T) { } } +func TestParseDateTimeFunctionDatepartArgs(t *testing.T) { + tests := []struct { + sql string + name string + datepart string + }{ + {"SELECT DATEADD(HOUR, 7, GETUTCDATE())", "DATEADD", "HOUR"}, + {"SELECT DATEDIFF(day, start_date, end_date)", "DATEDIFF", "day"}, + {"SELECT DATEDIFF_BIG(ns, start_date, end_date)", "DATEDIFF_BIG", "ns"}, + {"SELECT DATEPART(month, created_at)", "DATEPART", "month"}, + {"SELECT DATENAME(weekday, created_at)", "DATENAME", "weekday"}, + {"SELECT DATETRUNC(quarter, created_at)", "DATETRUNC", "quarter"}, + {"SELECT DATE_BUCKET(WEEK, 1, created_at)", "DATE_BUCKET", "WEEK"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := ParseAndCheck(t, tc.sql) + stmt := result.Items[0].(*ast.SelectStmt) + target := stmt.TargetList.Items[0].(*ast.ResTarget) + fc, ok := target.Val.(*ast.FuncCallExpr) + if !ok { + t.Fatalf("expected *FuncCallExpr, got %T", target.Val) + } + if fc.Args == nil || fc.Args.Len() == 0 { + t.Fatal("expected function arguments") + } + if cr, ok := fc.Args.Items[0].(*ast.ColumnRef); ok { + t.Fatalf("datepart argument parsed as ColumnRef: %+v", cr) + } + dp, ok := fc.Args.Items[0].(*ast.DatePart) + if !ok { + t.Fatalf("expected first argument to be *ast.DatePart, got %T", fc.Args.Items[0]) + } + if dp.Name != tc.datepart { + t.Fatalf("expected datepart %q, got %q", tc.datepart, dp.Name) + } + }) + } + + t.Run("ordinary function preserves column ref", func(t *testing.T) { + result := ParseAndCheck(t, "SELECT MYFUNC(HOUR)") + stmt := result.Items[0].(*ast.SelectStmt) + target := stmt.TargetList.Items[0].(*ast.ResTarget) + fc := target.Val.(*ast.FuncCallExpr) + if _, ok := fc.Args.Items[0].(*ast.ColumnRef); !ok { + t.Fatalf("expected ordinary function argument to remain ColumnRef, got %T", fc.Args.Items[0]) + } + }) +} + +func TestParseSpecialFunctionSyntax(t *testing.T) { + t.Run("next value for sequence", func(t *testing.T) { + result := ParseAndCheck(t, "SELECT NEXT VALUE FOR dbo.seq") + stmt := result.Items[0].(*ast.SelectStmt) + target := stmt.TargetList.Items[0].(*ast.ResTarget) + if cr, ok := target.Val.(*ast.ColumnRef); ok { + t.Fatalf("NEXT VALUE FOR parsed as ColumnRef: %+v", cr) + } + nv, ok := target.Val.(*ast.NextValueForExpr) + if !ok { + t.Fatalf("expected *ast.NextValueForExpr, got %T", target.Val) + } + if nv.Sequence == nil || nv.Sequence.Schema != "dbo" || nv.Sequence.Object != "seq" { + t.Fatalf("unexpected sequence ref: %+v", nv.Sequence) + } + }) + + t.Run("trim options", func(t *testing.T) { + tests := []struct { + sql string + wantOption string + wantArgs int + }{ + {"SELECT TRIM(name)", "", 1}, + {"SELECT TRIM('x' FROM name)", "", 2}, + {"SELECT TRIM(LEADING 'x' FROM name)", "LEADING", 2}, + {"SELECT TRIM(TRAILING 'x' FROM name)", "TRAILING", 2}, + {"SELECT TRIM(BOTH 'x' FROM name)", "BOTH", 2}, + } + for _, tc := range tests { + t.Run(tc.sql, func(t *testing.T) { + result := ParseAndCheck(t, tc.sql) + stmt := result.Items[0].(*ast.SelectStmt) + target := stmt.TargetList.Items[0].(*ast.ResTarget) + fc, ok := target.Val.(*ast.FuncCallExpr) + if !ok { + t.Fatalf("expected *ast.FuncCallExpr, got %T", target.Val) + } + if fc.TrimOption != tc.wantOption { + t.Fatalf("expected TrimOption=%q, got %q", tc.wantOption, fc.TrimOption) + } + if fc.Args == nil || fc.Args.Len() != tc.wantArgs { + t.Fatalf("expected %d args, got %#v", tc.wantArgs, fc.Args) + } + if tc.wantOption != "" { + if cr, ok := fc.Args.Items[0].(*ast.ColumnRef); ok && cr.Column == tc.wantOption { + t.Fatalf("trim option parsed as ColumnRef: %+v", cr) + } + } + }) + } + + result := ParseAndCheck(t, "SELECT MYFUNC(LEADING)") + stmt := result.Items[0].(*ast.SelectStmt) + target := stmt.TargetList.Items[0].(*ast.ResTarget) + fc := target.Val.(*ast.FuncCallExpr) + if _, ok := fc.Args.Items[0].(*ast.ColumnRef); !ok { + t.Fatalf("expected ordinary function argument to remain ColumnRef, got %T", fc.Args.Items[0]) + } + }) + + t.Run("ignore respect nulls", func(t *testing.T) { + tests := []struct { + sql string + want string + }{ + {"SELECT FIRST_VALUE(name) IGNORE NULLS OVER (ORDER BY id)", "IGNORE"}, + {"SELECT LAST_VALUE(name) RESPECT NULLS OVER (ORDER BY id)", "RESPECT"}, + {"SELECT LAG(name) IGNORE NULLS OVER (ORDER BY id)", "IGNORE"}, + {"SELECT LEAD(name) RESPECT NULLS OVER (ORDER BY id)", "RESPECT"}, + } + for _, tc := range tests { + t.Run(tc.want, func(t *testing.T) { + result := ParseAndCheck(t, tc.sql) + stmt := result.Items[0].(*ast.SelectStmt) + target := stmt.TargetList.Items[0].(*ast.ResTarget) + fc, ok := target.Val.(*ast.FuncCallExpr) + if !ok { + t.Fatalf("expected *ast.FuncCallExpr, got %T", target.Val) + } + if fc.NullTreatment != tc.want { + t.Fatalf("expected NullTreatment=%q, got %q", tc.want, fc.NullTreatment) + } + if fc.Over == nil { + t.Fatal("expected OVER clause") + } + }) + } + }) + + t.Run("parse and try_parse", func(t *testing.T) { + tests := []struct { + sql string + wantTry bool + }{ + {"SELECT PARSE('Monday, 13 December 2010' AS datetime2 USING 'en-US')", false}, + {"SELECT TRY_PARSE('Jabberwokkie' AS datetime2 USING 'en-US')", true}, + } + for _, tc := range tests { + t.Run(tc.sql, func(t *testing.T) { + result := ParseAndCheck(t, tc.sql) + stmt := result.Items[0].(*ast.SelectStmt) + target := stmt.TargetList.Items[0].(*ast.ResTarget) + pe, ok := target.Val.(*ast.ParseExpr) + if !ok { + t.Fatalf("expected *ast.ParseExpr, got %T", target.Val) + } + if pe.Try != tc.wantTry { + t.Fatalf("expected Try=%v, got %v", tc.wantTry, pe.Try) + } + if pe.DataType == nil || pe.DataType.Name != "datetime2" { + t.Fatalf("unexpected data type: %+v", pe.DataType) + } + if pe.Culture == nil { + t.Fatal("expected culture expression") + } + }) + } + }) + + t.Run("json functions", func(t *testing.T) { + tests := []struct { + sql string + wantNullClause string + wantReturnType string + wantArrayWrap bool + wantKeyValueArg bool + }{ + {"SELECT JSON_OBJECT('name': name, 'type': NULL ABSENT ON NULL)", "ABSENT", "", false, true}, + {"SELECT JSON_ARRAY(1, 2 NULL ON NULL RETURNING JSON)", "NULL", "JSON", false, false}, + {"SELECT JSON_VALUE(doc, '$.name' RETURNING nvarchar(100))", "", "nvarchar", false, false}, + {"SELECT JSON_QUERY(doc, '$.items' WITH ARRAY WRAPPER)", "", "", true, false}, + } + for _, tc := range tests { + t.Run(tc.sql, func(t *testing.T) { + result := ParseAndCheck(t, tc.sql) + stmt := result.Items[0].(*ast.SelectStmt) + target := stmt.TargetList.Items[0].(*ast.ResTarget) + fc, ok := target.Val.(*ast.FuncCallExpr) + if !ok { + t.Fatalf("expected *ast.FuncCallExpr, got %T", target.Val) + } + if fc.JsonNullClause != tc.wantNullClause { + t.Fatalf("expected JsonNullClause=%q, got %q", tc.wantNullClause, fc.JsonNullClause) + } + if tc.wantReturnType == "" { + if fc.ReturnType != nil { + t.Fatalf("expected nil ReturnType, got %+v", fc.ReturnType) + } + } else if fc.ReturnType == nil || fc.ReturnType.Name != tc.wantReturnType { + t.Fatalf("expected ReturnType=%q, got %+v", tc.wantReturnType, fc.ReturnType) + } + if fc.WithArrayWrapper != tc.wantArrayWrap { + t.Fatalf("expected WithArrayWrapper=%v, got %v", tc.wantArrayWrap, fc.WithArrayWrapper) + } + if tc.wantKeyValueArg { + if fc.Args == nil || fc.Args.Len() == 0 { + t.Fatal("expected JSON key-value args") + } + if _, ok := fc.Args.Items[0].(*ast.JsonKeyValueExpr); !ok { + t.Fatalf("expected *ast.JsonKeyValueExpr, got %T", fc.Args.Items[0]) + } + } + }) + } + }) +} + // TestParseVariables tests variable references. func TestParseVariables(t *testing.T) { tests := []string{ @@ -1189,6 +1407,49 @@ func TestParseSelect(t *testing.T) { } }) + t.Run("rowset non-expression identifiers", func(t *testing.T) { + tests := []struct { + name string + sql string + argIndex int + want string + }{ + { + name: "openquery linked server", + sql: "SELECT * FROM OPENQUERY(LinkedServer, 'SELECT 1') AS t", + argIndex: 0, + want: "LinkedServer", + }, + { + name: "openrowset remote object", + sql: "SELECT * FROM OPENROWSET('MSOLEDBSQL', 'Server=Seattle1;Trusted_Connection=yes;', Department) AS t", + argIndex: 2, + want: "Department", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := ParseAndCheck(t, tc.sql) + stmt := result.Items[0].(*ast.SelectStmt) + atr := stmt.FromClause.Items[0].(*ast.AliasedTableRef) + fc := atr.Table.(*ast.FuncCallExpr) + if fc.Args == nil || fc.Args.Len() <= tc.argIndex { + t.Fatalf("expected at least %d args, got %#v", tc.argIndex+1, fc.Args) + } + if cr, ok := fc.Args.Items[tc.argIndex].(*ast.ColumnRef); ok { + t.Fatalf("rowset identifier parsed as ColumnRef: %+v", cr) + } + tr, ok := fc.Args.Items[tc.argIndex].(*ast.TableRef) + if !ok { + t.Fatalf("expected rowset identifier to be *ast.TableRef, got %T", fc.Args.Items[tc.argIndex]) + } + if tr.Object != tc.want { + t.Fatalf("expected object %q, got %q", tc.want, tr.Object) + } + }) + } + }) + t.Run("openjson with", func(t *testing.T) { sql := "SELECT * FROM OPENJSON(@json) WITH (name NVARCHAR(100), age INT) AS t" result := ParseAndCheck(t, sql) diff --git a/mssql/parser/expr.go b/mssql/parser/expr.go index e5af1b49..001080cc 100644 --- a/mssql/parser/expr.go +++ b/mssql/parser/expr.go @@ -7,6 +7,11 @@ import ( nodes "github.com/bytebase/omni/mssql/ast" ) +var ( + jsonArrayOption = newOptionSet().withIdents("ARRAY") + jsonWrapperOption = newOptionSet().withIdents("WRAPPER") +) + // parseExpr parses an expression using precedence climbing. // // Ref: https://learn.microsoft.com/en-us/sql/t-sql/language-elements/expressions-transact-sql @@ -1237,14 +1242,62 @@ func (p *Parser) parseNextValueFor() (nodes.ExprNode, error) { if ref == nil { return nil, p.unexpectedToken() } - return &nodes.ColumnRef{ - Server: ref.Server, - Database: ref.Database, - Schema: ref.Schema, - Table: ref.Schema, - Column: ref.Object, + n := &nodes.NextValueForExpr{ + Sequence: ref, Loc: nodes.Loc{Start: loc, End: p.prevEnd()}, - }, nil + } + if p.cur.Type == kwOVER { + over, err := p.parseOverClause() + if err != nil { + return nil, err + } + n.Over = over + n.Loc.End = p.prevEnd() + } + return n, nil +} + +func (p *Parser) parseParseExpr(name string, loc int) (nodes.ExprNode, error) { + p.advance() // consume ( + n := &nodes.ParseExpr{ + Try: strings.EqualFold(name, "TRY_PARSE"), + Loc: nodes.Loc{Start: loc, End: -1}, + } + expr, err := p.parseExpr() + if err != nil { + return nil, err + } + if expr == nil { + return nil, p.unexpectedToken() + } + n.Expr = expr + if _, err := p.expect(kwAS); err != nil { + return nil, err + } + dt, err := p.parseDataType() + if err != nil { + return nil, err + } + if dt == nil { + return nil, p.unexpectedToken() + } + n.DataType = dt + if p.cur.Type == kwUSING { + p.advance() + culture, err := p.parseExpr() + if err != nil { + return nil, err + } + if culture == nil { + return nil, p.unexpectedToken() + } + n.Culture = culture + } + if _, err := p.expect(')'); err != nil { + return nil, err + } + n.Loc.End = p.prevEnd() + return n, nil } // parseFuncCall parses a function call after the opening paren has been seen. @@ -1295,16 +1348,32 @@ func (p *Parser) parseFuncCall(name string, loc int) (nodes.ExprNode, error) { return fc, nil } + if strings.EqualFold(name, "TRIM") { + return p.parseTrimFuncCall(fc) + } + if isJsonFunction(name) { + return p.parseJsonFuncCall(fc, name) + } + // Check for DISTINCT if _, ok := p.match(kwDISTINCT); ok { fc.Distinct = true } + argIndex := 0 args, err := p.parseCommaList(')', commaListAllowEmpty, func() (nodes.Node, error) { if p.collectMode() { p.addExpressionCandidates() return nil, errCollecting } + if argIndex == 0 && isDatePartFunction(name) && p.isIdentLike() { + tok := p.advance() + argIndex++ + return &nodes.DatePart{ + Name: tok.Str, + Loc: nodes.Loc{Start: tok.Loc, End: tok.End}, + }, nil + } arg, err := p.parseExpr() if err != nil { return nil, err @@ -1312,6 +1381,7 @@ func (p *Parser) parseFuncCall(name string, loc int) (nodes.ExprNode, error) { if arg == nil { return nil, p.unexpectedToken() } + argIndex++ return arg, nil }) if err != nil { @@ -1324,6 +1394,8 @@ func (p *Parser) parseFuncCall(name string, loc int) (nodes.ExprNode, error) { } fc.Loc.End = p.prevEnd() + p.parseNullTreatment(fc, name) + // Check for WITHIN GROUP (ORDER BY ...) clause // // Ref: https://learn.microsoft.com/en-us/sql/t-sql/functions/string-agg-transact-sql @@ -1351,6 +1423,287 @@ func (p *Parser) parseFuncCall(name string, loc int) (nodes.ExprNode, error) { return fc, nil } +func (p *Parser) parseTrimFuncCall(fc *nodes.FuncCallExpr) (nodes.ExprNode, error) { + if trimOption, ok := p.consumeTrimOption(); ok { + fc.TrimOption = trimOption + } + first, err := p.parseExpr() + if err != nil { + return nil, err + } + if first == nil { + return nil, p.unexpectedToken() + } + items := []nodes.Node{first} + if p.cur.Type == kwFROM { + p.advance() + second, err := p.parseExpr() + if err != nil { + return nil, err + } + if second == nil { + return nil, p.unexpectedToken() + } + items = append(items, second) + } + fc.Args = &nodes.List{Items: items} + if _, err := p.expect(')'); err != nil { + return nil, err + } + fc.Loc.End = p.prevEnd() + return fc, nil +} + +func (p *Parser) consumeTrimOption() (string, bool) { + if !p.isKeywordOrIdent() { + return "", false + } + switch strings.ToUpper(p.cur.Str) { + case "LEADING", "TRAILING", "BOTH": + option := strings.ToUpper(p.cur.Str) + p.advance() + return option, true + default: + return "", false + } +} + +func (p *Parser) parseNullTreatment(fc *nodes.FuncCallExpr, name string) { + if !isNullTreatmentFunction(name) || !p.isKeywordOrIdent() { + return + } + treatment := strings.ToUpper(p.cur.Str) + if treatment != "IGNORE" && treatment != "RESPECT" { + return + } + next := p.peekNext() + if next.Type != kwNULLS { + return + } + p.advance() + p.advance() + fc.NullTreatment = treatment + fc.Loc.End = p.prevEnd() +} + +func isDatePartFunction(name string) bool { + switch strings.ToUpper(name) { + case "DATEADD", "DATEDIFF", "DATEDIFF_BIG", "DATEPART", "DATENAME", "DATETRUNC", "DATE_BUCKET": + return true + default: + return false + } +} + +func isNullTreatmentFunction(name string) bool { + switch strings.ToUpper(name) { + case "FIRST_VALUE", "LAST_VALUE", "LAG", "LEAD": + return true + default: + return false + } +} + +func isJsonFunction(name string) bool { + switch strings.ToUpper(name) { + case "JSON_OBJECT", "JSON_ARRAY", "JSON_VALUE", "JSON_QUERY": + return true + default: + return false + } +} + +func (p *Parser) parseJsonFuncCall(fc *nodes.FuncCallExpr, name string) (nodes.ExprNode, error) { + switch strings.ToUpper(name) { + case "JSON_OBJECT": + if err := p.parseJsonObjectArgs(fc); err != nil { + return nil, err + } + case "JSON_ARRAY": + if err := p.parseJsonArrayArgs(fc); err != nil { + return nil, err + } + case "JSON_VALUE": + if err := p.parseJsonValueArgs(fc); err != nil { + return nil, err + } + case "JSON_QUERY": + if err := p.parseJsonQueryArgs(fc); err != nil { + return nil, err + } + } + if _, err := p.expect(')'); err != nil { + return nil, err + } + fc.Loc.End = p.prevEnd() + return fc, nil +} + +func (p *Parser) parseJsonObjectArgs(fc *nodes.FuncCallExpr) error { + var items []nodes.Node + for p.cur.Type != ')' && p.cur.Type != tokEOF && !p.isJsonNullClauseStart() && !p.isReturningStart() { + loc := p.pos() + key, err := p.parseExpr() + if err != nil { + return err + } + if key == nil { + return p.unexpectedToken() + } + if _, err := p.expect(':'); err != nil { + return err + } + value, err := p.parseExpr() + if err != nil { + return err + } + if value == nil { + return p.unexpectedToken() + } + items = append(items, &nodes.JsonKeyValueExpr{ + Key: key, + Value: value, + Loc: nodes.Loc{Start: loc, End: p.prevEnd()}, + }) + if _, ok := p.match(','); !ok { + break + } + } + fc.Args = &nodes.List{Items: items} + if err := p.parseJsonNullClause(fc); err != nil { + return err + } + return p.parseJsonReturningClause(fc) +} + +func (p *Parser) parseJsonArrayArgs(fc *nodes.FuncCallExpr) error { + var items []nodes.Node + for p.cur.Type != ')' && p.cur.Type != tokEOF && !p.isJsonNullClauseStart() && !p.isReturningStart() { + expr, err := p.parseExpr() + if err != nil { + return err + } + if expr == nil { + return p.unexpectedToken() + } + items = append(items, expr) + if _, ok := p.match(','); !ok { + break + } + } + fc.Args = &nodes.List{Items: items} + if err := p.parseJsonNullClause(fc); err != nil { + return err + } + return p.parseJsonReturningClause(fc) +} + +func (p *Parser) parseJsonValueArgs(fc *nodes.FuncCallExpr) error { + var items []nodes.Node + expr, err := p.parseExpr() + if err != nil { + return err + } + if expr == nil { + return p.unexpectedToken() + } + items = append(items, expr) + if _, ok := p.match(','); ok { + path, err := p.parseExpr() + if err != nil { + return err + } + if path == nil { + return p.unexpectedToken() + } + items = append(items, path) + } + fc.Args = &nodes.List{Items: items} + return p.parseJsonReturningClause(fc) +} + +func (p *Parser) parseJsonQueryArgs(fc *nodes.FuncCallExpr) error { + var items []nodes.Node + expr, err := p.parseExpr() + if err != nil { + return err + } + if expr == nil { + return p.unexpectedToken() + } + items = append(items, expr) + if _, ok := p.match(','); ok { + path, err := p.parseExpr() + if err != nil { + return err + } + if path == nil { + return p.unexpectedToken() + } + items = append(items, path) + } + fc.Args = &nodes.List{Items: items} + return p.parseJsonWithArrayWrapperClause(fc) +} + +func (p *Parser) parseJsonNullClause(fc *nodes.FuncCallExpr) error { + if !p.isJsonNullClauseStart() { + return nil + } + clause := strings.ToUpper(p.cur.Str) + p.advance() + if _, err := p.expect(kwON); err != nil { + return err + } + if _, err := p.expect(kwNULL); err != nil { + return err + } + fc.JsonNullClause = clause + return nil +} + +func (p *Parser) parseJsonReturningClause(fc *nodes.FuncCallExpr) error { + if !p.isReturningStart() { + return nil + } + p.advance() + dt, err := p.parseDataType() + if err != nil { + return err + } + if dt == nil { + return p.unexpectedToken() + } + fc.ReturnType = dt + return nil +} + +func (p *Parser) parseJsonWithArrayWrapperClause(fc *nodes.FuncCallExpr) error { + if p.cur.Type != kwWITH { + return nil + } + p.advance() + if _, err := p.expectOption(jsonArrayOption); err != nil { + return err + } + if _, err := p.expectOption(jsonWrapperOption); err != nil { + return err + } + fc.WithArrayWrapper = true + return nil +} + +func (p *Parser) isJsonNullClauseStart() bool { + if p.cur.Type != kwNULL && p.cur.Type != kwABSENT { + return false + } + return p.peekNext().Type == kwON +} + +func (p *Parser) isReturningStart() bool { + return p.cur.Type == kwRETURNING +} + // parseWithinGroupClause parses WITHIN GROUP (ORDER BY ...). // // Ref: https://learn.microsoft.com/en-us/sql/t-sql/functions/string-agg-transact-sql diff --git a/mssql/parser/keyword_classification_test.go b/mssql/parser/keyword_classification_test.go index 58aa65f3..c81827f0 100644 --- a/mssql/parser/keyword_classification_test.go +++ b/mssql/parser/keyword_classification_test.go @@ -412,6 +412,7 @@ var sqlServerContextKeywords = []string{ "noreset", "notification", "nowait", + "nulls", "numanode", "object", "offline", @@ -432,6 +433,7 @@ var sqlServerContextKeywords = []string{ "password", "path", "pause", + "parse", "period", "permission_set", "persisted", @@ -481,6 +483,7 @@ var sqlServerContextKeywords = []string{ "result", "resume", "retention", + "returning", "returns", "robust", "role", @@ -551,8 +554,10 @@ var sqlServerContextKeywords = []string{ "transfer", "timeout", "timer", + "trim", "try", "try_cast", + "try_parse", "type", "type_warning", "unbounded", @@ -744,7 +749,9 @@ func TestCoreKeywordNotIdentifier(t *testing.T) { // Unquoted must fail sql := fmt.Sprintf(pat.tmpl, kw) p := &Parser{} - p.lexer = NewLexer(sql); p.source = sql; p.advance() + p.lexer = NewLexer(sql) + p.source = sql + p.advance() _, err := p.parseStmt() // We expect either a parse error or a misparse (not a clean identifier parse). // For now, just check that the keyword token is NOT accepted by parseIdentifier. @@ -754,7 +761,9 @@ func TestCoreKeywordNotIdentifier(t *testing.T) { // Bracket-quoted must succeed quotedSQL := fmt.Sprintf(pat.quoted, kw) pq := &Parser{} - pq.lexer = NewLexer(quotedSQL); pq.source = quotedSQL; pq.advance() + pq.lexer = NewLexer(quotedSQL) + pq.source = quotedSQL + pq.advance() _, errq := pq.parseStmt() if errq != nil { t.Errorf("core keyword [%s] as %s: bracket-quoted should succeed but got: %v", kw, pat.name, errq) @@ -798,7 +807,9 @@ func TestContextKeywordAsIdentifier(t *testing.T) { for _, pos := range positions { sql := fmt.Sprintf(pos.tmpl, kw) p := &Parser{} - p.lexer = NewLexer(sql); p.source = sql; p.advance() + p.lexer = NewLexer(sql) + p.source = sql + p.advance() _, err := p.parseStmt() if err != nil { t.Errorf("context keyword %q as %s: %q should parse but got: %v", kw, pos.name, sql, err) diff --git a/mssql/parser/lexer.go b/mssql/parser/lexer.go index 0c881b9d..403e398f 100644 --- a/mssql/parser/lexer.go +++ b/mssql/parser/lexer.go @@ -338,6 +338,7 @@ const ( kwNOWAIT kwNULL kwNULLIF + kwNULLS kwNUMANODE kwOBJECT kwOF @@ -372,6 +373,7 @@ const ( kwPASSWORD kwPATH kwPAUSE + kwPARSE kwPERCENT kwPERIOD kwPERMISSION_SET @@ -438,6 +440,7 @@ const ( kwRESUME kwRETENTION kwRETURN + kwRETURNING kwRETURNS kwREVERT kwREVOKE @@ -538,10 +541,12 @@ const ( kwTRANSACTION kwTRANSFER kwTRIGGER + kwTRIM kwTRUNCATE kwTRY kwTRY_CAST kwTRY_CONVERT + kwTRY_PARSE kwTSEQUAL kwTYPE kwTYPE_WARNING @@ -668,7 +673,7 @@ func init() { "name": kwNAME, "national": kwNATIONAL, "native_compilation": kwNATIVE_COMPILATION, "next": kwNEXT, "no": kwNO, "nocheck": kwNOCHECK, "nocount": kwNOCOUNT, "node": kwNODE, "nolock": kwNOLOCK, "nonclustered": kwNONCLUSTERED, "none": kwNONE, "not": kwNOT, - "notification": kwNOTIFICATION, "nowait": kwNOWAIT, "null": kwNULL, "nullif": kwNULLIF, + "notification": kwNOTIFICATION, "nowait": kwNOWAIT, "null": kwNULL, "nullif": kwNULLIF, "nulls": kwNULLS, "numanode": kwNUMANODE, "object": kwOBJECT, "of": kwOF, "off": kwOFF, "offline": kwOFFLINE, "offset": kwOFFSET, "offsets": kwOFFSETS, "old_password": kwOLD_PASSWORD, "on": kwON, @@ -678,7 +683,7 @@ func init() { "outer": kwOUTER, "output": kwOUTPUT, "over": kwOVER, "override": kwOVERRIDE, "owner": kwOWNER, "page": kwPAGE, "parameterization": kwPARAMETERIZATION, "partition": kwPARTITION, "partitions": kwPARTITIONS, - "password": kwPASSWORD, "path": kwPATH, "pause": kwPAUSE, "percent": kwPERCENT, + "password": kwPASSWORD, "path": kwPATH, "pause": kwPAUSE, "parse": kwPARSE, "percent": kwPERCENT, "period": kwPERIOD, "permission_set": kwPERMISSION_SET, "persisted": kwPERSISTED, "pivot": kwPIVOT, "plan": kwPLAN, "platform": kwPLATFORM, "poison_message_handling": kwPOISON_MESSAGE_HANDLING, "policy": kwPOLICY, "pool": kwPOOL, "population": kwPOPULATION, "preceding": kwPRECEDING, "precision": kwPRECISION, @@ -694,7 +699,7 @@ func init() { "remove": kwREMOVE, "rename": kwRENAME, "reorganize": kwREORGANIZE, "repeatable": kwREPEATABLE, "replica": kwREPLICA, "replication": kwREPLICATION, "resample": kwRESAMPLE, "resource": kwRESOURCE, "resource_pool": kwRESOURCE_POOL, "restart": kwRESTART, "restore": kwRESTORE, "restrict": kwRESTRICT, - "result": kwRESULT, "resume": kwRESUME, "retention": kwRETENTION, "return": kwRETURN, + "result": kwRESULT, "resume": kwRESUME, "retention": kwRETENTION, "return": kwRETURN, "returning": kwRETURNING, "returns": kwRETURNS, "revert": kwREVERT, "revoke": kwREVOKE, "right": kwRIGHT, "robust": kwROBUST, "role": kwROLE, "rollback": kwROLLBACK, "rollup": kwROLLUP, "root": kwROOT, "round_robin": kwROUND_ROBIN, "route": kwROUTE, "row": kwROW, @@ -718,8 +723,8 @@ func init() { "tcp": kwTCP, "tempdb_metadata": kwTEMPDB_METADATA, "textimage_on": kwTEXTIMAGE_ON, "textsize": kwTEXTSIZE, "then": kwTHEN, "throw": kwTHROW, "ties": kwTIES, "time": kwTIME, "timeout": kwTIMEOUT, "timer": kwTIMER, "to": kwTO, "top": kwTOP, - "tran": kwTRAN, "transaction": kwTRANSACTION, "trigger": kwTRIGGER, "truncate": kwTRUNCATE, - "try": kwTRY, "try_cast": kwTRY_CAST, "try_convert": kwTRY_CONVERT, "tsequal": kwTSEQUAL, + "tran": kwTRAN, "transaction": kwTRANSACTION, "trigger": kwTRIGGER, "trim": kwTRIM, "truncate": kwTRUNCATE, + "try": kwTRY, "try_cast": kwTRY_CAST, "try_convert": kwTRY_CONVERT, "try_parse": kwTRY_PARSE, "tsequal": kwTSEQUAL, "type": kwTYPE, "type_warning": kwTYPE_WARNING, "unbounded": kwUNBOUNDED, "uncommitted": kwUNCOMMITTED, "undefined": kwUNDEFINED, "union": kwUNION, "unique": kwUNIQUE, "unknown": kwUNKNOWN, "unlimited": kwUNLIMITED, "unlock": kwUNLOCK, @@ -1072,6 +1077,7 @@ var keywordClassification = map[int]Keyword{ kwNOWAIT: {Name: "NOWAIT", Token: kwNOWAIT, Category: ContextKeyword}, kwNULL: {Name: "NULL", Token: kwNULL, Category: CoreKeyword}, kwNULLIF: {Name: "NULLIF", Token: kwNULLIF, Category: CoreKeyword}, + kwNULLS: {Name: "NULLS", Token: kwNULLS, Category: ContextKeyword}, kwNUMANODE: {Name: "NUMANODE", Token: kwNUMANODE, Category: ContextKeyword}, kwOBJECT: {Name: "OBJECT", Token: kwOBJECT, Category: ContextKeyword}, kwOF: {Name: "OF", Token: kwOF, Category: CoreKeyword}, @@ -1106,6 +1112,7 @@ var keywordClassification = map[int]Keyword{ kwPASSWORD: {Name: "PASSWORD", Token: kwPASSWORD, Category: ContextKeyword}, kwPATH: {Name: "PATH", Token: kwPATH, Category: ContextKeyword}, kwPAUSE: {Name: "PAUSE", Token: kwPAUSE, Category: ContextKeyword}, + kwPARSE: {Name: "PARSE", Token: kwPARSE, Category: ContextKeyword}, kwPERCENT: {Name: "PERCENT", Token: kwPERCENT, Category: CoreKeyword}, kwPERIOD: {Name: "PERIOD", Token: kwPERIOD, Category: ContextKeyword}, kwPERMISSION_SET: {Name: "PERMISSION_SET", Token: kwPERMISSION_SET, Category: ContextKeyword}, @@ -1172,6 +1179,7 @@ var keywordClassification = map[int]Keyword{ kwRESUME: {Name: "RESUME", Token: kwRESUME, Category: ContextKeyword}, kwRETENTION: {Name: "RETENTION", Token: kwRETENTION, Category: ContextKeyword}, kwRETURN: {Name: "RETURN", Token: kwRETURN, Category: CoreKeyword}, + kwRETURNING: {Name: "RETURNING", Token: kwRETURNING, Category: ContextKeyword}, kwRETURNS: {Name: "RETURNS", Token: kwRETURNS, Category: ContextKeyword}, kwREVERT: {Name: "REVERT", Token: kwREVERT, Category: CoreKeyword}, kwREVOKE: {Name: "REVOKE", Token: kwREVOKE, Category: CoreKeyword}, @@ -1272,10 +1280,12 @@ var keywordClassification = map[int]Keyword{ kwTRANSACTION: {Name: "TRANSACTION", Token: kwTRANSACTION, Category: CoreKeyword}, kwTRANSFER: {Name: "TRANSFER", Token: kwTRANSFER, Category: ContextKeyword}, kwTRIGGER: {Name: "TRIGGER", Token: kwTRIGGER, Category: CoreKeyword}, + kwTRIM: {Name: "TRIM", Token: kwTRIM, Category: ContextKeyword}, kwTRUNCATE: {Name: "TRUNCATE", Token: kwTRUNCATE, Category: CoreKeyword}, kwTRY: {Name: "TRY", Token: kwTRY, Category: ContextKeyword}, kwTRY_CAST: {Name: "TRY_CAST", Token: kwTRY_CAST, Category: ContextKeyword}, kwTRY_CONVERT: {Name: "TRY_CONVERT", Token: kwTRY_CONVERT, Category: CoreKeyword}, + kwTRY_PARSE: {Name: "TRY_PARSE", Token: kwTRY_PARSE, Category: ContextKeyword}, kwTSEQUAL: {Name: "TSEQUAL", Token: kwTSEQUAL, Category: CoreKeyword}, kwTYPE: {Name: "TYPE", Token: kwTYPE, Category: ContextKeyword}, kwTYPE_WARNING: {Name: "TYPE_WARNING", Token: kwTYPE_WARNING, Category: ContextKeyword}, diff --git a/mssql/parser/name.go b/mssql/parser/name.go index c5658ed2..e0f3f32a 100644 --- a/mssql/parser/name.go +++ b/mssql/parser/name.go @@ -2,6 +2,8 @@ package parser import ( + "strings" + nodes "github.com/bytebase/omni/mssql/ast" ) @@ -243,6 +245,9 @@ func (p *Parser) parseIdentExpr() (nodes.ExprNode, error) { // Function call: ident(...) if p.cur.Type == '(' { + if strings.EqualFold(name, "PARSE") || strings.EqualFold(name, "TRY_PARSE") { + return p.parseParseExpr(name, loc) + } return p.parseFuncCall(name, loc) } diff --git a/mssql/parser/rowset_functions.go b/mssql/parser/rowset_functions.go index 70a64880..bdd61a8a 100644 --- a/mssql/parser/rowset_functions.go +++ b/mssql/parser/rowset_functions.go @@ -4,6 +4,8 @@ package parser import ( + "strings" + nodes "github.com/bytebase/omni/mssql/ast" ) @@ -23,7 +25,19 @@ func (p *Parser) parseRowsetFunction() (nodes.TableExpr, error) { } p.advance() // consume ( + argIndex := 0 args, err := p.parseCommaList(')', commaListStrict, func() (nodes.Node, error) { + if isRowsetObjectArg(funcName, argIndex) && p.isIdentLike() { + ref, err := p.parseTableRef() + if err != nil { + return nil, err + } + if ref == nil { + return nil, p.unexpectedToken() + } + argIndex++ + return ref, nil + } arg, err := p.parseExpr() if err != nil { return nil, err @@ -31,6 +45,7 @@ func (p *Parser) parseRowsetFunction() (nodes.TableExpr, error) { if arg == nil { return nil, p.unexpectedToken() } + argIndex++ return arg, nil }) if err != nil { @@ -76,6 +91,17 @@ func (p *Parser) parseRowsetFunction() (nodes.TableExpr, error) { }, nil } +func isRowsetObjectArg(funcName string, argIndex int) bool { + switch { + case strings.EqualFold(funcName, "OPENQUERY"): + return argIndex == 0 + case strings.EqualFold(funcName, "OPENROWSET"): + return argIndex == 2 + default: + return false + } +} + // parseRowsetWithClause parses the WITH (...) column definitions for OPENJSON/OPENXML. // The opening '(' has already been consumed. func (p *Parser) parseRowsetWithClause() (*nodes.List, error) {