From 30e2c3cbf50ef867c0a98bad85f99ad208267015 Mon Sep 17 00:00:00 2001 From: rebelice Date: Wed, 13 May 2026 15:57:24 +0900 Subject: [PATCH] fix oracle sqlplus command splitting --- oracle/parser/split.go | 135 ++++++++++++++++++++++++++++---- oracle/parser/split_test.go | 150 +++++++++++++++++++++++++++++++++++- 2 files changed, 267 insertions(+), 18 deletions(-) diff --git a/oracle/parser/split.go b/oracle/parser/split.go index 7495f164..44b7f3b9 100644 --- a/oracle/parser/split.go +++ b/oracle/parser/split.go @@ -67,8 +67,9 @@ func Split(sql string) []Segment { if !state.inPLSQL { if cmd, ok := sqlPlusCommandAtLineStart(sql, tok); ok { + prefixEmpty := onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc) if cmd.flush { - if onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc) { + if prefixEmpty { if tok.Type == '/' && len(segments) > 0 { stmtStart = lineEndBeforeBreak(sql, tok.End) } else { @@ -78,20 +79,25 @@ func Split(sql string) []Segment { segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc)) stmtStart = lineEndBeforeBreak(sql, tok.End) } + lexer.pos = lineEndAfterBreak(sql, tok.End) + state.reset() + continue } else { - lineEnd := lineEndBeforeBreak(sql, tok.End) - nextStart := lineEndAfterBreak(sql, tok.End) - commandStart := stmtStart - if !onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc) { - segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc)) - commandStart = lineStartOffset(sql, tok.Loc) + if prefixEmpty || cmd.terminatesBufferedSQL { + lineEnd := lineEndBeforeBreak(sql, tok.End) + nextStart := lineEndAfterBreak(sql, tok.End) + commandStart := stmtStart + if !prefixEmpty { + segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc)) + commandStart = lineStartOffset(sql, tok.Loc) + } + segments = appendSegmentWithKind(segments, sql, commandStart, lineEnd, SegmentSQLPlusCommand) + stmtStart = nextStart + lexer.pos = lineEndAfterBreak(sql, tok.End) + state.reset() + continue } - segments = appendSegmentWithKind(segments, sql, commandStart, lineEnd, SegmentSQLPlusCommand) - stmtStart = nextStart } - lexer.pos = lineEndAfterBreak(sql, tok.End) - state.reset() - continue } } @@ -431,7 +437,8 @@ func (s *splitState) canStartNestedSubprogram(tok Token) bool { } type sqlPlusCommand struct { - flush bool + flush bool + terminatesBufferedSQL bool } func sqlPlusCommandAtLineStart(sql string, tok Token) (sqlPlusCommand, bool) { @@ -442,21 +449,24 @@ func sqlPlusCommandAtLineStart(sql string, tok Token) (sqlPlusCommand, bool) { } if tok.Type == '/' && isSlashDelimiterLine(sql, tok.Loc, tok.End) { - return sqlPlusCommand{flush: true}, true + return sqlPlusCommand{flush: true, terminatesBufferedSQL: true}, true } if tok.Type == '@' || tok.Type == '!' { - return sqlPlusCommand{}, true + return sqlPlusCommand{terminatesBufferedSQL: true}, true } word := splitTokenWord(tok) if word == "" { return sqlPlusCommand{}, false } + if isOracleSetStatement(word, sql, tok.End) { + return sqlPlusCommand{}, false + } if isSQLPlusFlushCommand(word) { - return sqlPlusCommand{flush: true}, true + return sqlPlusCommand{flush: true, terminatesBufferedSQL: true}, true } if isSQLPlusLineCommand(word) { - return sqlPlusCommand{}, true + return sqlPlusCommand{terminatesBufferedSQL: isSQLPlusLineCommandThatTerminatesBufferedSQL(word, sql, tok.End)}, true } return sqlPlusCommand{}, false } @@ -468,6 +478,44 @@ func splitTokenWord(tok Token) string { return "" } +func isOracleSetStatement(word, sql string, pos int) bool { + if word != "SET" { + return false + } + next := nextWordOnLine(sql, pos) + switch next { + case "TRANSACTION", "ROLE", "CONSTRAINT", "CONSTRAINTS": + return true + default: + return false + } +} + +func nextWordOnLine(sql string, pos int) string { + pos = skipHorizontalSpace(sql, pos) + start := pos + for pos < len(sql) { + c := sql[pos] + if c == '\n' || c == '\r' || !(c == '_' || c >= '0' && c <= '9' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z') { + break + } + pos++ + } + if pos == start { + return "" + } + + buf := make([]byte, pos-start) + for i := start; i < pos; i++ { + c := sql[i] + if c >= 'a' && c <= 'z' { + c -= 'a' - 'A' + } + buf[i-start] = c + } + return string(buf) +} + func isSQLPlusFlushCommand(word string) bool { switch word { case "RUN", "R": @@ -505,6 +553,59 @@ func isSQLPlusLineCommand(word string) bool { } } +func isSQLPlusLineCommandThatTerminatesBufferedSQL(word, sql string, pos int) bool { + switch word { + case "ACC", "ACCEPT", + "BRE", "BREAK", "BTI", "BTITLE", + "COL", "COLUMN", "COMP", "COMPUTE", + "DEF", "DEFINE", + "HO", "HOST", + "PRI", "PRINT", "PRO", "PROMPT", + "REM", "REMARK", + "SHO", "SHOW", "SPO", "SPOOL", + "TTI", "TTITLE", + "UNDEF", "UNDEFINE", + "VAR", "VARIABLE", + "WHENEVER": + return true + case "CONN", "CONNECT": + return isSQLPlusConnectCommandThatTerminatesBufferedSQL(sql, pos) + case "STA", "START": + return nextWordOnLine(sql, pos) != "WITH" + case "SET": + return isSQLPlusSetCommandThatTerminatesBufferedSQL(sql, pos) + default: + return false + } +} + +func isSQLPlusConnectCommandThatTerminatesBufferedSQL(sql string, pos int) bool { + switch nextWordOnLine(sql, pos) { + case "BY", "TO": + return false + default: + return true + } +} + +func isSQLPlusSetCommandThatTerminatesBufferedSQL(sql string, pos int) bool { + switch nextWordOnLine(sql, pos) { + case "APPINFO", "ARRAYSIZE", "AUTOCOMMIT", "AUTOPRINT", "AUTORECOVERY", "AUTOTRACE", + "BLOCKTERMINATOR", "CMDSEP", "COLINVISIBLE", "COLSEP", "CONCAT", "COPYCOMMIT", + "COPYTYPECHECK", "DEF", "DEFINE", "DESCRIBE", "ECHO", "EDITFILE", "EMBEDDED", "ERRORLOGGING", + "ESCAPE", "ESCCHAR", "EXITCOMMIT", "FEEDBACK", "FLAGGER", "FLUSH", "HEADING", + "HEADSEP", "INSTANCE", "LINESIZE", "LOBOFFSET", "LOGSOURCE", "LONG", "LONGCHUNKSIZE", + "MARKUP", "NEWPAGE", "NULL", "NUMFORMAT", "NUMWIDTH", "PAGESIZE", "PAUSE", + "RECSEP", "RECSEPCHAR", "SCAN", "SERVEROUT", "SERVEROUTPUT", "SHIFTINOUT", "SHOWMODE", "SQLBLANKLINES", + "SQLCASE", "SQLCONTINUE", "SQLNUMBER", "SQLPLUSCOMPATIBILITY", "SQLPREFIX", + "SQLPROMPT", "SQLTERMINATOR", "SUFFIX", "TAB", "TERMOUT", "TIME", "TIMING", + "TRIMOUT", "TRIMSPOOL", "UNDERLINE", "VERIFY", "WRAP": + return true + default: + return false + } +} + func onlyIgnorableSQLPlusPrefix(sql string, start, end int) bool { seg := Segment{Text: sql[start:end], ByteStart: start, ByteEnd: end} return seg.Empty() diff --git a/oracle/parser/split_test.go b/oracle/parser/split_test.go index 1c04e74c..4a2e7d8a 100644 --- a/oracle/parser/split_test.go +++ b/oracle/parser/split_test.go @@ -1,6 +1,10 @@ package parser -import "testing" +import ( + "os" + "path/filepath" + "testing" +) func TestSplitOrdinarySQL(t *testing.T) { tests := []struct { @@ -616,6 +620,150 @@ func TestSplitClassifiesSQLPlusCommands(t *testing.T) { } } +func TestSplitClassifiesOracleSetStatementsAsSQL(t *testing.T) { + sql := "SET DEFINE OFF\n" + + "SET TRANSACTION READ ONLY;\n" + + "SET ROLE app_role;\n" + + "SET CONSTRAINTS ALL IMMEDIATE;" + got := Split(sql) + wantTexts := []string{ + "SET DEFINE OFF", + "SET TRANSACTION READ ONLY", + "\nSET ROLE app_role", + "\nSET CONSTRAINTS ALL IMMEDIATE", + } + wantKinds := []SegmentKind{ + SegmentSQLPlusCommand, + SegmentSQL, + SegmentSQL, + SegmentSQL, + } + if len(got) != len(wantKinds) { + t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantKinds)) + } + for i := range wantKinds { + if got[i].Text != wantTexts[i] { + t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i]) + } + if got[i].Kind != wantKinds[i] { + t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, wantKinds[i]) + } + } +} + +func TestSplitDoesNotClassifySQLContinuationLinesAsSQLPlus(t *testing.T) { + sql := "SELECT employee_id\n" + + "FROM employees\n" + + "START WITH manager_id IS NULL\n" + + "CONNECT BY PRIOR employee_id = manager_id;\n" + + "CREATE DATABASE LINK remote_db\n" + + "CONNECT TO remote_user IDENTIFIED BY remote_pass\n" + + "USING 'remote_tns';\n" + + "CREATE DATABASE mydb\n" + + "SET DEFAULT BIGFILE TABLESPACE;" + got := Split(sql) + wantTexts := []string{ + "SELECT employee_id\nFROM employees\nSTART WITH manager_id IS NULL\nCONNECT BY PRIOR employee_id = manager_id", + "\nCREATE DATABASE LINK remote_db\nCONNECT TO remote_user IDENTIFIED BY remote_pass\nUSING 'remote_tns'", + "\nCREATE DATABASE mydb\nSET DEFAULT BIGFILE TABLESPACE", + } + if len(got) != len(wantTexts) { + t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantTexts)) + } + for i := range wantTexts { + if got[i].Text != wantTexts[i] { + t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i]) + } + if got[i].Kind != SegmentSQL { + t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, SegmentSQL) + } + } +} + +func TestSplitClassifiesUnambiguousSQLPlusCommandsAfterBufferedSQL(t *testing.T) { + sql := "SELECT 1 FROM dual\n" + + "PROMPT running next query\n" + + "SPOOL install.log\n" + + "SELECT 2 FROM dual\n" + + "SET DEFINE OFF\n" + + "SELECT 3 FROM dual\n" + + "SET DEF OFF\n" + + "SELECT 4 FROM dual\n" + + "SET SERVEROUT ON\n" + + "SELECT 3 FROM dual\n" + + "CONNECT scott/tiger@db\n" + + "SELECT 2 FROM dual;" + got := Split(sql) + wantTexts := []string{ + "SELECT 1 FROM dual", + "PROMPT running next query", + "SPOOL install.log", + "SELECT 2 FROM dual", + "SET DEFINE OFF", + "SELECT 3 FROM dual", + "SET DEF OFF", + "SELECT 4 FROM dual", + "SET SERVEROUT ON", + "SELECT 3 FROM dual", + "CONNECT scott/tiger@db", + "SELECT 2 FROM dual", + } + wantKinds := []SegmentKind{ + SegmentSQL, + SegmentSQLPlusCommand, + SegmentSQLPlusCommand, + SegmentSQL, + SegmentSQLPlusCommand, + SegmentSQL, + SegmentSQLPlusCommand, + SegmentSQL, + SegmentSQLPlusCommand, + SegmentSQL, + SegmentSQLPlusCommand, + SegmentSQL, + } + if len(got) != len(wantTexts) { + t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantTexts)) + } + for i := range wantTexts { + if got[i].Text != wantTexts[i] { + t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i]) + } + if got[i].Kind != wantKinds[i] { + t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, wantKinds[i]) + } + } +} + +func TestSplitDoesNotClassifyValidCorpusStatementsAsSQLPlus(t *testing.T) { + corpusDir := filepath.Join("..", "quality", "corpus") + entries, err := os.ReadDir(corpusDir) + if err != nil { + corpusDir = filepath.Join("oracle", "quality", "corpus") + entries, err = os.ReadDir(corpusDir) + if err != nil { + t.Fatalf("cannot read corpus directory: %v", err) + } + } + + for _, entry := range entries { + if entry.IsDir() || filepath.Ext(entry.Name()) != ".sql" { + continue + } + path := filepath.Join(corpusDir, entry.Name()) + for _, stmt := range loadCorpusFile(t, path) { + if stmt.valid != "true" { + continue + } + for _, seg := range Split(stmt.sql) { + if seg.Kind == SegmentSQLPlusCommand { + t.Fatalf("%s/%s classified valid SQL as SQL*Plus command: %q", entry.Name(), stmt.name, seg.Text) + } + } + } + } +} + func splitTexts(segs []Segment) []string { if len(segs) == 0 { return nil