diff --git a/internal/mysqldump/escape.go b/internal/mysqldump/escape.go index 55db0a81..d933cbaf 100644 --- a/internal/mysqldump/escape.go +++ b/internal/mysqldump/escape.go @@ -4,6 +4,16 @@ import ( "bytes" ) +var unescapeMap = map[byte]byte{ + '0': 0, + 'n': '\n', + 'r': '\r', + '\\': '\\', + '\'': '\'', + '"': '"', + 'Z': '\032', +} + func escape(str string) string { var esc string var buf bytes.Buffer @@ -34,3 +44,18 @@ func escape(str string) string { _, _ = buf.WriteString(str[last:]) return buf.String() } + +func unescape(str string) string { + var buf bytes.Buffer + for i := 0; i < len(str); i++ { + if str[i] == '\\' && i+1 < len(str) { + if unescaped, ok := unescapeMap[str[i+1]]; ok { + buf.WriteByte(unescaped) + i++ + continue + } + } + buf.WriteByte(str[i]) + } + return buf.String() +} diff --git a/internal/mysqldump/escape_test.go b/internal/mysqldump/escape_test.go index 6d73b3b5..9794398e 100644 --- a/internal/mysqldump/escape_test.go +++ b/internal/mysqldump/escape_test.go @@ -12,3 +12,39 @@ func TestEscape(t *testing.T) { result := escape(input) assert.Equal(t, expected, result) } + +func TestUnescape(t *testing.T) { + input := `\0\n\r\\\'\"\Za` + expected := string([]byte{0, '\n', '\r', '\\', '\'', '"', '\032', 'a'}) + result := unescape(input) + assert.Equal(t, expected, result) +} + +func TestEscapeUnescape_RoundTrip(t *testing.T) { + testCases := []string{ + "simple text", + "text with\nnewline", + "text with\rcarriage return", + `text with "double quotes"`, + "text with 'single quotes'", + `text with \backslash`, + "null\x00byte", + `json_extract(\'$.taxStatus\')`, + } + + for _, original := range testCases { + escaped := escape(original) + unescaped := unescape(escaped) + assert.Equal(t, original, unescaped, "round-trip failed for: %q", original) + } +} + +func TestUnescape_UnknownSequencePassthrough(t *testing.T) { + input := `abc\xyz` + assert.Equal(t, input, unescape(input)) +} + +func TestUnescape_TrailingBackslashPassthrough(t *testing.T) { + input := "abc\\" + assert.Equal(t, input, unescape(input)) +} diff --git a/internal/mysqldump/schema.go b/internal/mysqldump/schema.go index 409061bd..5d98e2d4 100644 --- a/internal/mysqldump/schema.go +++ b/internal/mysqldump/schema.go @@ -212,7 +212,8 @@ func (col *ColumnSchema) writeCharsetAndCollation(b *strings.Builder, tableColla func (col *ColumnSchema) writeGeneratedOrDefault(b *strings.Builder) { if col.GenerationExpr.Valid && col.GenerationExpr.String != "" { b.WriteString(" GENERATED ALWAYS AS (") - b.WriteString(col.GenerationExpr.String) + // INFORMATION_SCHEMA can return escaped quotes in generation expressions. + b.WriteString(unescape(col.GenerationExpr.String)) b.WriteString(")") if col.IsVirtual { b.WriteString(" VIRTUAL") diff --git a/internal/mysqldump/schema_test.go b/internal/mysqldump/schema_test.go index 3555d916..0a2b778c 100644 --- a/internal/mysqldump/schema_test.go +++ b/internal/mysqldump/schema_test.go @@ -211,6 +211,37 @@ func TestTableSchema_BuildCreateTableSQL_GeneratedColumn(t *testing.T) { assert.Contains(t, sql, "GENERATED ALWAYS AS (CONCAT(first_name, ' ', last_name)) VIRTUAL") } +func TestTableSchema_BuildCreateTableSQL_GeneratedColumn_WithEscapedQuotes(t *testing.T) { + // This test covers issue #846: generation expressions from INFORMATION_SCHEMA.COLUMNS + // contain escaped quotes that need to be unescaped + schema := &TableSchema{ + Name: "b2b_components_pending_order", + Engine: "InnoDB", + Charset: "utf8mb4", + Collation: "utf8mb4_unicode_ci", + Columns: []ColumnSchema{ + {Name: "id", Type: "int", Nullable: false, Extra: "AUTO_INCREMENT"}, + {Name: "price", Type: "json", Nullable: false}, + { + Name: "tax_status", + Type: "varchar(255)", + Nullable: true, + // This simulates what MySQL's INFORMATION_SCHEMA.COLUMNS returns with escaped quotes + GenerationExpr: sql.NullString{String: "json_unquote(json_extract(`price`,_utf8mb4\\'$.taxStatus\\'))", Valid: true}, + IsVirtual: true, + }, + }, + PrimaryKey: []string{"id"}, + } + + sql := schema.BuildCreateTableSQL() + + // The output should have unescaped quotes + assert.Contains(t, sql, "GENERATED ALWAYS AS (json_unquote(json_extract(`price`,_utf8mb4'$.taxStatus'))) VIRTUAL") + // Ensure the broken version with backslash-escaped quotes is NOT in the output + assert.NotContains(t, sql, "\\'") +} + func TestTableSchema_BuildCreateTableSQL_ColumnCollation(t *testing.T) { schema := &TableSchema{ Name: "test",