diff --git a/doris/ast/loc.go b/doris/ast/loc.go index cfc31c2d..35e6d473 100644 --- a/doris/ast/loc.go +++ b/doris/ast/loc.go @@ -124,6 +124,40 @@ func NodeLoc(n Node) Loc { return v.Loc case *MergeClause: return v.Loc + case *CreateRowPolicyStmt: + return v.Loc + case *DropRowPolicyStmt: + return v.Loc + case *CreateEncryptKeyStmt: + return v.Loc + case *DropEncryptKeyStmt: + return v.Loc + case *DictionaryColumn: + return v.Loc + case *CreateDictionaryStmt: + return v.Loc + case *AlterDictionaryStmt: + return v.Loc + case *DropDictionaryStmt: + return v.Loc + case *RefreshDictionaryStmt: + return v.Loc + case *CreateRoleStmt: + return v.Loc + case *AlterRoleStmt: + return v.Loc + case *DropRoleStmt: + return v.Loc + case *UserIdentity: + return v.Loc + case *CreateUserStmt: + return v.Loc + case *AlterUserStmt: + return v.Loc + case *DropUserStmt: + return v.Loc + case *SetPasswordStmt: + return v.Loc default: return NoLoc() } diff --git a/doris/ast/nodetags.go b/doris/ast/nodetags.go index 1cc7c3e6..55d553fa 100644 --- a/doris/ast/nodetags.go +++ b/doris/ast/nodetags.go @@ -206,6 +206,59 @@ const ( // T_MergeClause is the tag for *MergeClause (one WHEN clause inside MERGE). T_MergeClause + + // Security DDL nodes (T5.5). + + // T_CreateRowPolicyStmt is the tag for *CreateRowPolicyStmt. + T_CreateRowPolicyStmt + + // T_DropRowPolicyStmt is the tag for *DropRowPolicyStmt. + T_DropRowPolicyStmt + + // T_CreateEncryptKeyStmt is the tag for *CreateEncryptKeyStmt. + T_CreateEncryptKeyStmt + + // T_DropEncryptKeyStmt is the tag for *DropEncryptKeyStmt. + T_DropEncryptKeyStmt + + // T_DictionaryColumn is the tag for *DictionaryColumn. + T_DictionaryColumn + + // T_CreateDictionaryStmt is the tag for *CreateDictionaryStmt. + T_CreateDictionaryStmt + + // T_AlterDictionaryStmt is the tag for *AlterDictionaryStmt. + T_AlterDictionaryStmt + + // T_DropDictionaryStmt is the tag for *DropDictionaryStmt. + T_DropDictionaryStmt + + // T_RefreshDictionaryStmt is the tag for *RefreshDictionaryStmt. + T_RefreshDictionaryStmt + + // T_CreateRoleStmt is the tag for *CreateRoleStmt. + T_CreateRoleStmt + + // T_AlterRoleStmt is the tag for *AlterRoleStmt. + T_AlterRoleStmt + + // T_DropRoleStmt is the tag for *DropRoleStmt. + T_DropRoleStmt + + // T_UserIdentity is the tag for *UserIdentity ('user'@'host'). + T_UserIdentity + + // T_CreateUserStmt is the tag for *CreateUserStmt. + T_CreateUserStmt + + // T_AlterUserStmt is the tag for *AlterUserStmt. + T_AlterUserStmt + + // T_DropUserStmt is the tag for *DropUserStmt. + T_DropUserStmt + + // T_SetPasswordStmt is the tag for *SetPasswordStmt. + T_SetPasswordStmt ) // String returns a human-readable representation of the tag. @@ -327,6 +380,40 @@ func (t NodeTag) String() string { return "MergeStmt" case T_MergeClause: return "MergeClause" + case T_CreateRowPolicyStmt: + return "CreateRowPolicyStmt" + case T_DropRowPolicyStmt: + return "DropRowPolicyStmt" + case T_CreateEncryptKeyStmt: + return "CreateEncryptKeyStmt" + case T_DropEncryptKeyStmt: + return "DropEncryptKeyStmt" + case T_DictionaryColumn: + return "DictionaryColumn" + case T_CreateDictionaryStmt: + return "CreateDictionaryStmt" + case T_AlterDictionaryStmt: + return "AlterDictionaryStmt" + case T_DropDictionaryStmt: + return "DropDictionaryStmt" + case T_RefreshDictionaryStmt: + return "RefreshDictionaryStmt" + case T_CreateRoleStmt: + return "CreateRoleStmt" + case T_AlterRoleStmt: + return "AlterRoleStmt" + case T_DropRoleStmt: + return "DropRoleStmt" + case T_UserIdentity: + return "UserIdentity" + case T_CreateUserStmt: + return "CreateUserStmt" + case T_AlterUserStmt: + return "AlterUserStmt" + case T_DropUserStmt: + return "DropUserStmt" + case T_SetPasswordStmt: + return "SetPasswordStmt" default: return "Unknown" } diff --git a/doris/ast/securitynodes.go b/doris/ast/securitynodes.go new file mode 100644 index 00000000..edc97437 --- /dev/null +++ b/doris/ast/securitynodes.go @@ -0,0 +1,309 @@ +package ast + +// This file holds security DDL AST node types (T5.5): +// - ROW POLICY +// - ENCRYPTION KEY +// - DICTIONARY +// - ROLE +// - USER / SET PASSWORD + +// --------------------------------------------------------------------------- +// ROW POLICY +// --------------------------------------------------------------------------- + +// CreateRowPolicyStmt represents: +// +// CREATE ROW POLICY [IF NOT EXISTS] name ON table_name +// [AS {RESTRICTIVE | PERMISSIVE}] +// TO user_or_role +// USING (expr) +type CreateRowPolicyStmt struct { + Name string + IfNotExists bool + Type string // "RESTRICTIVE" or "PERMISSIVE"; empty = default + On *ObjectName // ON table_name + To string // TO user_or_role + Using string // USING (expr) — stored as raw text + Loc Loc +} + +// Tag implements Node. +func (n *CreateRowPolicyStmt) Tag() NodeTag { return T_CreateRowPolicyStmt } + +var _ Node = (*CreateRowPolicyStmt)(nil) + +// DropRowPolicyStmt represents: +// +// DROP ROW POLICY name ON table_name +type DropRowPolicyStmt struct { + Name string + On *ObjectName + Loc Loc +} + +// Tag implements Node. +func (n *DropRowPolicyStmt) Tag() NodeTag { return T_DropRowPolicyStmt } + +var _ Node = (*DropRowPolicyStmt)(nil) + +// --------------------------------------------------------------------------- +// ENCRYPTION KEY +// --------------------------------------------------------------------------- + +// CreateEncryptKeyStmt represents: +// +// CREATE ENCRYPTKEY [IF NOT EXISTS] name AS 'key_value' +type CreateEncryptKeyStmt struct { + Name *ObjectName + IfNotExists bool + Key string // AS 'key_value' + Loc Loc +} + +// Tag implements Node. +func (n *CreateEncryptKeyStmt) Tag() NodeTag { return T_CreateEncryptKeyStmt } + +var _ Node = (*CreateEncryptKeyStmt)(nil) + +// DropEncryptKeyStmt represents: +// +// DROP ENCRYPTKEY [IF EXISTS] name +type DropEncryptKeyStmt struct { + Name *ObjectName + IfExists bool + Loc Loc +} + +// Tag implements Node. +func (n *DropEncryptKeyStmt) Tag() NodeTag { return T_DropEncryptKeyStmt } + +var _ Node = (*DropEncryptKeyStmt)(nil) + +// --------------------------------------------------------------------------- +// DICTIONARY +// --------------------------------------------------------------------------- + +// DictionaryColumn represents one column entry in a CREATE DICTIONARY column list: +// +// col_name KEY | VALUE +type DictionaryColumn struct { + Name string + Role string // "KEY" or "VALUE" + Loc Loc +} + +// Tag implements Node. +func (n *DictionaryColumn) Tag() NodeTag { return T_DictionaryColumn } + +var _ Node = (*DictionaryColumn)(nil) + +// CreateDictionaryStmt represents: +// +// CREATE DICTIONARY [IF NOT EXISTS] name +// USING table_name +// (col1 KEY, col2 VALUE, ...) +// LAYOUT(HASH_MAP|IP_TRIE|...) +// PROPERTIES(...) +type CreateDictionaryStmt struct { + Name *ObjectName + IfNotExists bool + UsingTable *ObjectName + Columns []*DictionaryColumn + Layout string // HASH_MAP, IP_TRIE, etc. + Properties []*Property + Loc Loc +} + +// Tag implements Node. +func (n *CreateDictionaryStmt) Tag() NodeTag { return T_CreateDictionaryStmt } + +var _ Node = (*CreateDictionaryStmt)(nil) + +// AlterDictionaryStmt represents: +// +// ALTER DICTIONARY name PROPERTIES(...) +type AlterDictionaryStmt struct { + Name *ObjectName + Properties []*Property + Loc Loc +} + +// Tag implements Node. +func (n *AlterDictionaryStmt) Tag() NodeTag { return T_AlterDictionaryStmt } + +var _ Node = (*AlterDictionaryStmt)(nil) + +// DropDictionaryStmt represents: +// +// DROP DICTIONARY [IF EXISTS] name +type DropDictionaryStmt struct { + Name *ObjectName + IfExists bool + Loc Loc +} + +// Tag implements Node. +func (n *DropDictionaryStmt) Tag() NodeTag { return T_DropDictionaryStmt } + +var _ Node = (*DropDictionaryStmt)(nil) + +// RefreshDictionaryStmt represents: +// +// REFRESH DICTIONARY name +type RefreshDictionaryStmt struct { + Name *ObjectName + Loc Loc +} + +// Tag implements Node. +func (n *RefreshDictionaryStmt) Tag() NodeTag { return T_RefreshDictionaryStmt } + +var _ Node = (*RefreshDictionaryStmt)(nil) + +// --------------------------------------------------------------------------- +// ROLE +// --------------------------------------------------------------------------- + +// CreateRoleStmt represents: +// +// CREATE ROLE [IF NOT EXISTS] name [COMMENT 'text'] +type CreateRoleStmt struct { + Name string + IfNotExists bool + Comment string + Loc Loc +} + +// Tag implements Node. +func (n *CreateRoleStmt) Tag() NodeTag { return T_CreateRoleStmt } + +var _ Node = (*CreateRoleStmt)(nil) + +// AlterRoleStmt represents: +// +// ALTER ROLE name COMMENT 'text' +type AlterRoleStmt struct { + Name string + Comment string + Loc Loc +} + +// Tag implements Node. +func (n *AlterRoleStmt) Tag() NodeTag { return T_AlterRoleStmt } + +var _ Node = (*AlterRoleStmt)(nil) + +// DropRoleStmt represents: +// +// DROP ROLE [IF EXISTS] name +type DropRoleStmt struct { + Name string + IfExists bool + Loc Loc +} + +// Tag implements Node. +func (n *DropRoleStmt) Tag() NodeTag { return T_DropRoleStmt } + +var _ Node = (*DropRoleStmt)(nil) + +// --------------------------------------------------------------------------- +// USER +// --------------------------------------------------------------------------- + +// UserIdentity represents the 'user'@'host' form used in user statements. +// Host defaults to '%' when omitted. +type UserIdentity struct { + Username string + Host string // '%' when not specified + Loc Loc +} + +// Tag implements Node. +func (n *UserIdentity) Tag() NodeTag { return T_UserIdentity } + +var _ Node = (*UserIdentity)(nil) + +// CreateUserStmt represents: +// +// CREATE USER [IF NOT EXISTS] 'user'@'host' +// [IDENTIFIED BY 'password' | IDENTIFIED BY PASSWORD 'hash'] +// [DEFAULT ROLE 'role'] +// [password policy options...] +// [COMMENT 'text'] +type CreateUserStmt struct { + Name *UserIdentity + IfNotExists bool + Password string // IDENTIFIED BY 'password' + PasswordHash string // IDENTIFIED BY PASSWORD 'hash' + DefaultRole string + Comment string + // Password policy options (stored as raw token text for forward-compat) + PasswordExpire bool + PasswordExpireInterval int // INTERVAL n DAY; 0 = not set + FailedLoginAttempts int + PasswordLockTime int // lock time in days; 0 = not set + PasswordHistory int // HISTORY n; 0 = not set + Loc Loc +} + +// Tag implements Node. +func (n *CreateUserStmt) Tag() NodeTag { return T_CreateUserStmt } + +var _ Node = (*CreateUserStmt)(nil) + +// AlterUserStmt represents: +// +// ALTER USER [IF EXISTS] 'user'@'host' +// [IDENTIFIED BY 'password'] +// [password policy options...] +// [ACCOUNT_LOCK | ACCOUNT_UNLOCK] +// [COMMENT 'text'] +type AlterUserStmt struct { + Name *UserIdentity + IfExists bool + Password string + PasswordHash string + Comment string + AccountLock bool // ACCOUNT_LOCK + AccountUnlock bool // ACCOUNT_UNLOCK + // Password policy options + FailedLoginAttempts int + PasswordLockTime int + Loc Loc +} + +// Tag implements Node. +func (n *AlterUserStmt) Tag() NodeTag { return T_AlterUserStmt } + +var _ Node = (*AlterUserStmt)(nil) + +// DropUserStmt represents: +// +// DROP USER [IF EXISTS] 'user'@'host' +type DropUserStmt struct { + Name *UserIdentity + IfExists bool + Loc Loc +} + +// Tag implements Node. +func (n *DropUserStmt) Tag() NodeTag { return T_DropUserStmt } + +var _ Node = (*DropUserStmt)(nil) + +// SetPasswordStmt represents: +// +// SET PASSWORD [FOR 'user'@'host'] = 'password' +// SET PASSWORD [FOR 'user'@'host'] = PASSWORD('cleartext') +type SetPasswordStmt struct { + For *UserIdentity // nil when no FOR clause + Password string // final password value (cleartext or hash) + IsHash bool // true when the RHS was a bare string (hash), false when PASSWORD(...) + Loc Loc +} + +// Tag implements Node. +func (n *SetPasswordStmt) Tag() NodeTag { return T_SetPasswordStmt } + +var _ Node = (*SetPasswordStmt)(nil) diff --git a/doris/ast/walk_children.go b/doris/ast/walk_children.go index e3b3c3c7..e1a6b07b 100644 --- a/doris/ast/walk_children.go +++ b/doris/ast/walk_children.go @@ -348,5 +348,77 @@ func walkChildren(v Visitor, node Node) { Walk(v, val) } } + + // Security DDL nodes (T5.5). + case *CreateRowPolicyStmt: + if n.On != nil { + Walk(v, n.On) + } + case *DropRowPolicyStmt: + if n.On != nil { + Walk(v, n.On) + } + case *CreateEncryptKeyStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *DropEncryptKeyStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *DictionaryColumn: + // leaf-ish node, no Node children + case *CreateDictionaryStmt: + if n.Name != nil { + Walk(v, n.Name) + } + if n.UsingTable != nil { + Walk(v, n.UsingTable) + } + for _, col := range n.Columns { + Walk(v, col) + } + for _, prop := range n.Properties { + Walk(v, prop) + } + case *AlterDictionaryStmt: + if n.Name != nil { + Walk(v, n.Name) + } + for _, prop := range n.Properties { + Walk(v, prop) + } + case *DropDictionaryStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *RefreshDictionaryStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *CreateRoleStmt: + // leaf-ish node, no Node children + case *AlterRoleStmt: + // leaf-ish node, no Node children + case *DropRoleStmt: + // leaf-ish node, no Node children + case *UserIdentity: + // leaf node, no Node children + case *CreateUserStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *AlterUserStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *DropUserStmt: + if n.Name != nil { + Walk(v, n.Name) + } + case *SetPasswordStmt: + if n.For != nil { + Walk(v, n.For) + } } } diff --git a/doris/parser/parser.go b/doris/parser/parser.go index 1d23e24b..09398339 100644 --- a/doris/parser/parser.go +++ b/doris/parser/parser.go @@ -190,6 +190,17 @@ func (p *Parser) parseStmt() (ast.Node, error) { return nil, err } return p.parseCreateView(createTok.Loc, true) + case kwROW: + p.advance() // consume ROW + return p.parseCreateRowPolicy(createTok.Loc) + case kwENCRYPTKEY: + return p.parseCreateEncryptKey(createTok.Loc) + case kwDICTIONARY: + return p.parseCreateDictionary(createTok.Loc) + case kwROLE: + return p.parseCreateRole(createTok.Loc) + case kwUSER: + return p.parseCreateUser(createTok.Loc) default: return p.unsupported("CREATE") } @@ -202,6 +213,12 @@ func (p *Parser) parseStmt() (ast.Node, error) { return p.parseAlterTable() case kwVIEW: return p.parseAlterView() + case kwDICTIONARY: + return p.parseAlterDictionary(p.prev.Loc) + case kwROLE: + return p.parseAlterRole(p.prev.Loc) + case kwUSER: + return p.parseAlterUser(p.prev.Loc) default: return p.unsupported("ALTER") } @@ -214,6 +231,17 @@ func (p *Parser) parseStmt() (ast.Node, error) { return p.parseDropDatabase() case kwVIEW: return p.parseDropView(dropTok.Loc) + case kwROW: + p.advance() // consume ROW + return p.parseDropRowPolicy(dropTok.Loc) + case kwENCRYPTKEY: + return p.parseDropEncryptKey(dropTok.Loc) + case kwDICTIONARY: + return p.parseDropDictionary(dropTok.Loc) + case kwROLE: + return p.parseDropRole(dropTok.Loc) + case kwUSER: + return p.parseDropUser(dropTok.Loc) default: return p.unsupported("DROP") } @@ -279,6 +307,10 @@ func (p *Parser) parseStmt() (ast.Node, error) { // Set / Unset case kwSET: + setTok := p.advance() // consume SET + if p.cur.Kind == kwPASSWORD { + return p.parseSetPassword(setTok.Loc) + } return p.unsupported("SET") case kwUNSET: return p.unsupported("UNSET") @@ -307,6 +339,10 @@ func (p *Parser) parseStmt() (ast.Node, error) { // Materialized View / Refresh case kwREFRESH: + refreshTok := p.advance() // consume REFRESH + if p.cur.Kind == kwDICTIONARY { + return p.parseRefreshDictionary(refreshTok.Loc) + } return p.unsupported("REFRESH") // Job control diff --git a/doris/parser/security.go b/doris/parser/security.go new file mode 100644 index 00000000..8178c0a2 --- /dev/null +++ b/doris/parser/security.go @@ -0,0 +1,961 @@ +package parser + +import ( + "strings" + + "github.com/bytebase/omni/doris/ast" +) + +// --------------------------------------------------------------------------- +// ROW POLICY +// --------------------------------------------------------------------------- + +// parseCreateRowPolicy parses: +// +// CREATE ROW POLICY [IF NOT EXISTS] name ON table_name +// [AS {RESTRICTIVE | PERMISSIVE}] +// TO user_or_role +// USING (expr) +// +// On entry, CREATE and ROW have been consumed; cur is POLICY. +func (p *Parser) parseCreateRowPolicy(startLoc ast.Loc) (ast.Node, error) { + if _, err := p.expect(kwPOLICY); err != nil { + return nil, err + } + + stmt := &ast.CreateRowPolicyStmt{} + + // Optional IF NOT EXISTS + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwNOT); err != nil { + return nil, err + } + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfNotExists = true + } + + // Policy name + name, nameLoc, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + endLoc := nameLoc + + // ON table_name + if _, err := p.expect(kwON); err != nil { + return nil, err + } + onTable, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.On = onTable + endLoc = ast.NodeLoc(onTable) + + // Optional AS {RESTRICTIVE | PERMISSIVE} + if p.cur.Kind == kwAS { + p.advance() + switch p.cur.Kind { + case kwRESTRICTIVE: + stmt.Type = "RESTRICTIVE" + endLoc = p.cur.Loc + p.advance() + case kwPERMISSIVE: + stmt.Type = "PERMISSIVE" + endLoc = p.cur.Loc + p.advance() + default: + return nil, p.syntaxErrorAtCur() + } + } + + // TO user_or_role + if _, err := p.expect(kwTO); err != nil { + return nil, err + } + toName, toLoc, err := p.parseIdentifierOrString() + if err != nil { + return nil, err + } + stmt.To = toName + endLoc = toLoc + + // USING (expr) — consume as raw text until end of paren group + if p.cur.Kind == kwUSING { + p.advance() + raw, loc, err := p.consumeParenGroup() + if err != nil { + return nil, err + } + stmt.Using = raw + endLoc = loc + } + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// parseDropRowPolicy parses: +// +// DROP ROW POLICY name ON table_name +// +// On entry, DROP and ROW have been consumed; cur is POLICY. +func (p *Parser) parseDropRowPolicy(startLoc ast.Loc) (ast.Node, error) { + if _, err := p.expect(kwPOLICY); err != nil { + return nil, err + } + + stmt := &ast.DropRowPolicyStmt{} + + // Policy name + name, _, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + + // ON table_name + if _, err := p.expect(kwON); err != nil { + return nil, err + } + onTable, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.On = onTable + + stmt.Loc = startLoc.Merge(ast.NodeLoc(onTable)) + return stmt, nil +} + +// --------------------------------------------------------------------------- +// ENCRYPTION KEY +// --------------------------------------------------------------------------- + +// parseCreateEncryptKey parses: +// +// CREATE ENCRYPTKEY [IF NOT EXISTS] name AS 'key_value' +// +// On entry, CREATE has been consumed; cur is ENCRYPTKEY. +func (p *Parser) parseCreateEncryptKey(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume ENCRYPTKEY + + stmt := &ast.CreateEncryptKeyStmt{} + + // Optional IF NOT EXISTS + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwNOT); err != nil { + return nil, err + } + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfNotExists = true + } + + // Key name (can be qualified: db.key_name) + name, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + + // AS 'key_value' + if _, err := p.expect(kwAS); err != nil { + return nil, err + } + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.Key = p.cur.Str + endLoc := p.cur.Loc + p.advance() + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// parseDropEncryptKey parses: +// +// DROP ENCRYPTKEY [IF EXISTS] name +// +// On entry, DROP has been consumed; cur is ENCRYPTKEY. +func (p *Parser) parseDropEncryptKey(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume ENCRYPTKEY + + stmt := &ast.DropEncryptKeyStmt{} + + // Optional IF EXISTS + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfExists = true + } + + // Key name (can be qualified: db.key_name) + name, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + + stmt.Loc = startLoc.Merge(ast.NodeLoc(name)) + return stmt, nil +} + +// --------------------------------------------------------------------------- +// DICTIONARY +// --------------------------------------------------------------------------- + +// parseCreateDictionary parses: +// +// CREATE DICTIONARY [IF NOT EXISTS] name +// USING table_name +// (col1 KEY, col2 VALUE, ...) +// LAYOUT(layout_type) +// PROPERTIES(...) +// +// On entry, CREATE has been consumed; cur is DICTIONARY. +func (p *Parser) parseCreateDictionary(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume DICTIONARY + + stmt := &ast.CreateDictionaryStmt{} + + // Optional IF NOT EXISTS + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwNOT); err != nil { + return nil, err + } + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfNotExists = true + } + + // Dictionary name + name, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + endLoc := ast.NodeLoc(name) + + // USING table_name + if _, err := p.expect(kwUSING); err != nil { + return nil, err + } + usingTable, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.UsingTable = usingTable + endLoc = ast.NodeLoc(usingTable) + + // (col1 KEY, col2 VALUE, ...) + if p.cur.Kind == int('(') { + cols, loc, err := p.parseDictionaryColumns() + if err != nil { + return nil, err + } + stmt.Columns = cols + endLoc = loc + } + + // LAYOUT(layout_type) + if p.cur.Kind == kwLAYOUT { + p.advance() + if _, err := p.expect(int('(')); err != nil { + return nil, err + } + // Layout type is an identifier/keyword; normalize to lowercase. + layoutName, layoutLoc, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Layout = strings.ToLower(layoutName) + endLoc = layoutLoc + rparen, err := p.expect(int(')')) + if err != nil { + return nil, err + } + endLoc = rparen.Loc + } + + // PROPERTIES(...) + if p.cur.Kind == kwPROPERTIES { + props, err := p.parseProperties() + if err != nil { + return nil, err + } + stmt.Properties = props + if len(props) > 0 { + endLoc = ast.NodeLoc(props[len(props)-1]) + } + } + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// parseDictionaryColumns parses: (col1 KEY, col2 VALUE, ...) +// Returns the column list and the location of the closing ')'. +func (p *Parser) parseDictionaryColumns() ([]*ast.DictionaryColumn, ast.Loc, error) { + if _, err := p.expect(int('(')); err != nil { + return nil, ast.NoLoc(), err + } + + var cols []*ast.DictionaryColumn + + for p.cur.Kind != int(')') && p.cur.Kind != tokEOF { + colStart := p.cur.Loc + + colName, _, err := p.parseIdentifier() + if err != nil { + return nil, ast.NoLoc(), err + } + + col := &ast.DictionaryColumn{ + Name: colName, + Loc: colStart, + } + + switch p.cur.Kind { + case kwKEY: + col.Role = "KEY" + col.Loc.End = p.cur.Loc.End + p.advance() + case kwVALUE: + col.Role = "VALUE" + col.Loc.End = p.cur.Loc.End + p.advance() + default: + return nil, ast.NoLoc(), p.syntaxErrorAtCur() + } + + cols = append(cols, col) + + if p.cur.Kind == int(',') { + p.advance() + } + } + + rparen, err := p.expect(int(')')) + if err != nil { + return nil, ast.NoLoc(), err + } + + return cols, rparen.Loc, nil +} + +// parseAlterDictionary parses: +// +// ALTER DICTIONARY name PROPERTIES(...) +// +// On entry, ALTER has been consumed; cur is DICTIONARY. +func (p *Parser) parseAlterDictionary(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume DICTIONARY + + stmt := &ast.AlterDictionaryStmt{} + + name, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + endLoc := ast.NodeLoc(name) + + if p.cur.Kind == kwPROPERTIES { + props, err := p.parseProperties() + if err != nil { + return nil, err + } + stmt.Properties = props + if len(props) > 0 { + endLoc = ast.NodeLoc(props[len(props)-1]) + } + } + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// parseDropDictionary parses: +// +// DROP DICTIONARY [IF EXISTS] name +// +// On entry, DROP has been consumed; cur is DICTIONARY. +func (p *Parser) parseDropDictionary(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume DICTIONARY + + stmt := &ast.DropDictionaryStmt{} + + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfExists = true + } + + name, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + + stmt.Loc = startLoc.Merge(ast.NodeLoc(name)) + return stmt, nil +} + +// parseRefreshDictionary parses: +// +// REFRESH DICTIONARY name +// +// On entry, REFRESH has been consumed; cur is DICTIONARY. +func (p *Parser) parseRefreshDictionary(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume DICTIONARY + + stmt := &ast.RefreshDictionaryStmt{} + + name, err := p.parseMultipartIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + + stmt.Loc = startLoc.Merge(ast.NodeLoc(name)) + return stmt, nil +} + +// --------------------------------------------------------------------------- +// ROLE +// --------------------------------------------------------------------------- + +// parseCreateRole parses: +// +// CREATE ROLE [IF NOT EXISTS] name [COMMENT 'text'] +// +// On entry, CREATE has been consumed; cur is ROLE. +func (p *Parser) parseCreateRole(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume ROLE + + stmt := &ast.CreateRoleStmt{} + + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwNOT); err != nil { + return nil, err + } + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfNotExists = true + } + + name, nameLoc, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + endLoc := nameLoc + + if p.cur.Kind == kwCOMMENT { + p.advance() + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.Comment = p.cur.Str + endLoc = p.cur.Loc + p.advance() + } + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// parseAlterRole parses: +// +// ALTER ROLE name COMMENT 'text' +// +// On entry, ALTER has been consumed; cur is ROLE. +func (p *Parser) parseAlterRole(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume ROLE + + stmt := &ast.AlterRoleStmt{} + + name, _, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + + if _, err := p.expect(kwCOMMENT); err != nil { + return nil, err + } + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.Comment = p.cur.Str + endLoc := p.cur.Loc + p.advance() + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// parseDropRole parses: +// +// DROP ROLE [IF EXISTS] name +// +// On entry, DROP has been consumed; cur is ROLE. +func (p *Parser) parseDropRole(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume ROLE + + stmt := &ast.DropRoleStmt{} + + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfExists = true + } + + name, nameLoc, err := p.parseIdentifier() + if err != nil { + return nil, err + } + stmt.Name = name + + stmt.Loc = startLoc.Merge(nameLoc) + return stmt, nil +} + +// --------------------------------------------------------------------------- +// USER +// --------------------------------------------------------------------------- + +// parseUserIdentity parses: 'user'@'host' or user@'host' or 'user' or user +// +// Returns a *ast.UserIdentity. Host defaults to '%' when no @'host' suffix. +func (p *Parser) parseUserIdentity() (*ast.UserIdentity, error) { + startLoc := p.cur.Loc + + // Username — string literal or bare identifier + username, userLoc, err := p.parseIdentifierOrString() + if err != nil { + return nil, err + } + endLoc := userLoc + + host := "%" + + // Optional @'host' + if p.cur.Kind == int('@') { + p.advance() // consume @ + if p.cur.Kind != tokString && !isIdentifierToken(p.cur.Kind) { + return nil, p.syntaxErrorAtCur() + } + var hostLoc ast.Loc + host, hostLoc, err = p.parseIdentifierOrString() + if err != nil { + return nil, err + } + endLoc = hostLoc + } + + return &ast.UserIdentity{ + Username: username, + Host: host, + Loc: ast.Loc{Start: startLoc.Start, End: endLoc.End}, + }, nil +} + +// parseCreateUser parses: +// +// CREATE USER [IF NOT EXISTS] 'user'@'host' +// [IDENTIFIED BY 'password' | IDENTIFIED BY PASSWORD 'hash'] +// [DEFAULT ROLE 'role'] +// [PASSWORD_EXPIRE [INTERVAL n DAY]] +// [FAILED_LOGIN_ATTEMPTS n] +// [PASSWORD_LOCK_TIME n DAY] +// [PASSWORD_HISTORY n] +// [COMMENT 'text'] +// +// On entry, CREATE has been consumed; cur is USER. +func (p *Parser) parseCreateUser(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume USER + + stmt := &ast.CreateUserStmt{} + + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwNOT); err != nil { + return nil, err + } + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfNotExists = true + } + + identity, err := p.parseUserIdentity() + if err != nil { + return nil, err + } + stmt.Name = identity + endLoc := identity.Loc + + // Optional IDENTIFIED BY ['PASSWORD'] 'value' + if p.cur.Kind == kwIDENTIFIED { + p.advance() // consume IDENTIFIED + if _, err := p.expect(kwBY); err != nil { + return nil, err + } + if p.cur.Kind == kwPASSWORD { + // IDENTIFIED BY PASSWORD 'hash' + p.advance() + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.PasswordHash = p.cur.Str + endLoc = p.cur.Loc + p.advance() + } else { + // IDENTIFIED BY 'password' + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.Password = p.cur.Str + endLoc = p.cur.Loc + p.advance() + } + } + + // Optional DEFAULT ROLE 'role' + if p.cur.Kind == kwDEFAULT { + p.advance() + if _, err := p.expect(kwROLE); err != nil { + return nil, err + } + roleName, roleLoc, err := p.parseIdentifierOrString() + if err != nil { + return nil, err + } + stmt.DefaultRole = roleName + endLoc = roleLoc + } + + // Optional password policy clauses + endLoc = p.parsePasswordPolicyCreate(stmt, endLoc) + + // Optional COMMENT + if p.cur.Kind == kwCOMMENT { + p.advance() + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.Comment = p.cur.Str + endLoc = p.cur.Loc + p.advance() + } + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// parsePasswordPolicyCreate parses optional password policy clauses for CREATE USER. +// Modifies stmt in place and returns the updated endLoc. +func (p *Parser) parsePasswordPolicyCreate(stmt *ast.CreateUserStmt, endLoc ast.Loc) ast.Loc { + for { + switch p.cur.Kind { + case kwPASSWORD_EXPIRE: + stmt.PasswordExpire = true + endLoc = p.cur.Loc + p.advance() + // Optional INTERVAL n DAY + if p.cur.Kind == kwINTERVAL { + p.advance() + if p.cur.Kind == tokInt { + n := int(p.cur.Ival) + stmt.PasswordExpireInterval = n + endLoc = p.cur.Loc + p.advance() + } + // Consume DAY/DAYS + if p.cur.Kind == kwDAY || p.cur.Kind == kwDAYS { + endLoc = p.cur.Loc + p.advance() + } + } + case kwFAILED_LOGIN_ATTEMPTS: + endLoc = p.cur.Loc + p.advance() + if p.cur.Kind == tokInt { + n := int(p.cur.Ival) + stmt.FailedLoginAttempts = n + endLoc = p.cur.Loc + p.advance() + } + case kwPASSWORD_LOCK_TIME: + endLoc = p.cur.Loc + p.advance() + if p.cur.Kind == tokInt { + n := int(p.cur.Ival) + stmt.PasswordLockTime = n + endLoc = p.cur.Loc + p.advance() + } + if p.cur.Kind == kwDAY || p.cur.Kind == kwDAYS { + endLoc = p.cur.Loc + p.advance() + } + case kwPASSWORD_HISTORY: + endLoc = p.cur.Loc + p.advance() + if p.cur.Kind == tokInt { + n := int(p.cur.Ival) + stmt.PasswordHistory = n + endLoc = p.cur.Loc + p.advance() + } + default: + return endLoc + } + } +} + +// parseAlterUser parses: +// +// ALTER USER [IF EXISTS] 'user'@'host' +// [IDENTIFIED BY 'password'] +// [FAILED_LOGIN_ATTEMPTS n] +// [PASSWORD_LOCK_TIME n DAY] +// [ACCOUNT_LOCK | ACCOUNT_UNLOCK] +// [COMMENT 'text'] +// +// On entry, ALTER has been consumed; cur is USER. +func (p *Parser) parseAlterUser(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume USER + + stmt := &ast.AlterUserStmt{} + + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfExists = true + } + + identity, err := p.parseUserIdentity() + if err != nil { + return nil, err + } + stmt.Name = identity + endLoc := identity.Loc + + // Optional IDENTIFIED BY ['PASSWORD'] 'value' + if p.cur.Kind == kwIDENTIFIED { + p.advance() + if _, err := p.expect(kwBY); err != nil { + return nil, err + } + if p.cur.Kind == kwPASSWORD { + p.advance() + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.PasswordHash = p.cur.Str + endLoc = p.cur.Loc + p.advance() + } else { + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.Password = p.cur.Str + endLoc = p.cur.Loc + p.advance() + } + } + + // Optional password policy / account clauses + for { + switch p.cur.Kind { + case kwFAILED_LOGIN_ATTEMPTS: + endLoc = p.cur.Loc + p.advance() + if p.cur.Kind == tokInt { + n := int(p.cur.Ival) + stmt.FailedLoginAttempts = n + endLoc = p.cur.Loc + p.advance() + } + case kwPASSWORD_LOCK_TIME: + endLoc = p.cur.Loc + p.advance() + if p.cur.Kind == tokInt { + n := int(p.cur.Ival) + stmt.PasswordLockTime = n + endLoc = p.cur.Loc + p.advance() + } + if p.cur.Kind == kwDAY || p.cur.Kind == kwDAYS { + endLoc = p.cur.Loc + p.advance() + } + case kwACCOUNT_LOCK: + stmt.AccountLock = true + endLoc = p.cur.Loc + p.advance() + case kwACCOUNT_UNLOCK: + stmt.AccountUnlock = true + endLoc = p.cur.Loc + p.advance() + case kwCOMMENT: + p.advance() + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.Comment = p.cur.Str + endLoc = p.cur.Loc + p.advance() + default: + goto doneAlterUser + } + } +doneAlterUser: + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// parseDropUser parses: +// +// DROP USER [IF EXISTS] 'user'@'host' +// +// On entry, DROP has been consumed; cur is USER. +func (p *Parser) parseDropUser(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume USER + + stmt := &ast.DropUserStmt{} + + if p.cur.Kind == kwIF { + p.advance() + if _, err := p.expect(kwEXISTS); err != nil { + return nil, err + } + stmt.IfExists = true + } + + identity, err := p.parseUserIdentity() + if err != nil { + return nil, err + } + stmt.Name = identity + + stmt.Loc = startLoc.Merge(identity.Loc) + return stmt, nil +} + +// parseSetPassword parses: +// +// SET PASSWORD [FOR 'user'@'host'] = 'hash' +// SET PASSWORD [FOR 'user'@'host'] = PASSWORD('cleartext') +// +// On entry, SET has been consumed; cur is PASSWORD. +func (p *Parser) parseSetPassword(startLoc ast.Loc) (ast.Node, error) { + p.advance() // consume PASSWORD + + stmt := &ast.SetPasswordStmt{} + + // Optional FOR 'user'@'host' + if p.cur.Kind == kwFOR { + p.advance() + identity, err := p.parseUserIdentity() + if err != nil { + return nil, err + } + stmt.For = identity + } + + // '=' + if _, err := p.expect(int('=')); err != nil { + return nil, err + } + + endLoc := p.cur.Loc + + if p.cur.Kind == kwPASSWORD { + // PASSWORD('cleartext') + p.advance() + if _, err := p.expect(int('(')); err != nil { + return nil, err + } + if p.cur.Kind != tokString { + return nil, p.syntaxErrorAtCur() + } + stmt.Password = p.cur.Str + stmt.IsHash = false + p.advance() + rparen, err := p.expect(int(')')) + if err != nil { + return nil, err + } + endLoc = rparen.Loc + } else if p.cur.Kind == tokString { + // bare hash string + stmt.Password = p.cur.Str + stmt.IsHash = true + endLoc = p.cur.Loc + p.advance() + } else { + return nil, p.syntaxErrorAtCur() + } + + stmt.Loc = startLoc.Merge(endLoc) + return stmt, nil +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +// consumeParenGroup consumes a parenthesized group (including nested parens) +// and returns its raw text and location. +func (p *Parser) consumeParenGroup() (string, ast.Loc, error) { + startTok, err := p.expect(int('(')) + if err != nil { + return "", ast.NoLoc(), err + } + + depth := 1 + startOff := startTok.Loc.Start + endOff := startTok.Loc.End + + for depth > 0 && p.cur.Kind != tokEOF { + endOff = p.cur.Loc.End + switch p.cur.Kind { + case int('('): + depth++ + case int(')'): + depth-- + if depth == 0 { + endOff = p.cur.Loc.End + p.advance() + return p.input[startOff:endOff], ast.Loc{Start: startOff, End: endOff}, nil + } + } + p.advance() + } + + return "", ast.NoLoc(), p.syntaxErrorAtCur() +} diff --git a/doris/parser/security_test.go b/doris/parser/security_test.go new file mode 100644 index 00000000..a9417dd2 --- /dev/null +++ b/doris/parser/security_test.go @@ -0,0 +1,696 @@ +package parser + +import ( + "testing" + + "github.com/bytebase/omni/doris/ast" +) + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +func parseOne(t *testing.T, sql string) ast.Node { + t.Helper() + file, errs := Parse(sql) + if len(errs) != 0 { + t.Fatalf("Parse(%q) errors: %v", sql, errs) + } + if len(file.Stmts) != 1 { + t.Fatalf("Parse(%q): got %d stmts, want 1", sql, len(file.Stmts)) + } + return file.Stmts[0] +} + +func parseCreateEncryptKeyStmt(t *testing.T, sql string) *ast.CreateEncryptKeyStmt { + t.Helper() + n := parseOne(t, sql) + stmt, ok := n.(*ast.CreateEncryptKeyStmt) + if !ok { + t.Fatalf("expected *ast.CreateEncryptKeyStmt, got %T", n) + } + return stmt +} + +func parseDropEncryptKeyStmt(t *testing.T, sql string) *ast.DropEncryptKeyStmt { + t.Helper() + n := parseOne(t, sql) + stmt, ok := n.(*ast.DropEncryptKeyStmt) + if !ok { + t.Fatalf("expected *ast.DropEncryptKeyStmt, got %T", n) + } + return stmt +} + +func parseCreateRoleStmt(t *testing.T, sql string) *ast.CreateRoleStmt { + t.Helper() + n := parseOne(t, sql) + stmt, ok := n.(*ast.CreateRoleStmt) + if !ok { + t.Fatalf("expected *ast.CreateRoleStmt, got %T", n) + } + return stmt +} + +func parseDropRoleStmt(t *testing.T, sql string) *ast.DropRoleStmt { + t.Helper() + n := parseOne(t, sql) + stmt, ok := n.(*ast.DropRoleStmt) + if !ok { + t.Fatalf("expected *ast.DropRoleStmt, got %T", n) + } + return stmt +} + +func parseCreateUserStmt(t *testing.T, sql string) *ast.CreateUserStmt { + t.Helper() + n := parseOne(t, sql) + stmt, ok := n.(*ast.CreateUserStmt) + if !ok { + t.Fatalf("expected *ast.CreateUserStmt, got %T", n) + } + return stmt +} + +func parseAlterUserStmt(t *testing.T, sql string) *ast.AlterUserStmt { + t.Helper() + n := parseOne(t, sql) + stmt, ok := n.(*ast.AlterUserStmt) + if !ok { + t.Fatalf("expected *ast.AlterUserStmt, got %T", n) + } + return stmt +} + +func parseDropUserStmt(t *testing.T, sql string) *ast.DropUserStmt { + t.Helper() + n := parseOne(t, sql) + stmt, ok := n.(*ast.DropUserStmt) + if !ok { + t.Fatalf("expected *ast.DropUserStmt, got %T", n) + } + return stmt +} + +func parseSetPasswordStmt(t *testing.T, sql string) *ast.SetPasswordStmt { + t.Helper() + n := parseOne(t, sql) + stmt, ok := n.(*ast.SetPasswordStmt) + if !ok { + t.Fatalf("expected *ast.SetPasswordStmt, got %T", n) + } + return stmt +} + +// --------------------------------------------------------------------------- +// ENCRYPTION KEY — legacy corpus: security_encryptkey.sql +// --------------------------------------------------------------------------- + +func TestCreateEncryptKey_Simple(t *testing.T) { + // CREATE ENCRYPTKEY my_key AS "ABCD123456789"; + stmt := parseCreateEncryptKeyStmt(t, `CREATE ENCRYPTKEY my_key AS "ABCD123456789"`) + if stmt.Name.String() != "my_key" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "my_key") + } + if stmt.Key != "ABCD123456789" { + t.Errorf("Key = %q, want %q", stmt.Key, "ABCD123456789") + } + if stmt.IfNotExists { + t.Error("IfNotExists should be false") + } +} + +func TestCreateEncryptKey_Qualified(t *testing.T) { + // CREATE ENCRYPTKEY testdb.test_key AS "ABCD123456789"; + stmt := parseCreateEncryptKeyStmt(t, `CREATE ENCRYPTKEY testdb.test_key AS "ABCD123456789"`) + if stmt.Name.String() != "testdb.test_key" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "testdb.test_key") + } + if stmt.Key != "ABCD123456789" { + t.Errorf("Key = %q, want %q", stmt.Key, "ABCD123456789") + } +} + +func TestDropEncryptKey_Simple(t *testing.T) { + // DROP ENCRYPTKEY my_key; + stmt := parseDropEncryptKeyStmt(t, `DROP ENCRYPTKEY my_key`) + if stmt.Name.String() != "my_key" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "my_key") + } + if stmt.IfExists { + t.Error("IfExists should be false") + } +} + +func TestDropEncryptKey_IfExists(t *testing.T) { + // DROP ENCRYPTKEY IF EXISTS testdb.my_key + stmt := parseDropEncryptKeyStmt(t, `DROP ENCRYPTKEY IF EXISTS testdb.my_key`) + if stmt.Name.String() != "testdb.my_key" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "testdb.my_key") + } + if !stmt.IfExists { + t.Error("IfExists should be true") + } +} + +// --------------------------------------------------------------------------- +// ROLE — legacy corpus: account_role.sql +// --------------------------------------------------------------------------- + +func TestCreateRole_Simple(t *testing.T) { + // CREATE ROLE role1; + stmt := parseCreateRoleStmt(t, `CREATE ROLE role1`) + if stmt.Name != "role1" { + t.Errorf("Name = %q, want %q", stmt.Name, "role1") + } + if stmt.IfNotExists { + t.Error("IfNotExists should be false") + } + if stmt.Comment != "" { + t.Errorf("Comment = %q, want empty", stmt.Comment) + } +} + +func TestCreateRole_WithComment(t *testing.T) { + // CREATE ROLE role2 COMMENT "this is my first role"; + stmt := parseCreateRoleStmt(t, `CREATE ROLE role2 COMMENT "this is my first role"`) + if stmt.Name != "role2" { + t.Errorf("Name = %q, want %q", stmt.Name, "role2") + } + if stmt.Comment != "this is my first role" { + t.Errorf("Comment = %q, want %q", stmt.Comment, "this is my first role") + } +} + +func TestDropRole_Simple(t *testing.T) { + // DROP ROLE role1; + stmt := parseDropRoleStmt(t, `DROP ROLE role1`) + if stmt.Name != "role1" { + t.Errorf("Name = %q, want %q", stmt.Name, "role1") + } + if stmt.IfExists { + t.Error("IfExists should be false") + } +} + +func TestDropRole_IfExists(t *testing.T) { + // DROP ROLE IF EXISTS role1 + stmt := parseDropRoleStmt(t, `DROP ROLE IF EXISTS role1`) + if stmt.Name != "role1" { + t.Errorf("Name = %q, want %q", stmt.Name, "role1") + } + if !stmt.IfExists { + t.Error("IfExists should be true") + } +} + +// --------------------------------------------------------------------------- +// USER — CREATE USER — legacy corpus: account_create_user.sql +// --------------------------------------------------------------------------- + +func TestCreateUser_Simple(t *testing.T) { + // CREATE USER 'jack'; + stmt := parseCreateUserStmt(t, `CREATE USER 'jack'`) + if stmt.Name.Username != "jack" { + t.Errorf("Username = %q, want %q", stmt.Name.Username, "jack") + } + if stmt.Name.Host != "%" { + t.Errorf("Host = %q, want %%", stmt.Name.Host) + } + if stmt.IfNotExists { + t.Error("IfNotExists should be false") + } +} + +func TestCreateUser_WithHost(t *testing.T) { + // CREATE USER jack@'172.10.1.10' IDENTIFIED BY '123456'; + stmt := parseCreateUserStmt(t, `CREATE USER jack@'172.10.1.10' IDENTIFIED BY '123456'`) + if stmt.Name.Username != "jack" { + t.Errorf("Username = %q, want %q", stmt.Name.Username, "jack") + } + if stmt.Name.Host != "172.10.1.10" { + t.Errorf("Host = %q, want %q", stmt.Name.Host, "172.10.1.10") + } + if stmt.Password != "123456" { + t.Errorf("Password = %q, want %q", stmt.Password, "123456") + } +} + +func TestCreateUser_IdentifiedByPasswordHash(t *testing.T) { + // CREATE USER jack@'172.10.1.10' IDENTIFIED BY PASSWORD '*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9'; + stmt := parseCreateUserStmt(t, `CREATE USER jack@'172.10.1.10' IDENTIFIED BY PASSWORD '*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9'`) + if stmt.Name.Username != "jack" { + t.Errorf("Username = %q, want %q", stmt.Name.Username, "jack") + } + if stmt.PasswordHash != "*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9" { + t.Errorf("PasswordHash = %q, unexpected", stmt.PasswordHash) + } + if stmt.Password != "" { + t.Errorf("Password should be empty, got %q", stmt.Password) + } +} + +func TestCreateUser_DefaultRole(t *testing.T) { + // CREATE USER 'jack'@'192.168.%' DEFAULT ROLE 'example_role'; + stmt := parseCreateUserStmt(t, `CREATE USER 'jack'@'192.168.%' DEFAULT ROLE 'example_role'`) + if stmt.Name.Username != "jack" { + t.Errorf("Username = %q, want %q", stmt.Name.Username, "jack") + } + if stmt.Name.Host != "192.168.%" { + t.Errorf("Host = %q, want %q", stmt.Name.Host, "192.168.%") + } + if stmt.DefaultRole != "example_role" { + t.Errorf("DefaultRole = %q, want %q", stmt.DefaultRole, "example_role") + } +} + +func TestCreateUser_PasswordAndDefaultRole(t *testing.T) { + // CREATE USER 'jack'@'%' IDENTIFIED BY '12345' DEFAULT ROLE 'my_role'; + stmt := parseCreateUserStmt(t, `CREATE USER 'jack'@'%' IDENTIFIED BY '12345' DEFAULT ROLE 'my_role'`) + if stmt.Password != "12345" { + t.Errorf("Password = %q, want %q", stmt.Password, "12345") + } + if stmt.DefaultRole != "my_role" { + t.Errorf("DefaultRole = %q, want %q", stmt.DefaultRole, "my_role") + } +} + +func TestCreateUser_PasswordPolicy(t *testing.T) { + // CREATE USER 'jack' IDENTIFIED BY '12345' PASSWORD_EXPIRE INTERVAL 10 DAY FAILED_LOGIN_ATTEMPTS 3 PASSWORD_LOCK_TIME 1 DAY; + stmt := parseCreateUserStmt(t, `CREATE USER 'jack' IDENTIFIED BY '12345' PASSWORD_EXPIRE INTERVAL 10 DAY FAILED_LOGIN_ATTEMPTS 3 PASSWORD_LOCK_TIME 1 DAY`) + if stmt.Password != "12345" { + t.Errorf("Password = %q, want %q", stmt.Password, "12345") + } + if !stmt.PasswordExpire { + t.Error("PasswordExpire should be true") + } + if stmt.PasswordExpireInterval != 10 { + t.Errorf("PasswordExpireInterval = %d, want 10", stmt.PasswordExpireInterval) + } + if stmt.FailedLoginAttempts != 3 { + t.Errorf("FailedLoginAttempts = %d, want 3", stmt.FailedLoginAttempts) + } + if stmt.PasswordLockTime != 1 { + t.Errorf("PasswordLockTime = %d, want 1", stmt.PasswordLockTime) + } +} + +func TestCreateUser_PasswordHistory(t *testing.T) { + // CREATE USER 'jack' IDENTIFIED BY '12345' PASSWORD_HISTORY 8; + stmt := parseCreateUserStmt(t, `CREATE USER 'jack' IDENTIFIED BY '12345' PASSWORD_HISTORY 8`) + if stmt.PasswordHistory != 8 { + t.Errorf("PasswordHistory = %d, want 8", stmt.PasswordHistory) + } +} + +func TestCreateUser_Comment(t *testing.T) { + // CREATE USER 'jack' COMMENT "this is my first user" + stmt := parseCreateUserStmt(t, `CREATE USER 'jack' COMMENT "this is my first user"`) + if stmt.Comment != "this is my first user" { + t.Errorf("Comment = %q, want %q", stmt.Comment, "this is my first user") + } +} + +func TestCreateUser_DomainHost(t *testing.T) { + // CREATE USER 'jack'@'example_domain' IDENTIFIED BY '12345'; + stmt := parseCreateUserStmt(t, `CREATE USER 'jack'@'example_domain' IDENTIFIED BY '12345'`) + if stmt.Name.Host != "example_domain" { + t.Errorf("Host = %q, want %q", stmt.Name.Host, "example_domain") + } +} + +// --------------------------------------------------------------------------- +// USER — ALTER USER — legacy corpus: account_alter_user.sql +// --------------------------------------------------------------------------- + +func TestAlterUser_IdentifiedBy(t *testing.T) { + // ALTER USER jack@'%' IDENTIFIED BY "12345"; + stmt := parseAlterUserStmt(t, `ALTER USER jack@'%' IDENTIFIED BY "12345"`) + if stmt.Name.Username != "jack" { + t.Errorf("Username = %q, want %q", stmt.Name.Username, "jack") + } + if stmt.Name.Host != "%" { + t.Errorf("Host = %q, want %%", stmt.Name.Host) + } + if stmt.Password != "12345" { + t.Errorf("Password = %q, want %q", stmt.Password, "12345") + } + if stmt.IfExists { + t.Error("IfExists should be false") + } +} + +func TestAlterUser_PasswordPolicy(t *testing.T) { + // ALTER USER jack@'%' FAILED_LOGIN_ATTEMPTS 3 PASSWORD_LOCK_TIME 1 DAY; + stmt := parseAlterUserStmt(t, `ALTER USER jack@'%' FAILED_LOGIN_ATTEMPTS 3 PASSWORD_LOCK_TIME 1 DAY`) + if stmt.FailedLoginAttempts != 3 { + t.Errorf("FailedLoginAttempts = %d, want 3", stmt.FailedLoginAttempts) + } + if stmt.PasswordLockTime != 1 { + t.Errorf("PasswordLockTime = %d, want 1", stmt.PasswordLockTime) + } +} + +func TestAlterUser_AccountUnlock(t *testing.T) { + // ALTER USER jack@'%' ACCOUNT_UNLOCK; + stmt := parseAlterUserStmt(t, `ALTER USER jack@'%' ACCOUNT_UNLOCK`) + if !stmt.AccountUnlock { + t.Error("AccountUnlock should be true") + } + if stmt.AccountLock { + t.Error("AccountLock should be false") + } +} + +func TestAlterUser_Comment(t *testing.T) { + // ALTER USER jack@'%' COMMENT "this is my first user" + stmt := parseAlterUserStmt(t, `ALTER USER jack@'%' COMMENT "this is my first user"`) + if stmt.Comment != "this is my first user" { + t.Errorf("Comment = %q, want %q", stmt.Comment, "this is my first user") + } +} + +// --------------------------------------------------------------------------- +// USER — DROP USER — legacy corpus: account_drop_user.sql +// --------------------------------------------------------------------------- + +func TestDropUser_Simple(t *testing.T) { + // DROP USER 'jack'@'192.%' + stmt := parseDropUserStmt(t, `DROP USER 'jack'@'192.%'`) + if stmt.Name.Username != "jack" { + t.Errorf("Username = %q, want %q", stmt.Name.Username, "jack") + } + if stmt.Name.Host != "192.%" { + t.Errorf("Host = %q, want %q", stmt.Name.Host, "192.%") + } + if stmt.IfExists { + t.Error("IfExists should be false") + } +} + +func TestDropUser_IfExists(t *testing.T) { + stmt := parseDropUserStmt(t, `DROP USER IF EXISTS 'jack'@'%'`) + if !stmt.IfExists { + t.Error("IfExists should be true") + } + if stmt.Name.Username != "jack" { + t.Errorf("Username = %q, want %q", stmt.Name.Username, "jack") + } +} + +// --------------------------------------------------------------------------- +// SET PASSWORD — legacy corpus: account_set_password.sql +// --------------------------------------------------------------------------- + +func TestSetPassword_HashNoFor(t *testing.T) { + // SET PASSWORD = '*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9'; + stmt := parseSetPasswordStmt(t, `SET PASSWORD = '*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9'`) + if stmt.For != nil { + t.Error("For should be nil") + } + if stmt.Password != "*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9" { + t.Errorf("Password = %q, unexpected", stmt.Password) + } + if !stmt.IsHash { + t.Error("IsHash should be true") + } +} + +func TestSetPassword_PasswordFuncNoFor(t *testing.T) { + // SET PASSWORD = PASSWORD('123456'); + stmt := parseSetPasswordStmt(t, `SET PASSWORD = PASSWORD('123456')`) + if stmt.For != nil { + t.Error("For should be nil") + } + if stmt.Password != "123456" { + t.Errorf("Password = %q, want %q", stmt.Password, "123456") + } + if stmt.IsHash { + t.Error("IsHash should be false for PASSWORD(...) form") + } +} + +func TestSetPassword_PasswordFuncWithFor(t *testing.T) { + // SET PASSWORD FOR 'jack'@'192.%' = PASSWORD('123456'); + stmt := parseSetPasswordStmt(t, `SET PASSWORD FOR 'jack'@'192.%' = PASSWORD('123456')`) + if stmt.For == nil { + t.Fatal("For should not be nil") + } + if stmt.For.Username != "jack" { + t.Errorf("For.Username = %q, want %q", stmt.For.Username, "jack") + } + if stmt.For.Host != "192.%" { + t.Errorf("For.Host = %q, want %q", stmt.For.Host, "192.%") + } + if stmt.Password != "123456" { + t.Errorf("Password = %q, want %q", stmt.Password, "123456") + } +} + +func TestSetPassword_HashWithFor(t *testing.T) { + // SET PASSWORD FOR 'jack'@'domain' = '*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9' + stmt := parseSetPasswordStmt(t, `SET PASSWORD FOR 'jack'@'domain' = '*6BB4837EB74329105EE4568DDA7DC67ED2CA2AD9'`) + if stmt.For == nil { + t.Fatal("For should not be nil") + } + if stmt.For.Username != "jack" { + t.Errorf("For.Username = %q, want %q", stmt.For.Username, "jack") + } + if stmt.For.Host != "domain" { + t.Errorf("For.Host = %q, want %q", stmt.For.Host, "domain") + } + if !stmt.IsHash { + t.Error("IsHash should be true") + } +} + +// --------------------------------------------------------------------------- +// ROW POLICY +// --------------------------------------------------------------------------- + +func TestCreateRowPolicy_Basic(t *testing.T) { + sql := `CREATE ROW POLICY test_policy ON test_table TO test_user USING (k1 = 1)` + n := parseOne(t, sql) + stmt, ok := n.(*ast.CreateRowPolicyStmt) + if !ok { + t.Fatalf("expected *ast.CreateRowPolicyStmt, got %T", n) + } + if stmt.Name != "test_policy" { + t.Errorf("Name = %q, want %q", stmt.Name, "test_policy") + } + if stmt.On.String() != "test_table" { + t.Errorf("On = %q, want %q", stmt.On.String(), "test_table") + } + if stmt.To != "test_user" { + t.Errorf("To = %q, want %q", stmt.To, "test_user") + } + if stmt.Type != "" { + t.Errorf("Type = %q, want empty", stmt.Type) + } +} + +func TestCreateRowPolicy_Restrictive(t *testing.T) { + sql := `CREATE ROW POLICY IF NOT EXISTS p1 ON db.tbl AS RESTRICTIVE TO role1 USING (col > 0)` + n := parseOne(t, sql) + stmt, ok := n.(*ast.CreateRowPolicyStmt) + if !ok { + t.Fatalf("expected *ast.CreateRowPolicyStmt, got %T", n) + } + if !stmt.IfNotExists { + t.Error("IfNotExists should be true") + } + if stmt.Type != "RESTRICTIVE" { + t.Errorf("Type = %q, want RESTRICTIVE", stmt.Type) + } + if stmt.On.String() != "db.tbl" { + t.Errorf("On = %q, want %q", stmt.On.String(), "db.tbl") + } +} + +func TestCreateRowPolicy_Permissive(t *testing.T) { + sql := `CREATE ROW POLICY p2 ON tbl AS PERMISSIVE TO user1 USING (a = 1)` + n := parseOne(t, sql) + stmt, ok := n.(*ast.CreateRowPolicyStmt) + if !ok { + t.Fatalf("expected *ast.CreateRowPolicyStmt, got %T", n) + } + if stmt.Type != "PERMISSIVE" { + t.Errorf("Type = %q, want PERMISSIVE", stmt.Type) + } +} + +func TestDropRowPolicy_Basic(t *testing.T) { + sql := `DROP ROW POLICY test_policy ON test_table` + n := parseOne(t, sql) + stmt, ok := n.(*ast.DropRowPolicyStmt) + if !ok { + t.Fatalf("expected *ast.DropRowPolicyStmt, got %T", n) + } + if stmt.Name != "test_policy" { + t.Errorf("Name = %q, want %q", stmt.Name, "test_policy") + } + if stmt.On.String() != "test_table" { + t.Errorf("On = %q, want %q", stmt.On.String(), "test_table") + } +} + +// --------------------------------------------------------------------------- +// DICTIONARY +// --------------------------------------------------------------------------- + +func TestCreateDictionary_Basic(t *testing.T) { + sql := `CREATE DICTIONARY IF NOT EXISTS my_dict USING my_table (id KEY, name VALUE) LAYOUT(HASH_MAP) PROPERTIES("read_timeout" = "3000")` + n := parseOne(t, sql) + stmt, ok := n.(*ast.CreateDictionaryStmt) + if !ok { + t.Fatalf("expected *ast.CreateDictionaryStmt, got %T", n) + } + if stmt.Name.String() != "my_dict" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "my_dict") + } + if !stmt.IfNotExists { + t.Error("IfNotExists should be true") + } + if stmt.UsingTable.String() != "my_table" { + t.Errorf("UsingTable = %q, want %q", stmt.UsingTable.String(), "my_table") + } + if len(stmt.Columns) != 2 { + t.Fatalf("Columns: got %d, want 2", len(stmt.Columns)) + } + if stmt.Columns[0].Name != "id" || stmt.Columns[0].Role != "KEY" { + t.Errorf("Columns[0] = {%q, %q}, want {id, KEY}", stmt.Columns[0].Name, stmt.Columns[0].Role) + } + if stmt.Columns[1].Name != "name" || stmt.Columns[1].Role != "VALUE" { + t.Errorf("Columns[1] = {%q, %q}, want {name, VALUE}", stmt.Columns[1].Name, stmt.Columns[1].Role) + } + if stmt.Layout != "hash_map" { + t.Errorf("Layout = %q, want %q", stmt.Layout, "hash_map") + } + if len(stmt.Properties) != 1 { + t.Fatalf("Properties: got %d, want 1", len(stmt.Properties)) + } + if stmt.Properties[0].Key != "read_timeout" { + t.Errorf("Properties[0].Key = %q, want %q", stmt.Properties[0].Key, "read_timeout") + } +} + +func TestAlterDictionary_Basic(t *testing.T) { + sql := `ALTER DICTIONARY my_dict PROPERTIES("write_timeout" = "5000")` + n := parseOne(t, sql) + stmt, ok := n.(*ast.AlterDictionaryStmt) + if !ok { + t.Fatalf("expected *ast.AlterDictionaryStmt, got %T", n) + } + if stmt.Name.String() != "my_dict" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "my_dict") + } + if len(stmt.Properties) != 1 { + t.Fatalf("Properties: got %d, want 1", len(stmt.Properties)) + } +} + +func TestDropDictionary_Basic(t *testing.T) { + sql := `DROP DICTIONARY my_dict` + n := parseOne(t, sql) + stmt, ok := n.(*ast.DropDictionaryStmt) + if !ok { + t.Fatalf("expected *ast.DropDictionaryStmt, got %T", n) + } + if stmt.Name.String() != "my_dict" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "my_dict") + } + if stmt.IfExists { + t.Error("IfExists should be false") + } +} + +func TestDropDictionary_IfExists(t *testing.T) { + sql := `DROP DICTIONARY IF EXISTS db.my_dict` + n := parseOne(t, sql) + stmt, ok := n.(*ast.DropDictionaryStmt) + if !ok { + t.Fatalf("expected *ast.DropDictionaryStmt, got %T", n) + } + if !stmt.IfExists { + t.Error("IfExists should be true") + } + if stmt.Name.String() != "db.my_dict" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "db.my_dict") + } +} + +func TestRefreshDictionary_Basic(t *testing.T) { + sql := `REFRESH DICTIONARY my_dict` + n := parseOne(t, sql) + stmt, ok := n.(*ast.RefreshDictionaryStmt) + if !ok { + t.Fatalf("expected *ast.RefreshDictionaryStmt, got %T", n) + } + if stmt.Name.String() != "my_dict" { + t.Errorf("Name = %q, want %q", stmt.Name.String(), "my_dict") + } +} + +// --------------------------------------------------------------------------- +// Node tags +// --------------------------------------------------------------------------- + +func TestSecurityNodeTags(t *testing.T) { + tests := []struct { + node ast.Node + want ast.NodeTag + }{ + {&ast.CreateRowPolicyStmt{}, ast.T_CreateRowPolicyStmt}, + {&ast.DropRowPolicyStmt{}, ast.T_DropRowPolicyStmt}, + {&ast.CreateEncryptKeyStmt{}, ast.T_CreateEncryptKeyStmt}, + {&ast.DropEncryptKeyStmt{}, ast.T_DropEncryptKeyStmt}, + {&ast.DictionaryColumn{}, ast.T_DictionaryColumn}, + {&ast.CreateDictionaryStmt{}, ast.T_CreateDictionaryStmt}, + {&ast.AlterDictionaryStmt{}, ast.T_AlterDictionaryStmt}, + {&ast.DropDictionaryStmt{}, ast.T_DropDictionaryStmt}, + {&ast.RefreshDictionaryStmt{}, ast.T_RefreshDictionaryStmt}, + {&ast.CreateRoleStmt{}, ast.T_CreateRoleStmt}, + {&ast.AlterRoleStmt{}, ast.T_AlterRoleStmt}, + {&ast.DropRoleStmt{}, ast.T_DropRoleStmt}, + {&ast.UserIdentity{}, ast.T_UserIdentity}, + {&ast.CreateUserStmt{}, ast.T_CreateUserStmt}, + {&ast.AlterUserStmt{}, ast.T_AlterUserStmt}, + {&ast.DropUserStmt{}, ast.T_DropUserStmt}, + {&ast.SetPasswordStmt{}, ast.T_SetPasswordStmt}, + } + for _, tt := range tests { + if tt.node.Tag() != tt.want { + t.Errorf("%T.Tag() = %v, want %v", tt.node, tt.node.Tag(), tt.want) + } + } +} + +func TestSecurityNodeTagStrings(t *testing.T) { + tests := []struct { + tag ast.NodeTag + want string + }{ + {ast.T_CreateRowPolicyStmt, "CreateRowPolicyStmt"}, + {ast.T_DropRowPolicyStmt, "DropRowPolicyStmt"}, + {ast.T_CreateEncryptKeyStmt, "CreateEncryptKeyStmt"}, + {ast.T_DropEncryptKeyStmt, "DropEncryptKeyStmt"}, + {ast.T_DictionaryColumn, "DictionaryColumn"}, + {ast.T_CreateDictionaryStmt, "CreateDictionaryStmt"}, + {ast.T_AlterDictionaryStmt, "AlterDictionaryStmt"}, + {ast.T_DropDictionaryStmt, "DropDictionaryStmt"}, + {ast.T_RefreshDictionaryStmt, "RefreshDictionaryStmt"}, + {ast.T_CreateRoleStmt, "CreateRoleStmt"}, + {ast.T_AlterRoleStmt, "AlterRoleStmt"}, + {ast.T_DropRoleStmt, "DropRoleStmt"}, + {ast.T_UserIdentity, "UserIdentity"}, + {ast.T_CreateUserStmt, "CreateUserStmt"}, + {ast.T_AlterUserStmt, "AlterUserStmt"}, + {ast.T_DropUserStmt, "DropUserStmt"}, + {ast.T_SetPasswordStmt, "SetPasswordStmt"}, + } + for _, tt := range tests { + if tt.tag.String() != tt.want { + t.Errorf("NodeTag(%d).String() = %q, want %q", tt.tag, tt.tag.String(), tt.want) + } + } +}