-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathsqlParser_test.go
More file actions
1198 lines (966 loc) · 36.9 KB
/
sqlParser_test.go
File metadata and controls
1198 lines (966 loc) · 36.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package zorm
import (
"context"
"strings"
"testing"
)
// assertPart 验证 SQL 片段的正确性
func assertPart(t *testing.T, sql string, part sqlSpan, expectedKeyword string) {
t.Helper()
if part.End <= part.Start {
t.Errorf("SQL 片段未正确设置:%s, Start=%d, End=%d", expectedKeyword, part.Start, part.End)
return
}
content := strings.ToUpper(strings.TrimSpace(sql[part.Start:part.End]))
if !strings.Contains(content, strings.ToUpper(expectedKeyword)) {
t.Errorf("SQL 片段内容不正确:%s, 期望包含:%s, 实际:%s",
expectedKeyword, expectedKeyword, sql[part.Start:part.End])
}
}
// ---------------- TestParseSQL_Basic ----------------
// 测试基础 SQL 解析
func TestParseSQL_Basic(t *testing.T) {
sql := `SELECT name, count(*) FROM user WHERE age > 18 GROUP BY name ORDER BY count(*) DESC`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.GroupBy, "GROUP")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// ---------------- TestParseSQL_Subquery ----------------
// 测试子查询 SQL 解析
func TestParseSQL_Subquery(t *testing.T) {
sql := `SELECT name, count(*)
FROM (
SELECT name FROM user GROUP BY name
) t
WHERE age > 18
GROUP BY name
ORDER BY count(*) DESC`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.GroupBy, "GROUP")
assertPart(t, sql, parts.OrderBy, "ORDER")
// 验证 FROM 子句包含子查询
fromContent := sql[parts.From.Start:parts.From.End]
if !strings.Contains(fromContent, "SELECT name FROM user") {
t.Errorf("FROM 子句应包含子查询, 实际:%s", fromContent)
}
}
// ---------------- TestParseSQL_NoWhere ----------------
// 测试没有 WHERE 子句的 SQL
func TestParseSQL_NoWhere(t *testing.T) {
sql := `SELECT * FROM users ORDER BY id`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.OrderBy, "ORDER")
// WHERE 应该未被设置
if parts.Where.Start != parts.Where.End {
t.Errorf("WHERE 子句不应被设置, 实际 Start=%d, End=%d", parts.Where.Start, parts.Where.End)
}
}
// ---------------- TestParseSQL_NoOrderBy ----------------
// 测试没有 ORDER BY 子句的 SQL
func TestParseSQL_NoOrderBy(t *testing.T) {
sql := `SELECT id, name FROM users WHERE status = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// ORDER BY 应该未被设置
if parts.OrderBy.Start != parts.OrderBy.End {
t.Errorf("ORDER BY 子句不应被设置, 实际 Start=%d, End=%d", parts.OrderBy.Start, parts.OrderBy.End)
}
}
// ---------------- TestParseSQL_NoGroupBy ----------------
// 测试没有 GROUP BY 子句的 SQL
func TestParseSQL_NoGroupBy(t *testing.T) {
sql := `SELECT * FROM users WHERE id = 1 ORDER BY created_at`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.OrderBy, "ORDER")
// GROUP BY 应该未被设置
if parts.GroupBy.Start != parts.GroupBy.End {
t.Errorf("GROUP BY 子句不应被设置, 实际 Start=%d, End=%d", parts.GroupBy.Start, parts.GroupBy.End)
}
}
// ---------------- TestParseSQL_StringWithQuote ----------------
// 测试字符串中包含引号的 SQL
func TestParseSQL_StringWithQuote(t *testing.T) {
sql := `SELECT * FROM users WHERE name = 'O''Brien' AND desc = 'It''s fine'`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_StringWithBackslash ----------------
// 测试字符串中包含反斜杠的 SQL
func TestParseSQL_StringWithBackslash(t *testing.T) {
sql := `SELECT * FROM users WHERE path = 'C:\\Users\\test' AND regex = 'a\b'`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_DoubleQuoteString ----------------
// 测试双引号字符串的 SQL
func TestParseSQL_DoubleQuoteString(t *testing.T) {
sql := `SELECT "name", "age" FROM "users" WHERE "id" = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_LineComment ----------------
// 测试单行注释的 SQL
func TestParseSQL_LineComment(t *testing.T) {
sql := `SELECT * FROM users -- 这是注释
WHERE id = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_LineCommentEOF ----------------
// 测试单行注释在末尾没有换行符的 SQL (边界情况)
// 例如:SELECT * FROM user -- comment 没有 \n
func TestParseSQL_LineCommentEOF(t *testing.T) {
sql := "SELECT * FROM users -- comment at end without newline"
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
// WHERE 应该未被设置 (注释后没有内容)
if parts.Where.Start != parts.Where.End {
t.Errorf("WHERE 子句不应被设置, 实际 Start=%d, End=%d", parts.Where.Start, parts.Where.End)
}
}
// ---------------- TestParseSQL_LineCommentNoSpace ----------------
// 测试单行注释紧跟关键字后没有空格的 SQL
func TestParseSQL_LineCommentNoSpace(t *testing.T) {
sql := "SELECT * FROM users--comment without space\nWHERE id = 1"
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_MultiLineComment ----------------
// 测试多行注释的 SQL
func TestParseSQL_MultiLineComment(t *testing.T) {
sql := `SELECT /* 这是注释 */ * FROM users WHERE id = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_NestedComment ----------------
// 测试嵌套括号的 SQL
func TestParseSQL_NestedParentheses(t *testing.T) {
sql := `SELECT (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) AS order_count FROM users WHERE status IN (1, 2, 3)`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_CaseInsensitive ----------------
// 测试大小写不敏感的 SQL
func TestParseSQL_CaseInsensitive(t *testing.T) {
sql := `select * from USERS where ID = 1 group by NAME order by TIME limit 10`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.GroupBy, "GROUP")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// ---------------- TestParseSQL_MixedCase ----------------
// 测试混合大小写的 SQL
func TestParseSQL_MixedCase(t *testing.T) {
sql := `SeLeCt * FrOm users WhErE id = 1 GrOuP By name OrDeR By id`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.GroupBy, "GROUP")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// ---------------- TestParseSQL_ExtraSpaces ----------------
// 测试多余空格的 SQL
func TestParseSQL_ExtraSpaces(t *testing.T) {
sql := `SELECT * FROM users WHERE id = 1 GROUP BY name ORDER BY id`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.GroupBy, "GROUP")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// ---------------- TestParseSQL_NewLines ----------------
// 测试换行符的 SQL
func TestParseSQL_NewLines(t *testing.T) {
sql := `SELECT *
FROM users
WHERE id = 1
GROUP BY name
ORDER BY id`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.GroupBy, "GROUP")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// ---------------- TestParseSQL_KeywordInString ----------------
// 测试关键字在字符串中的 SQL
func TestParseSQL_KeywordInString(t *testing.T) {
sql := `SELECT * FROM users WHERE name = 'SELECT FROM WHERE' AND desc = 'ORDER BY LIMIT'`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// 不应该错误地解析字符串中的关键字
if parts.OrderBy.Start != parts.OrderBy.End {
t.Errorf("ORDER BY 不应被设置 (字符串中的伪关键字) ")
}
}
// ---------------- TestParseSQL_KeywordInIdentifier ----------------
// 测试关键字作为标识符一部分的 SQL
func TestParseSQL_KeywordInIdentifier(t *testing.T) {
sql := `SELECT select_time, from_addr, where_field FROM table1 WHERE orderby = 1 AND grouping = 2`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// 不应该错误地解析标识符中的关键字
if parts.OrderBy.Start != parts.OrderBy.End {
t.Errorf("ORDER BY 不应被设置 (标识符中的伪关键字) ")
}
if parts.GroupBy.Start != parts.GroupBy.End {
t.Errorf("GROUP BY 不应被设置 (标识符中的伪关键字) ")
}
}
// ---------------- TestParseSQL_Union ----------------
// 测试 UNION 的 SQL
func TestParseSQL_Union(t *testing.T) {
sql := `SELECT id, name FROM users WHERE status = 1 UNION SELECT id, name FROM admins WHERE status = 1 ORDER BY id`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
// FROM 应该指向第一个 FROM
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.OrderBy, "ORDER")
// 验证 UNION 被正确解析
if parts.Union.Start == parts.Union.End {
t.Errorf("UNION 应被解析")
}
unionContent := sql[parts.Union.Start:parts.Union.End]
if strings.ToUpper(strings.TrimSpace(unionContent)) != "UNION" {
t.Errorf("UNION 内容不正确, 实际:%s", unionContent)
}
}
// ---------------- TestParseSQL_UnionAll ----------------
// 测试 UNION ALL 的 SQL
func TestParseSQL_UnionAll(t *testing.T) {
sql := `SELECT id FROM users UNION ALL SELECT id FROM orders UNION ALL SELECT id FROM products`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
// 验证第一个 UNION 被正确解析
if parts.Union.Start == parts.Union.End {
t.Errorf("UNION 应被解析")
}
}
// ---------------- TestParseSQL_Distinct ----------------
// 测试 DISTINCT 的 SQL
func TestParseSQL_Distinct(t *testing.T) {
sql := `SELECT DISTINCT name FROM users WHERE status = 1 GROUP BY name`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.GroupBy, "GROUP")
// 验证 DISTINCT 被正确解析
if parts.Distinct.Start == parts.Distinct.End {
t.Errorf("DISTINCT 应被解析")
}
distinctContent := sql[parts.Distinct.Start:parts.Distinct.End]
if strings.ToUpper(strings.TrimSpace(distinctContent)) != "DISTINCT" {
t.Errorf("DISTINCT 内容不正确, 实际:%s", distinctContent)
}
}
// ---------------- TestParseSQL_DistinctInString ----------------
// 测试字符串中的 DISTINCT 不会被误解析
func TestParseSQL_DistinctInString(t *testing.T) {
sql := `SELECT * FROM users WHERE note = 'SELECT DISTINCT is a keyword'`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// 字符串中的 DISTINCT 不应被解析
if parts.Distinct.Start != parts.Distinct.End {
t.Errorf("DISTINCT 不应被解析 (字符串中的伪关键字)")
}
}
// ---------------- TestParseSQL_DistinctAsIdentifier ----------------
// 测试 DISTINCT 作为标识符一部分不会被误解析
func TestParseSQL_DistinctAsIdentifier(t *testing.T) {
sql := `SELECT distinct_count, is_distinct FROM metrics WHERE id = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// 标识符中的 DISTINCT 不应被解析
if parts.Distinct.Start != parts.Distinct.End {
t.Errorf("DISTINCT 不应被解析 (标识符中的伪关键字)")
}
}
// ---------------- TestParseSQL_DistinctOrderBy ----------------
// 测试 SELECT DISTINCT + ORDER BY 的 SQL
func TestParseSQL_DistinctOrderBy(t *testing.T) {
sql := `SELECT DISTINCT name FROM users ORDER BY name`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.OrderBy, "ORDER")
// 验证 DISTINCT 被正确解析
if parts.Distinct.Start == parts.Distinct.End {
t.Errorf("DISTINCT 应被解析")
}
}
// ---------------- TestParseSQL_UnionInSubquery ----------------
// 测试子查询中的 UNION 不应影响外层解析
func TestParseSQL_UnionInSubquery(t *testing.T) {
sql := `SELECT * FROM (SELECT id FROM users UNION SELECT id FROM orders) AS temp WHERE id > 0`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// UNION 应被解析 (即使在子查询中, 当前实现会解析)
// 但 FROM 应该指向外层的 FROM
fromContent := sql[parts.From.Start:parts.From.End]
if !strings.Contains(fromContent, "temp") {
t.Logf("FROM 应包含外层查询范围:%s", fromContent)
}
}
// ---------------- TestParseSQL_Intersect ----------------
// 测试 INTERSECT 的 SQL
func TestParseSQL_Intersect(t *testing.T) {
sql := `SELECT id FROM users INTERSECT SELECT id FROM admins`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
// 验证 INTERSECT 被正确解析
if parts.Intersect.Start == parts.Intersect.End {
t.Errorf("INTERSECT 应被解析")
}
intersectContent := sql[parts.Intersect.Start:parts.Intersect.End]
if strings.ToUpper(strings.TrimSpace(intersectContent)) != "INTERSECT" {
t.Errorf("INTERSECT 内容不正确, 实际:%s", intersectContent)
}
}
// ---------------- TestParseSQL_Except ----------------
// 测试 EXCEPT 的 SQL
func TestParseSQL_Except(t *testing.T) {
sql := `SELECT id FROM users EXCEPT SELECT id FROM admins`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
// 验证 EXCEPT 被正确解析
if parts.Except.Start == parts.Except.End {
t.Errorf("EXCEPT 应被解析")
}
exceptContent := sql[parts.Except.Start:parts.Except.End]
if strings.ToUpper(strings.TrimSpace(exceptContent)) != "EXCEPT" {
t.Errorf("EXCEPT 内容不正确, 实际:%s", exceptContent)
}
}
// ---------------- TestParseSQL_Join ----------------
// 测试 JOIN 的 SQL
func TestParseSQL_Join(t *testing.T) {
sql := `SELECT u.id, u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1 ORDER BY o.created_at`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// ---------------- TestParseSQL_MultipleJoins ----------------
// 测试多个 JOIN 的 SQL
func TestParseSQL_MultipleJoins(t *testing.T) {
sql := `SELECT * FROM users u
INNER JOIN orders o ON u.id = o.user_id
LEFT JOIN products p ON o.product_id = p.id
WHERE u.status = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_WithHint ----------------
// 测试带 Hint 的 SQL
func TestParseSQL_WithHint(t *testing.T) {
sql := `SELECT /*+ INDEX(users idx_status) */ * FROM users WHERE status = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_Count ----------------
// 测试 COUNT 聚合的 SQL
func TestParseSQL_Count(t *testing.T) {
sql := `SELECT COUNT(*) AS total FROM users WHERE status = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_CountSubquery ----------------
// 测试分页 COUNT 子查询的 SQL (用于替代正则表达式场景)
func TestParseSQL_CountSubquery(t *testing.T) {
sql := `SELECT COUNT(*) temp_zorm_row_count FROM (SELECT DISTINCT name FROM users WHERE age > 18 GROUP BY name ORDER BY name) temp_zorm_noob_table_name WHERE 1=1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
// 验证能正确解析内层查询的关键字
content := sql[parts.From.Start:parts.From.End]
if !strings.Contains(content, "DISTINCT") {
t.Errorf("FROM 子句应包含 DISTINCT")
}
if !strings.Contains(content, "GROUP BY") {
t.Errorf("FROM 子句应包含 GROUP BY")
}
}
// ---------------- TestParseSQL_Insert ----------------
// 测试 INSERT 语句 (只解析 SELECT 部分)
func TestParseSQL_Insert(t *testing.T) {
sql := `INSERT INTO users (name, age) VALUES ('test', 18)`
parts := parseSQL(sql)
// INSERT 语句没有 SELECT 关键字, Select 会包含整个语句
if parts.Select.End != len(sql) {
t.Errorf("INSERT 语句的 Select 应包含整个语句")
}
}
// ---------------- TestParseSQL_Update ----------------
// 测试 UPDATE 语句
func TestParseSQL_Update(t *testing.T) {
sql := `UPDATE users SET name = 'test', age = 18 WHERE id = 1`
parts := parseSQL(sql)
// UPDATE 语句没有 SELECT/FROM 关键字
if parts.Where.Start == 0 || parts.Where.End == 0 {
t.Errorf("UPDATE 语句应解析 WHERE 子句")
}
}
// ---------------- TestParseSQL_Delete ----------------
// 测试 DELETE 语句
func TestParseSQL_Delete(t *testing.T) {
sql := `DELETE FROM users WHERE id = 1`
parts := parseSQL(sql)
// DELETE FROM 语句应解析 FROM 和 WHERE
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_EmptyString ----------------
// 测试空字符串
func TestParseSQL_EmptyString(t *testing.T) {
sql := ``
parts := parseSQL(sql)
// 空 SQL 应返回全 0 的 part
if parts.Select.Start != 0 || parts.Select.End != 0 {
t.Errorf("空 SQL 的 Select 应为 Start=0, End=0")
}
}
// ---------------- TestParseSQL_OnlySelect ----------------
// 测试只有 SELECT 的 SQL
func TestParseSQL_OnlySelect(t *testing.T) {
sql := `SELECT 1 + 1 AS result`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
// 其他子句应未被设置
if parts.From.Start != parts.From.End {
t.Errorf("FROM 不应被设置")
}
if parts.Where.Start != parts.Where.End {
t.Errorf("WHERE 不应被设置")
}
if parts.GroupBy.Start != parts.GroupBy.End {
t.Errorf("GROUP BY 不应被设置")
}
if parts.OrderBy.Start != parts.OrderBy.End {
t.Errorf("ORDER BY 不应被设置")
}
}
// ---------------- TestParseSQL_MultipleOrderBy ----------------
// 测试多个 ORDER BY 字段的 SQL
func TestParseSQL_MultipleOrderBy(t *testing.T) {
sql := `SELECT * FROM users ORDER BY status DESC, created_at ASC, id DESC`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.OrderBy, "ORDER")
// 验证 ORDER BY 包含所有字段
orderByContent := sql[parts.OrderBy.Start:parts.OrderBy.End]
if !strings.Contains(orderByContent, "status") {
t.Errorf("ORDER BY 应包含 status 字段")
}
if !strings.Contains(orderByContent, "created_at") {
t.Errorf("ORDER BY 应包含 created_at 字段")
}
if !strings.Contains(orderByContent, "id") {
t.Errorf("ORDER BY 应包含 id 字段")
}
}
// ---------------- TestParseSQL_SelectCountSubquery ----------------
// 测试 SELECT 子句中包含 COUNT 子查询的 SQL, 通过 selectCount 函数验证生成的 COUNT 语句是否正确
func TestParseSQL_SelectCountSubquery(t *testing.T) {
// 典型场景: 查询列表时, 用子查询统计关联表的数量
originalSQL := `select ut.*,(SELECT COUNT (*) FROM users us WHERE us.unit_id = ut.unit_id) num from units ut `
// 构造 Finder, 模拟分页查询场景
finder := NewFinder()
finder.Append(originalSQL)
finder.InjectionCheck = false
// 构造 Page, 触发 selectCount
page := NewPage()
// mock queryRow, 捕获传入的 SQL
var capturedSQL string
originalQueryRow := queryRow
queryRow = func(ctx context.Context, f *Finder, entity interface{}) (bool, error) {
sqlstr, _ := f.GetSQL()
capturedSQL = sqlstr
return false, nil
}
defer func() {
queryRow = originalQueryRow
}()
// 调用 selectCount, 内部会构建 COUNT SQL 并调用 QueryRow -> queryRow
ctx := context.Background()
_, _ = selectCount(ctx, finder)
t.Logf("原始 SQL: %s", originalSQL)
t.Logf("生成的 COUNT SQL: %s", capturedSQL)
// 验证 COUNT SQL 正确
expectedCountSQL := "SELECT COUNT(*) from units ut "
if !strings.EqualFold(strings.TrimSpace(capturedSQL), strings.TrimSpace(expectedCountSQL)) {
t.Errorf("生成的 COUNT SQL 不正确.\n期望: %s\n实际: %s", expectedCountSQL, capturedSQL)
}
// 验证 page 的 TotalCount 被正确设置
if page.TotalCount != 0 {
t.Logf("page.TotalCount: %d", page.TotalCount)
}
}
// ---------------- TestParseSQL_RealWorldExamples ----------------
// 测试真实世界的复杂 SQL 示例
// 示例 1: 带子查询和分页的复杂查询
func TestParseSQL_ComplexPaging(t *testing.T) {
sql := `SELECT u.*, (SELECT COUNT(*) FROM orders WHERE user_id = u.id) AS order_count
FROM (SELECT * FROM users WHERE status IN (1,2,3) ORDER BY created_at DESC) u
WHERE u.age > 18
ORDER BY u.created_at DESC`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// 示例 2: 多表关联查询
func TestParseSQL_MultiTableJoin(t *testing.T) {
sql := `SELECT
u.id, u.name,
o.id AS order_id, o.amount,
p.name AS product_name
FROM users u
INNER JOIN orders o ON u.id = o.user_id AND o.status = 1
LEFT JOIN order_items oi ON o.id = oi.order_id
INNER JOIN products p ON oi.product_id = p.id
WHERE u.status = 1 AND u.created_at > '2024-01-01'
GROUP BY u.id, o.id
HAVING SUM(oi.quantity) > 0
ORDER BY o.created_at DESC`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.GroupBy, "GROUP")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// 示例 3: 分页 COUNT 查询 (替代正则表达式场景)
func TestParseSQL_CountForPaging(t *testing.T) {
// 这是 dialect.go 中 selectCount 方法需要处理的典型 SQL
originalSQL := `SELECT DISTINCT u.name, u.age FROM users u WHERE u.age > 18 GROUP BY u.name ORDER BY u.name`
// 去掉 ORDER BY
parts := parseSQL(originalSQL)
if parts.OrderBy.Start != parts.OrderBy.End {
countSQL := originalSQL[:parts.OrderBy.Start]
if strings.Contains(strings.ToUpper(countSQL), "ORDER") {
t.Errorf("去掉 ORDER BY 后不应包含 ORDER 关键字:%s", countSQL)
}
}
// 检查是否有 GROUP BY 或 DISTINCT
hasGroupBy := parts.GroupBy.Start != parts.GroupBy.End
hasDistinct := parts.Distinct.Start != parts.Distinct.End
if !hasGroupBy && !hasDistinct {
t.Errorf("应检测到 GROUP BY 或 DISTINCT")
}
}
// ---------------- TestParseSQL_SpecialCases ----------------
// 测试边界情况
// 测试字符串未闭合的情况
func TestParseSQL_UnclosedString(t *testing.T) {
sql := `SELECT * FROM users WHERE name = 'unclosed`
parts := parseSQL(sql)
// 不应 panic, 能正常解析
if parts.Select.Start != 0 {
t.Errorf("Select Start 应为 0")
}
}
// 测试注释未闭合的情况
func TestParseSQL_UnclosedComment(t *testing.T) {
sql := `SELECT /* unclosed comment * FROM users`
parts := parseSQL(sql)
// 不应 panic, 能正常解析
if parts.Select.Start != 0 {
t.Errorf("Select Start 应为 0")
}
}
// 测试连续括号
func TestParseSQL_NestedBrackets(t *testing.T) {
sql := `SELECT (((1 + 2) * 3) - 4) AS result FROM (((users)))`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
}
// 测试关键字在括号内 (子查询中的 FROM 不应影响外层)
func TestParseSQL_KeywordInSubquery(t *testing.T) {
sql := `SELECT (SELECT name FROM inner_table WHERE id = 1) AS name FROM outer_table WHERE status = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// 验证 FROM 指向 outer_table 而不是 inner_table
fromContent := sql[parts.From.Start:parts.From.End]
if !strings.Contains(fromContent, "outer_table") {
t.Errorf("FROM 应指向 outer_table, 实际:%s", fromContent)
}
}
// ---------------- 进阶场景 / Advanced Scenarios ----------------
// ---------------- TestParseSQL_CaseWhen ----------------
// 测试 CASE WHEN 语句
func TestParseSQL_CaseWhen(t *testing.T) {
sql := `SELECT id,
CASE
WHEN status = 1 THEN 'active'
WHEN status = 0 THEN 'inactive'
ELSE 'unknown'
END AS status_name
FROM users
WHERE created_at > '2024-01-01'`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// 验证 CASE WHEN 中的 THEN/ELSE 不会被误解析
selectContent := sql[parts.Select.Start:parts.From.Start]
if strings.Contains(strings.ToUpper(selectContent), "FROM") {
t.Errorf("SELECT 子句中不应包含 FROM 关键字")
}
}
// ---------------- TestParseSQL_CaseWhenInWhere ----------------
// 测试 WHERE 中包含 CASE WHEN 的 SQL
func TestParseSQL_CaseWhenInWhere(t *testing.T) {
sql := `SELECT * FROM users
WHERE CASE
WHEN age > 18 THEN status
ELSE 'pending'
END = 'active'`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
}
// ---------------- TestParseSQL_WindowFunction ----------------
// 测试窗口函数 OVER() 的 SQL
func TestParseSQL_WindowFunction(t *testing.T) {
sql := `SELECT id, name, salary,
ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) AS rn,
AVG(salary) OVER (PARTITION BY dept) AS avg_salary
FROM employees
WHERE status = 1`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// 验证 OVER() 中的 ORDER BY 不会影响外层 ORDER BY 解析
// (本例中没有外层 ORDER BY, 所以 OrderBy 应未被设置)
if parts.OrderBy.Start != parts.OrderBy.End {
t.Errorf("此 SQL 不应解析出 ORDER BY 子句 (OVER 中的 ORDER BY 是窗口函数的一部分) ")
}
}
// ---------------- TestParseSQL_WindowFunctionWithOrderBy ----------------
// 测试窗口函数 + 外层 ORDER BY 的 SQL
func TestParseSQL_WindowFunctionWithOrderBy(t *testing.T) {
sql := `SELECT id, name,
ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) AS rn
FROM employees
WHERE status = 1
ORDER BY id DESC`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.OrderBy, "ORDER")
// 验证 ORDER BY 指向外层的 ORDER BY id DESC
orderByContent := sql[parts.OrderBy.Start:parts.OrderBy.End]
if !strings.Contains(orderByContent, "id DESC") {
t.Errorf("ORDER BY 应包含外层的 id DESC, 实际:%s", orderByContent)
}
}
// ---------------- TestParseSQL_CTE ----------------
// 测试 CTE (WITH 子句) 的 SQL
func TestParseSQL_CTE(t *testing.T) {
sql := `WITH active_users AS (
SELECT id, name FROM users WHERE status = 1
),
order_counts AS (
SELECT user_id, COUNT(*) AS cnt FROM orders GROUP BY user_id
)
SELECT u.id, u.name, oc.cnt
FROM active_users u
LEFT JOIN order_counts oc ON u.id = oc.user_id
WHERE oc.cnt > 5`
parts := parseSQL(sql)
// WITH 子句会被包含在 SELECT 中
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
// 验证 FROM 包含 CTE 定义的表
fromContent := sql[parts.From.Start:parts.From.End]
if !strings.Contains(fromContent, "active_users") {
t.Errorf("FROM 应包含 CTE 表 active_users")
}
if !strings.Contains(fromContent, "order_counts") {
t.Errorf("FROM 应包含 CTE 表 order_counts")
}
}
// ---------------- TestParseSQL_CTEWithOrderBy ----------------
// 测试 CTE + ORDER BY 的 SQL
func TestParseSQL_CTEWithOrderBy(t *testing.T) {
sql := `WITH ranked_users AS (
SELECT id, name, ROW_NUMBER() OVER (ORDER BY created_at DESC) AS rn
FROM users
)
SELECT * FROM ranked_users WHERE rn <= 10 ORDER BY rn`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.Where, "WHERE")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// ---------------- TestParseSQL_RecursiveCTE ----------------
// 测试递归 CTE 的 SQL
func TestParseSQL_RecursiveCTE(t *testing.T) {
sql := `WITH RECURSIVE hierarchy AS (
SELECT id, parent_id, name, 0 AS level
FROM categories WHERE parent_id IS NULL
UNION ALL
SELECT c.id, c.parent_id, c.name, h.level + 1
FROM categories c
INNER JOIN hierarchy h ON c.parent_id = h.id
)
SELECT * FROM hierarchy ORDER BY level, name`
parts := parseSQL(sql)
assertPart(t, sql, parts.Select, "SELECT")
assertPart(t, sql, parts.From, "FROM")
assertPart(t, sql, parts.OrderBy, "ORDER")
}
// ---------------- TestParseSQL_SelectInto ----------------
// 测试 SELECT INTO 语句 (MySQL)
func TestParseSQL_SelectInto(t *testing.T) {
// MySQL 的 SELECT INTO 语法
sql := `SELECT id, name INTO @var_id, @var_name FROM users WHERE id = 1`
parts := parseSQL(sql)