Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 118 additions & 17 deletions oracle/parser/split.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ func Split(sql string) []Segment {

if !state.inPLSQL {
if cmd, ok := sqlPlusCommandAtLineStart(sql, tok); ok {
prefixEmpty := onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc)
if cmd.flush {
if onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc) {
if prefixEmpty {
if tok.Type == '/' && len(segments) > 0 {
stmtStart = lineEndBeforeBreak(sql, tok.End)
} else {
Expand All @@ -78,20 +79,25 @@ func Split(sql string) []Segment {
segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc))
stmtStart = lineEndBeforeBreak(sql, tok.End)
}
lexer.pos = lineEndAfterBreak(sql, tok.End)
state.reset()
continue
} else {
lineEnd := lineEndBeforeBreak(sql, tok.End)
nextStart := lineEndAfterBreak(sql, tok.End)
commandStart := stmtStart
if !onlyIgnorableSQLPlusPrefix(sql, stmtStart, tok.Loc) {
segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc))
commandStart = lineStartOffset(sql, tok.Loc)
if prefixEmpty || cmd.terminatesBufferedSQL {
lineEnd := lineEndBeforeBreak(sql, tok.End)
nextStart := lineEndAfterBreak(sql, tok.End)
commandStart := stmtStart
if !prefixEmpty {
segments = appendSegment(segments, sql, stmtStart, trimRightSpace(sql, tok.Loc))
commandStart = lineStartOffset(sql, tok.Loc)
}
segments = appendSegmentWithKind(segments, sql, commandStart, lineEnd, SegmentSQLPlusCommand)
stmtStart = nextStart
lexer.pos = lineEndAfterBreak(sql, tok.End)
state.reset()
continue
}
segments = appendSegmentWithKind(segments, sql, commandStart, lineEnd, SegmentSQLPlusCommand)
stmtStart = nextStart
}
lexer.pos = lineEndAfterBreak(sql, tok.End)
state.reset()
continue
}
}

Expand Down Expand Up @@ -431,7 +437,8 @@ func (s *splitState) canStartNestedSubprogram(tok Token) bool {
}

type sqlPlusCommand struct {
flush bool
flush bool
terminatesBufferedSQL bool
}

func sqlPlusCommandAtLineStart(sql string, tok Token) (sqlPlusCommand, bool) {
Expand All @@ -442,21 +449,24 @@ func sqlPlusCommandAtLineStart(sql string, tok Token) (sqlPlusCommand, bool) {
}

if tok.Type == '/' && isSlashDelimiterLine(sql, tok.Loc, tok.End) {
return sqlPlusCommand{flush: true}, true
return sqlPlusCommand{flush: true, terminatesBufferedSQL: true}, true
}
if tok.Type == '@' || tok.Type == '!' {
return sqlPlusCommand{}, true
return sqlPlusCommand{terminatesBufferedSQL: true}, true
}

word := splitTokenWord(tok)
if word == "" {
return sqlPlusCommand{}, false
}
if isOracleSetStatement(word, sql, tok.End) {
return sqlPlusCommand{}, false
}
if isSQLPlusFlushCommand(word) {
return sqlPlusCommand{flush: true}, true
return sqlPlusCommand{flush: true, terminatesBufferedSQL: true}, true
}
if isSQLPlusLineCommand(word) {
return sqlPlusCommand{}, true
return sqlPlusCommand{terminatesBufferedSQL: isSQLPlusLineCommandThatTerminatesBufferedSQL(word, sql, tok.End)}, true
}
return sqlPlusCommand{}, false
}
Expand All @@ -468,6 +478,44 @@ func splitTokenWord(tok Token) string {
return ""
}

func isOracleSetStatement(word, sql string, pos int) bool {
if word != "SET" {
return false
}
next := nextWordOnLine(sql, pos)
switch next {
case "TRANSACTION", "ROLE", "CONSTRAINT", "CONSTRAINTS":
return true
default:
return false
}
}

func nextWordOnLine(sql string, pos int) string {
pos = skipHorizontalSpace(sql, pos)
start := pos
for pos < len(sql) {
c := sql[pos]
if c == '\n' || c == '\r' || !(c == '_' || c >= '0' && c <= '9' || c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z') {
break
}
pos++
}
if pos == start {
return ""
}

buf := make([]byte, pos-start)
for i := start; i < pos; i++ {
c := sql[i]
if c >= 'a' && c <= 'z' {
c -= 'a' - 'A'
}
buf[i-start] = c
}
return string(buf)
}

