-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsql.go
More file actions
449 lines (388 loc) · 12.2 KB
/
sql.go
File metadata and controls
449 lines (388 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
package QueryHelper
import (
"context"
"database/sql"
"fmt"
"github.com/Seann-Moser/go-serve/pkg/ctxLogger"
"github.com/jmoiron/sqlx"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"go.uber.org/zap"
"os"
"sort"
"strconv"
"strings"
"time"
)
var _ DB = &SqlDB{}
type SqlDB struct {
sql *sqlx.DB
updateColumns bool
tablePrefix string
}
func Flags() *pflag.FlagSet {
fs := pflag.NewFlagSet("sql-db", pflag.ExitOnError)
fs.Bool("sql-db-update-columns", false, "")
fs.String("sql-db-prefix", "", "")
return fs
}
func NewSql(db *sqlx.DB) *SqlDB {
return &SqlDB{
sql: db,
updateColumns: viper.GetBool("sql-db-update-columns"),
tablePrefix: viper.GetString("sql-db-prefix"),
}
}
func (s *SqlDB) Ping(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
return s.sql.PingContext(ctx)
}
func (s *SqlDB) Close() {
_ = s.sql.Close()
}
func (s *SqlDB) GetDataset(ds string) string {
return fmt.Sprintf("%s%s", s.tablePrefix, ds)
}
func (s *SqlDB) BuildCreateTableQueries(dataset, table string, columns map[string]Column) (string, string, error) {
// Build the CREATE SCHEMA statement
createSchemaStatement := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS `%s`", dataset)
// Initialize variables
var primaryKeys []string
var foreignKeys []string
createTableStatement := fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s`.`%s` (", dataset, table)
// Convert the columns map to a slice and sort them
var cols []Column
for _, column := range columns {
cols = append(cols, column)
}
sort.Slice(cols, func(i, j int) bool {
return cols[i].ColumnOrder < cols[j].ColumnOrder
})
var primaryKeyColumns []Column
// Build column definitions
for _, column := range cols {
def := column.GetDefinition()
createTableStatement += def + ","
if column.HasFK() {
fk, err := column.GetFK()
if err != nil {
return "", "", err
}
foreignKeys = append(foreignKeys, fk)
}
if column.Primary {
primaryKeys = append(primaryKeys, column.Name)
primaryKeyColumns = append(primaryKeyColumns, column)
}
}
//if _, err := s.CheckPrimaryKeyLength(primaryKeyColumns); err != nil {
// return "", "", err
//}
// Handle primary keys
if len(primaryKeys) == 0 {
return "", "", MissingPrimaryKeyErr
} else if len(primaryKeys) == 1 {
createTableStatement += fmt.Sprintf("\n\tPRIMARY KEY(`%s`)", primaryKeys[0])
} else {
createTableStatement += fmt.Sprintf("\n\tCONSTRAINT `PK_%s_%s` PRIMARY KEY (%s)", dataset, table, joinQuoted(primaryKeys, ","))
}
// Append foreign keys if any
if len(foreignKeys) > 0 {
createTableStatement += "," + strings.Join(foreignKeys, ",")
}
createTableStatement += "\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4"
return createSchemaStatement, createTableStatement, nil
}
func (s *SqlDB) CreateTable(ctx context.Context, dataset, table string, columns map[string]Column) error {
// Build the SQL queries
createSchemaStatement, createTableStatement, err := s.BuildCreateTableQueries(dataset, table, columns)
if err != nil {
return fmt.Errorf("failed BuildCreateTableQueries: %w", err)
}
// Execute the SQL statements
for _, stmt := range []string{createSchemaStatement, createTableStatement} {
_, err := s.sql.ExecContext(ctx, stmt)
if err != nil {
ctxLogger.Error(ctx, "failed creating tables", zap.Error(err), zap.String("statement", stmt))
return err
}
}
// Optionally update columns
if s.updateColumns {
return s.ColumnUpdater(ctx, dataset, table, columns)
}
return nil
}
// Helper function to quote identifiers
func joinQuoted(items []string, sep string) string {
quotedItems := make([]string, len(items))
for i, item := range items {
quotedItems[i] = fmt.Sprintf("`%s`", item)
}
return strings.Join(quotedItems, sep)
}
func (s *SqlDB) QueryContext(ctx context.Context, query string, options *DBOptions, args interface{}) (DBRow, error) {
if options == nil || !(options.NoLock || options.ReadPast) {
return s.sql.NamedQueryContext(ctx, query, args)
}
tx, err := s.sql.BeginTxx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("error starting transaction: %w", err)
}
_, err = tx.ExecContext(ctx, "SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED")
if err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("error setting transaction: %w", err)
}
rows, err := tx.NamedQuery(query, args)
if err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("error executing query: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, fmt.Errorf("error committing query: %w", err)
}
return rows, nil
}
func (s *SqlDB) RawQueryContext(ctx context.Context, query string, options *DBOptions, args ...interface{}) (DBRow, error) {
defer func() { //catch or finally
if err := recover(); err != nil { //catch
fmt.Fprintf(os.Stderr, "Exception: %v\n", err)
os.Exit(1)
}
}()
if options == nil || !(options.NoLock || options.ReadPast) {
return s.sql.QueryxContext(ctx, query, args...)
}
tx, err := s.sql.BeginTxx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("error starting transaction: %w", err)
}
_, err = tx.Exec("SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED")
if err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("error setting transaction: %w", err)
}
rows, err := tx.QueryxContext(ctx, query, args...)
if err != nil {
_ = tx.Rollback()
return nil, fmt.Errorf("error executing query: %w", err)
}
if err = tx.Commit(); err != nil {
return nil, fmt.Errorf("error committing query: %w", err)
}
return rows, nil
}
func (s *SqlDB) ExecContext(ctx context.Context, query string, args interface{}) error {
defer func() { //catch or finally
if err := recover(); err != nil { //catch
fmt.Fprintf(os.Stderr, "Exception: %v\n", err)
}
}()
tx, err := s.sql.BeginTxx(ctx, nil)
if err != nil {
return fmt.Errorf("error starting transaction: %w", err)
}
_, err = tx.NamedExecContext(ctx, query, args)
if err != nil {
ctxLogger.Warn(ctx, "rolled back transaction", zap.String("query", query), zap.Any("args", args), zap.Error(err))
_ = tx.Rollback()
return err
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("error committing query: %w", err)
}
return nil
}
func (s *SqlDB) ColumnUpdater(ctx context.Context, dataset, table string, columns map[string]Column) error {
cols, err := getColumns(ctx, s.sql, dataset, table)
if err != nil {
return err
}
var addColumns []*Column
var removeColumns []*sql.ColumnType
colMap := map[string]*sql.ColumnType{}
for _, c := range cols {
colMap[c.Name()] = c
}
for _, e := range columns {
if _, found := colMap[e.Name]; !found {
addColumns = append(addColumns, &e)
}
}
for _, c := range cols {
if _, found := columns[c.Name()]; !found {
removeColumns = append(removeColumns, c)
}
}
alterTable := fmt.Sprintf("ALTER TABLE %s.%s ", dataset, table)
if len(addColumns) > 0 {
addStmt := generateColumnStatements(alterTable, "add", addColumns)
ctxLogger.Debug(ctx, "adding columns to table", zap.String("query", addStmt))
_, err := s.sql.ExecContext(ctx, addStmt)
if err != nil {
return err
}
}
if len(removeColumns) > 0 {
removeStmt := generateColumnTypeStatements(alterTable, "remove", removeColumns)
ctxLogger.Debug(ctx, "removing columns from table", zap.String("table", table), zap.String("query", removeStmt))
_, err := s.sql.ExecContext(ctx, removeStmt)
if err != nil {
return err
}
}
return nil
}
func getColumns(ctx context.Context, db *sqlx.DB, dataset, table string) ([]*sql.ColumnType, error) {
if db == nil {
return nil, nil
}
rows, err := db.QueryxContext(ctx, fmt.Sprintf("SELECT * FROM %s.%s limit 1;", dataset, table))
if err != nil {
return nil, err
}
cols, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
return cols, nil
}
func generateColumnTypeStatements(alterTable, columnType string, e []*sql.ColumnType) string {
output := []string{}
for _, el := range e {
output = append(output, generateColumnTypeStmt(columnType, el))
}
return fmt.Sprintf("%s %s;", alterTable, strings.Join(output, ","))
}
func generateColumnStatements(alterTable, columnType string, e []*Column) string {
output := []string{}
for _, el := range e {
output = append(output, generateColumnStmt(columnType, el))
}
return fmt.Sprintf("%s %s;", alterTable, strings.Join(output, ","))
}
func generateColumnStmt(columnType string, e *Column) string {
switch strings.ToLower(columnType) {
case "drop":
return fmt.Sprintf("DROP COLUMN %s;", e.Name)
case "add":
return fmt.Sprintf("ADD %s", e.GetDefinition())
}
return ""
}
func generateColumnTypeStmt(columnType string, e *sql.ColumnType) string {
switch strings.ToLower(columnType) {
case "drop":
return fmt.Sprintf("DROP COLUMN %s", e.Name())
case "add":
return fmt.Sprintf("ADD %s", e.Name())
}
return ""
}
type IndexInfo struct {
IndexName string `db:"INDEX_NAME" json:"index_name"`
ColumnName string `db:"COLUMN_NAME" json:"column_name"`
NonUnique int `db:"NON_UNIQUE" json:"non_unique"`
SeqInIndex int `db:"SEQ_IN_INDEX" json:"seq_in_index"`
}
type ColumnInfo struct {
ColumnName string `db:"COLUMN_NAME" json:"column_name"`
ColumnType string `db:"COLUMN_TYPE" json:"column_type"`
IsNullable string `db:"IS_NULLABLE" json:"is_nullable"`
ColumnKey string `db:"COLUMN_KEY" json:"column_key"`
ColumnDefault string `db:"COLUMN_DEFAULT" json:"column_default"`
Extra string `db:"EXTRA" json:"extra"`
}
func (s *SqlDB) GetTableDefinition(database string, tableName string) ([]ColumnInfo, error) {
query := `SELECT COLUMN_NAME, COLUMN_TYPE, IS_NULLABLE, COLUMN_KEY, COLUMN_DEFAULT, EXTRA
FROM information_schema.columns
WHERE table_schema = ? AND table_name = ?`
var columns []ColumnInfo
err := s.sql.Select(&columns, query, database, tableName)
if err != nil {
return nil, err
}
return columns, nil
}
func (s *SqlDB) GetTableIndexes(database, tableName string) ([]IndexInfo, error) {
query := `SELECT INDEX_NAME, COLUMN_NAME, NON_UNIQUE, SEQ_IN_INDEX
FROM information_schema.statistics
WHERE table_schema = ? AND table_name = ?`
var indexes []IndexInfo
err := s.sql.Select(&indexes, query, database, tableName)
if err != nil {
return nil, err
}
return indexes, nil
}
func (s *SqlDB) Version() string {
if s.sql == nil {
return "8.0.40"
}
v, err := GetMySQLVersion(s.sql)
if err != nil {
return "unknown"
}
return v
}
const defaultMaxPrimaryKeyLength = 767
// CheckPrimaryKeyLength checks if the combined byte length of primary key columns exceeds the limit
func (s *SqlDB) CheckPrimaryKeyLength(columns []Column) (bool, error) {
// Get MySQL version to adjust max primary key length if needed
version := s.Version()
maxPrimaryKeyLength := defaultMaxPrimaryKeyLength
// Example check for newer MySQL versions (adjust as needed for version-specific handling)
if CompareVersions(version, "8.0.17") > 0 {
maxPrimaryKeyLength = 3072 // Increased max length for MySQL 8.0.17+ with InnoDB and utf8mb4
}
// Calculate total byte length of the primary key columns
totalLength := 0
for _, col := range columns {
if col.Primary {
totalLength += col.GetByteLength()
}
}
// Check if total length exceeds the maximum allowed length
if totalLength > maxPrimaryKeyLength {
return true, fmt.Errorf("primary key length exceeds the maximum allowed length of %d bytes", maxPrimaryKeyLength)
}
return false, nil
}
// GetMySQLVersion retrieves the MySQL version from the database.
func GetMySQLVersion(db *sqlx.DB) (string, error) {
var version string
err := db.Get(&version, "SELECT VERSION()")
if err != nil {
return "", fmt.Errorf("failed to get MySQL version: %w", err)
}
return version, nil
}
func CompareVersions(version1, version2 string) int {
v1Parts := strings.Split(version1, ".")
v2Parts := strings.Split(version2, ".")
// Compare each part numerically
maxParts := len(v1Parts)
if len(v2Parts) > maxParts {
maxParts = len(v2Parts)
}
for i := 0; i < maxParts; i++ {
var v1, v2 int
// Convert the current part to an integer or assume 0 if part is missing
if i < len(v1Parts) {
v1, _ = strconv.Atoi(v1Parts[i])
}
if i < len(v2Parts) {
v2, _ = strconv.Atoi(v2Parts[i])
}
// Compare the individual parts
if v1 > v2 {
return 1
} else if v1 < v2 {
return -1
}
}
// Versions are equal
return 0
}