func isSQLPlusFlushCommand(word string) bool {
switch word {
case "RUN", "R":
Expand Down Expand Up @@ -505,6 +553,59 @@ func isSQLPlusLineCommand(word string) bool {
}
}

func isSQLPlusLineCommandThatTerminatesBufferedSQL(word, sql string, pos int) bool {
switch word {
case "ACC", "ACCEPT",
"BRE", "BREAK", "BTI", "BTITLE",
"COL", "COLUMN", "COMP", "COMPUTE",
"DEF", "DEFINE",
"HO", "HOST",
"PRI", "PRINT", "PRO", "PROMPT",
"REM", "REMARK",
"SHO", "SHOW", "SPO", "SPOOL",
"TTI", "TTITLE",
"UNDEF", "UNDEFINE",
"VAR", "VARIABLE",
"WHENEVER":
return true
case "CONN", "CONNECT":
return isSQLPlusConnectCommandThatTerminatesBufferedSQL(sql, pos)
case "STA", "START":
return nextWordOnLine(sql, pos) != "WITH"
case "SET":
return isSQLPlusSetCommandThatTerminatesBufferedSQL(sql, pos)
default:
return false
}
}

func isSQLPlusConnectCommandThatTerminatesBufferedSQL(sql string, pos int) bool {
switch nextWordOnLine(sql, pos) {
case "BY", "TO":
return false
default:
return true
}
}

func isSQLPlusSetCommandThatTerminatesBufferedSQL(sql string, pos int) bool {
switch nextWordOnLine(sql, pos) {
case "APPINFO", "ARRAYSIZE", "AUTOCOMMIT", "AUTOPRINT", "AUTORECOVERY", "AUTOTRACE",
"BLOCKTERMINATOR", "CMDSEP", "COLINVISIBLE", "COLSEP", "CONCAT", "COPYCOMMIT",
"COPYTYPECHECK", "DEF", "DEFINE", "DESCRIBE", "ECHO", "EDITFILE", "EMBEDDED", "ERRORLOGGING",
"ESCAPE", "ESCCHAR", "EXITCOMMIT", "FEEDBACK", "FLAGGER", "FLUSH", "HEADING",
"HEADSEP", "INSTANCE", "LINESIZE", "LOBOFFSET", "LOGSOURCE", "LONG", "LONGCHUNKSIZE",
"MARKUP", "NEWPAGE", "NULL", "NUMFORMAT", "NUMWIDTH", "PAGESIZE", "PAUSE",
"RECSEP", "RECSEPCHAR", "SCAN", "SERVEROUT", "SERVEROUTPUT", "SHIFTINOUT", "SHOWMODE", "SQLBLANKLINES",
"SQLCASE", "SQLCONTINUE", "SQLNUMBER", "SQLPLUSCOMPATIBILITY", "SQLPREFIX",
"SQLPROMPT", "SQLTERMINATOR", "SUFFIX", "TAB", "TERMOUT", "TIME", "TIMING",
"TRIMOUT", "TRIMSPOOL", "UNDERLINE", "VERIFY", "WRAP":
return true
default:
return false
}
}

func onlyIgnorableSQLPlusPrefix(sql string, start, end int) bool {
seg := Segment{Text: sql[start:end], ByteStart: start, ByteEnd: end}
return seg.Empty()
Expand Down
150 changes: 149 additions & 1 deletion oracle/parser/split_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package parser

import "testing"
import (
"os"
"path/filepath"
"testing"
)

func TestSplitOrdinarySQL(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -616,6 +620,150 @@ func TestSplitClassifiesSQLPlusCommands(t *testing.T) {
}
}

func TestSplitClassifiesOracleSetStatementsAsSQL(t *testing.T) {
sql := "SET DEFINE OFF\n" +
"SET TRANSACTION READ ONLY;\n" +
"SET ROLE app_role;\n" +
"SET CONSTRAINTS ALL IMMEDIATE;"
got := Split(sql)
wantTexts := []string{
"SET DEFINE OFF",
"SET TRANSACTION READ ONLY",
"\nSET ROLE app_role",
"\nSET CONSTRAINTS ALL IMMEDIATE",
}
wantKinds := []SegmentKind{
SegmentSQLPlusCommand,
SegmentSQL,
SegmentSQL,
SegmentSQL,
}
if len(got) != len(wantKinds) {
t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantKinds))
}
for i := range wantKinds {
if got[i].Text != wantTexts[i] {
t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i])
}
if got[i].Kind != wantKinds[i] {
t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, wantKinds[i])
}
}
}

func TestSplitDoesNotClassifySQLContinuationLinesAsSQLPlus(t *testing.T) {
sql := "SELECT employee_id\n" +
"FROM employees\n" +
"START WITH manager_id IS NULL\n" +
"CONNECT BY PRIOR employee_id = manager_id;\n" +
"CREATE DATABASE LINK remote_db\n" +
"CONNECT TO remote_user IDENTIFIED BY remote_pass\n" +
"USING 'remote_tns';\n" +
"CREATE DATABASE mydb\n" +
"SET DEFAULT BIGFILE TABLESPACE;"
got := Split(sql)
wantTexts := []string{
"SELECT employee_id\nFROM employees\nSTART WITH manager_id IS NULL\nCONNECT BY PRIOR employee_id = manager_id",
"\nCREATE DATABASE LINK remote_db\nCONNECT TO remote_user IDENTIFIED BY remote_pass\nUSING 'remote_tns'",
"\nCREATE DATABASE mydb\nSET DEFAULT BIGFILE TABLESPACE",
}
if len(got) != len(wantTexts) {
t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantTexts))
}
for i := range wantTexts {
if got[i].Text != wantTexts[i] {
t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i])
}
if got[i].Kind != SegmentSQL {
t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, SegmentSQL)
}
}
}

func TestSplitClassifiesUnambiguousSQLPlusCommandsAfterBufferedSQL(t *testing.T) {
sql := "SELECT 1 FROM dual\n" +
"PROMPT running next query\n" +
"SPOOL install.log\n" +
"SELECT 2 FROM dual\n" +
"SET DEFINE OFF\n" +
"SELECT 3 FROM dual\n" +
"SET DEF OFF\n" +
"SELECT 4 FROM dual\n" +
"SET SERVEROUT ON\n" +
"SELECT 3 FROM dual\n" +
"CONNECT scott/tiger@db\n" +
"SELECT 2 FROM dual;"
got := Split(sql)
wantTexts := []string{
"SELECT 1 FROM dual",
"PROMPT running next query",
"SPOOL install.log",
"SELECT 2 FROM dual",
"SET DEFINE OFF",
"SELECT 3 FROM dual",
"SET DEF OFF",
"SELECT 4 FROM dual",
"SET SERVEROUT ON",
"SELECT 3 FROM dual",
"CONNECT scott/tiger@db",
"SELECT 2 FROM dual",
}
wantKinds := []SegmentKind{
SegmentSQL,
SegmentSQLPlusCommand,
SegmentSQLPlusCommand,
SegmentSQL,
SegmentSQLPlusCommand,
SegmentSQL,
SegmentSQLPlusCommand,
SegmentSQL,
SegmentSQLPlusCommand,
SegmentSQL,
SegmentSQLPlusCommand,
SegmentSQL,
}
if len(got) != len(wantTexts) {
t.Fatalf("got %d segments %q, want %d", len(got), splitTexts(got), len(wantTexts))
}
for i := range wantTexts {
if got[i].Text != wantTexts[i] {
t.Fatalf("segment[%d] Text = %q, want %q", i, got[i].Text, wantTexts[i])
}
if got[i].Kind != wantKinds[i] {
t.Fatalf("segment[%d] Kind = %v for %q, want %v", i, got[i].Kind, got[i].Text, wantKinds[i])
}
}
}

func TestSplitDoesNotClassifyValidCorpusStatementsAsSQLPlus(t *testing.T) {
corpusDir := filepath.Join("..", "quality", "corpus")
entries, err := os.ReadDir(corpusDir)
if err != nil {
corpusDir = filepath.Join("oracle", "quality", "corpus")
entries, err = os.ReadDir(corpusDir)
if err != nil {
t.Fatalf("cannot read corpus directory: %v", err)
}
}

for _, entry := range entries {
if entry.IsDir() || filepath.Ext(entry.Name()) != ".sql" {
continue
}
path := filepath.Join(corpusDir, entry.Name())
for _, stmt := range loadCorpusFile(t, path) {
if stmt.valid != "true" {
continue
}
for _, seg := range Split(stmt.sql) {
if seg.Kind == SegmentSQLPlusCommand {
t.Fatalf("%s/%s classified valid SQL as SQL*Plus command: %q", entry.Name(), stmt.name, seg.Text)
}
}
}
}
}

func splitTexts(segs []Segment) []string {
if len(segs) == 0 {
return nil
Expand Down
Loading