diff --git a/consilens-cli/src/main/java/com/consilens/cli/model/normalization/TypeNormalizationRule.java b/consilens-cli/src/main/java/com/consilens/cli/model/normalization/TypeNormalizationRule.java index 7af0517..042796c 100644 --- a/consilens-cli/src/main/java/com/consilens/cli/model/normalization/TypeNormalizationRule.java +++ b/consilens-cli/src/main/java/com/consilens/cli/model/normalization/TypeNormalizationRule.java @@ -29,6 +29,12 @@ public class TypeNormalizationRule { */ @JsonProperty("timezone") private String timezone; + + /** + * Temporal comparison mode for date/datetime/timestamp values. + */ + @JsonProperty("comparisonMode") + private String comparisonMode; /** * Encoding method (hex/base64). diff --git a/consilens-cli/src/main/java/com/consilens/cli/service/DiffService.java b/consilens-cli/src/main/java/com/consilens/cli/service/DiffService.java index 8605d24..0daa14b 100644 --- a/consilens-cli/src/main/java/com/consilens/cli/service/DiffService.java +++ b/consilens-cli/src/main/java/com/consilens/cli/service/DiffService.java @@ -53,23 +53,26 @@ public class DiffService { * @return diff result */ public CliDiffResult performDiff(CliConfiguration config) throws Exception { - log.info("Starting diff operation with strategy: {}, algorithm: {}", + log.info("Starting diff operation with strategy: {}, algorithm: {}", config.getStrategyMode(), config.getAlgorithm()); long startTime = System.currentTimeMillis(); DiffLifecycle lifecycle = buildLifecycle(config); DiffContext diffContext = buildDiffContext(config); + CliDiffResult result = null; + Exception failure = null; try { lifecycle.onDiffStart(diffContext); - CompareRuntime runtime = new DefaultCompareRuntime(); + CompareRuntime runtime = createCompareRuntime(); DiffResult coreResult = runtime.execute(toCompareRequest(config)); + publishDifferences(coreResult, lifecycle, diffContext); lifecycle.onDiffComplete(coreResult, diffContext); // Convert core result to CLI result - CliDiffResult result = convertToCLIResult(coreResult, config.getStrategyMode(), config); + result = convertToCLIResult(coreResult, config.getStrategyMode(), config); if (result.getInfoTree() != null) { log.info(formatInfoTree(result.getInfoTree())); } @@ -80,26 +83,37 @@ public CliDiffResult performDiff(CliConfiguration config) throws Exception { log.info("Diff operation completed in {} ms with {} differences", duration, result.getTotalDifferences()); - return result; - } catch (Exception e) { log.error("Diff operation failed", e); + failure = new Exception("Diff operation failed: " + e.getMessage(), e); try { lifecycle.onDiffError(diffContext, e); } catch (Exception lifecycleEx) { - log.warn("Lifecycle onDiffError failed", lifecycleEx); + failure.addSuppressed(lifecycleEx); } - throw new Exception("Diff operation failed: " + e.getMessage(), e); } finally { try { lifecycle.close(); - } catch (Exception e) { - log.warn("Lifecycle close failed", e); + } catch (Exception closeEx) { + if (failure != null) { + failure.addSuppressed(closeEx); + } else { + failure = new Exception("Lifecycle close failed: " + closeEx.getMessage(), closeEx); + } } } + + if (failure != null) { + throw failure; + } + return result; + } + + protected CompareRuntime createCompareRuntime() { + return new DefaultCompareRuntime(); } - private DiffLifecycle buildLifecycle(CliConfiguration config) { + protected DiffLifecycle buildLifecycle(CliConfiguration config) { ResultConfig resultConfig = config.getResult(); if (resultConfig == null || resultConfig.getSinks() == null || resultConfig.getSinks().isEmpty()) { return new NoopDiffLifecycle(); @@ -107,7 +121,7 @@ private DiffLifecycle buildLifecycle(CliConfiguration config) { return new DefaultDiffLifecycle(resultConfig); } - private DiffContext buildDiffContext(CliConfiguration config) { + protected DiffContext buildDiffContext(CliConfiguration config) { TablePath sourcePath = null; TablePath targetPath = null; List sourceColumnNames = new ArrayList<>(); @@ -152,6 +166,13 @@ private DiffContext buildDiffContext(CliConfiguration config) { .build(); } + private void publishDifferences(DiffResult coreResult, DiffLifecycle lifecycle, DiffContext diffContext) throws Exception { + if (coreResult == null || coreResult.getDifferences() == null || coreResult.getDifferences().isEmpty()) { + return; + } + lifecycle.onDifferencesFound(coreResult.getDifferences(), diffContext); + } + /** * Convert core DiffResult to CLI DiffResult. */ @@ -178,7 +199,9 @@ private CliDiffResult convertToCLIResult(DiffResult coreResult, String strategy, .targetRowCount((int) stats.getTargetRowCount()) .differences(convertDiffRows(coreResult.getDifferences())) .tableMetadata(tableMetadata) - .infoTree(coreResult.getInfoTree().isPresent() ? coreResult.getInfoTree().orElse(null) : null) + .infoTree(coreResult.getInfoTree() != null && coreResult.getInfoTree().isPresent() + ? coreResult.getInfoTree().orElse(null) + : null) .build() .withSortKeyColumns(sortKeyColumns); } @@ -433,7 +456,7 @@ private List toNormalizationRules(String type, TypeNormalizat result.add(normalizationRule(type, "format_number", params)); } - if (rule.getFormat() != null || rule.getTimezone() != null) { + if (rule.getFormat() != null || rule.getTimezone() != null || rule.getComparisonMode() != null) { Map params = new LinkedHashMap<>(); if (rule.getFormat() != null) { params.put("format", rule.getFormat()); @@ -441,6 +464,9 @@ private List toNormalizationRules(String type, TypeNormalizat if (rule.getTimezone() != null) { params.put("timezone", rule.getTimezone()); } + if (rule.getComparisonMode() != null) { + params.put("comparisonMode", rule.getComparisonMode()); + } result.add(normalizationRule(type, "format_datetime", params)); } @@ -564,7 +590,7 @@ private Integer integerValue(Object value) { * Perform a dry run to validate configuration without executing diff. */ public CliDiffResult performDryRun(CliConfiguration config) throws Exception { - log.info("Performing dry run for diff operation with strategy: {}, algorithm: {}", + log.info("Performing dry run for diff operation with strategy: {}, algorithm: {}", config.getStrategyMode(), config.getAlgorithm()); ConnectorProbeService probeService = new ConnectorProbeService(); ComparisonConfig comparison = config.getComparison(); diff --git a/consilens-cli/src/test/java/com/consilens/cli/service/DiffServiceTest.java b/consilens-cli/src/test/java/com/consilens/cli/service/DiffServiceTest.java new file mode 100644 index 0000000..57c891c --- /dev/null +++ b/consilens-cli/src/test/java/com/consilens/cli/service/DiffServiceTest.java @@ -0,0 +1,172 @@ +package com.consilens.cli.service; + +import com.consilens.cli.model.CliConfiguration; +import com.consilens.cli.model.ComparisonConfig; +import com.consilens.cli.model.ConnectionConfig; +import com.consilens.cli.model.ListPairConfig; +import com.consilens.cli.model.LocalCompareConfig; +import com.consilens.cli.model.StrategyConfig; +import com.consilens.cli.model.StringPairConfig; +import com.consilens.cli.model.normalization.TypeNormalizationRule; +import com.consilens.connector.api.normalization.NormalizationRule; +import com.consilens.core.compare.CompareRuntime; +import com.consilens.core.diff.DiffResult; +import com.consilens.core.diff.DiffRow; +import com.consilens.core.lifecycle.DiffContext; +import com.consilens.core.lifecycle.DiffLifecycle; +import com.consilens.core.lifecycle.SegmentResult; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class DiffServiceTest { + + @Test + void shouldPublishDifferencesBeforeCompletingLifecycle() throws Exception { + RecordingLifecycle lifecycle = new RecordingLifecycle(false); + DiffResult diffResult = DiffResult.of( + List.of(DiffRow.removed(List.of(1), List.of("Alice"), List.of("name"))), + com.consilens.connector.api.model.TablePath.of("source_table"), + com.consilens.connector.api.model.TablePath.of("target_table")); + + TestableDiffService service = new TestableDiffService(lifecycle, request -> diffResult); + + service.performDiff(createConfig()); + + assertEquals(List.of("start", "differences", "complete", "close"), lifecycle.events); + } + + @Test + void shouldFailWhenLifecycleCloseFailsAfterSuccessfulDiff() { + RecordingLifecycle lifecycle = new RecordingLifecycle(true); + DiffResult diffResult = DiffResult.of( + List.of(DiffRow.removed(List.of(1), List.of("Alice"), List.of("name"))), + com.consilens.connector.api.model.TablePath.of("source_table"), + com.consilens.connector.api.model.TablePath.of("target_table")); + + TestableDiffService service = new TestableDiffService(lifecycle, request -> diffResult); + + Exception exception = assertThrows(Exception.class, () -> service.performDiff(createConfig())); + assertTrue(exception.getMessage().contains("Lifecycle close failed")); + assertEquals(List.of("start", "differences", "complete", "close"), lifecycle.events); + } + + @Test + @SuppressWarnings("unchecked") + void shouldConvertTemporalComparisonModeToNormalizationRule() throws Exception { + DiffService service = new DiffService(); + TypeNormalizationRule rule = new TypeNormalizationRule(); + rule.setTimezone("UTC"); + rule.setComparisonMode("DATE_ONLY"); + + Method method = DiffService.class.getDeclaredMethod("toNormalizationRules", String.class, TypeNormalizationRule.class); + method.setAccessible(true); + + List rules = (List) method.invoke(service, "timestamp", rule); + + assertEquals(1, rules.size()); + assertEquals("format_datetime", rules.get(0).getOperation()); + assertEquals(Map.of("timezone", "UTC", "comparisonMode", "DATE_ONLY"), rules.get(0).getParams()); + } + + private CliConfiguration createConfig() { + return CliConfiguration.builder() + .source(ConnectionConfig.builder() + .type("mysql") + .url("jdbc:mysql://localhost:3306/source_db") + .username("user") + .password("pwd") + .resource(ConnectionConfig.ResourceConfig.builder().type("table").name("source_table").build()) + .build()) + .target(ConnectionConfig.builder() + .type("mysql") + .url("jdbc:mysql://localhost:3306/target_db") + .username("user") + .password("pwd") + .resource(ConnectionConfig.ResourceConfig.builder().type("table").name("target_table").build()) + .build()) + .comparison(ComparisonConfig.builder() + .tables(StringPairConfig.builder().source("source_table").target("target_table").build()) + .keys(ListPairConfig.builder().source(List.of("id")).target(List.of("id")).build()) + .fields(ListPairConfig.builder().source(List.of("name")).target(List.of("name")).build()) + .build()) + .strategy(StrategyConfig.builder() + .mode("checksum") + .algorithm("concat") + .bisectionFactor(4) + .bisectionThreshold(1000L) + .batchSize(100) + .enableProfiling(false) + .localCompare(LocalCompareConfig.builder().mode("full").build()) + .build()) + .build(); + } + + private static class TestableDiffService extends DiffService { + private final DiffLifecycle lifecycle; + private final CompareRuntime runtime; + + private TestableDiffService(DiffLifecycle lifecycle, CompareRuntime runtime) { + this.lifecycle = lifecycle; + this.runtime = runtime; + } + + @Override + protected CompareRuntime createCompareRuntime() { + return runtime; + } + + @Override + protected DiffLifecycle buildLifecycle(CliConfiguration config) { + return lifecycle; + } + } + + private static class RecordingLifecycle implements DiffLifecycle { + private final boolean failOnClose; + private final List events = new ArrayList<>(); + + private RecordingLifecycle(boolean failOnClose) { + this.failOnClose = failOnClose; + } + + @Override + public void onDiffStart(DiffContext context) { + events.add("start"); + } + + @Override + public void onSegmentComplete(SegmentResult result) { + } + + @Override + public void onDifferencesFound(List diffs, DiffContext context) { + events.add("differences"); + } + + @Override + public void onDiffComplete(DiffResult result, DiffContext context) { + events.add("complete"); + } + + @Override + public void onDiffError(DiffContext context, Throwable error) { + events.add("error"); + } + + @Override + public void close() throws Exception { + events.add("close"); + if (failOnClose) { + throw new Exception("close failed"); + } + } + } +} diff --git a/consilens-connector/consilens-connector-api/src/main/java/com/consilens/connector/api/normalization/DefaultNormalizationOperationRegistry.java b/consilens-connector/consilens-connector-api/src/main/java/com/consilens/connector/api/normalization/DefaultNormalizationOperationRegistry.java index 4f55bbf..2ff3c94 100644 --- a/consilens-connector/consilens-connector-api/src/main/java/com/consilens/connector/api/normalization/DefaultNormalizationOperationRegistry.java +++ b/consilens-connector/consilens-connector-api/src/main/java/com/consilens/connector/api/normalization/DefaultNormalizationOperationRegistry.java @@ -30,7 +30,7 @@ private Map createDefinitions() { setOf("precision", "rounding"))); result.put("format_datetime", definition("format_datetime", setOf("date", "time", "datetime", "timestamp"), - setOf("format", "timezone"))); + setOf("format", "timezone", "comparisonMode"))); result.put("encode", definition("encode", setOf("binary"), setOf("format", "encoding", "uppercase"))); diff --git a/consilens-connector/consilens-connector-api/src/main/java/com/consilens/connector/api/normalization/DefaultNormalizationSpecValidator.java b/consilens-connector/consilens-connector-api/src/main/java/com/consilens/connector/api/normalization/DefaultNormalizationSpecValidator.java index 40f8c2e..e0d5645 100644 --- a/consilens-connector/consilens-connector-api/src/main/java/com/consilens/connector/api/normalization/DefaultNormalizationSpecValidator.java +++ b/consilens-connector/consilens-connector-api/src/main/java/com/consilens/connector/api/normalization/DefaultNormalizationSpecValidator.java @@ -17,6 +17,8 @@ public class DefaultNormalizationSpecValidator implements NormalizationSpecValid "datetime", "timestamp", "boolean", "binary", "json")); private static final Set SUPPORTED_BINARY_FORMATS = new HashSet<>(Arrays.asList("hex", "base64")); + private static final Set SUPPORTED_TEMPORAL_COMPARISON_MODES = new HashSet<>( + Arrays.asList("EXACT", "DATE_ONLY", "TRUNCATE_TO_SECOND", "TRUNCATE_TO_DAY")); private final NormalizationOperationRegistry operationRegistry; @@ -102,6 +104,7 @@ private void validateParams(Map params, } else if ("format_datetime".equals(operation)) { validateString(params.get("format"), "format", scope, operation, false); validateTimezone(params.get("timezone"), scope, operation); + validateTemporalComparisonMode(params.get("comparisonMode"), scope, operation); } else if ("encode".equals(operation)) { Object encoding = params.get("encoding") != null ? params.get("encoding") : params.get("format"); validateBinaryFormat(encoding, scope, operation); @@ -161,6 +164,18 @@ private void validateTimezone(Object value, String scope, String operation) { } } + private void validateTemporalComparisonMode(Object value, String scope, String operation) { + if (value == null) { + return; + } + validateString(value, "comparisonMode", scope, operation, false); + String normalized = ((String) value).trim().toUpperCase(); + if (!SUPPORTED_TEMPORAL_COMPARISON_MODES.contains(normalized)) { + throw new ConnectorException("Parameter 'comparisonMode' for operation '" + operation + "' in " + scope + + " must be one of " + SUPPORTED_TEMPORAL_COMPARISON_MODES); + } + } + private void validateBinaryFormat(Object value, String scope, String operation) { if (value == null) { return; diff --git a/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/main/java/com/consilens/conncetor/base/BaseDataTypeHandler.java b/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/main/java/com/consilens/conncetor/base/BaseDataTypeHandler.java index d452bb8..aaf7b4b 100644 --- a/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/main/java/com/consilens/conncetor/base/BaseDataTypeHandler.java +++ b/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/main/java/com/consilens/conncetor/base/BaseDataTypeHandler.java @@ -6,11 +6,35 @@ import lombok.extern.slf4j.Slf4j; import java.lang.reflect.Method; +import java.util.Locale; import java.util.Map; +import java.util.Set; @Slf4j public class BaseDataTypeHandler implements DataTypeHandler { + private static final Set DATE_ONLY_COMPARISON_MODES = Set.of("DATE_ONLY", "TRUNCATE_TO_DAY"); + + private static final Map DATA_TYPE_ALIASES = Map.ofEntries( + Map.entry("BOOL", DataType.BOOLEAN), + Map.entry("BIT", DataType.BIT), + Map.entry("BPCHAR", DataType.CHAR), + Map.entry("INT", DataType.INTEGER), + Map.entry("INT2", DataType.SMALLINT), + Map.entry("INT4", DataType.INTEGER), + Map.entry("INT8", DataType.BIGINT), + Map.entry("DOUBLE_PRECISION", DataType.DOUBLE), + Map.entry("CHARACTER_VARYING", DataType.VARCHAR), + Map.entry("CHARACTER", DataType.CHAR), + Map.entry("ENUM", DataType.VARCHAR), + Map.entry("SET", DataType.VARCHAR), + Map.entry("TIMESTAMP_WITH_TIME_ZONE", DataType.TIMESTAMP_WITH_TIMEZONE), + Map.entry("TIMESTAMP_WITHOUT_TIME_ZONE", DataType.TIMESTAMP), + Map.entry("TIME_WITH_TIME_ZONE", DataType.TIME_WITH_TIME_ZONE), + Map.entry("TIME_WITHOUT_TIME_ZONE", DataType.TIME), + Map.entry("DATETIME2", DataType.DATETIME) + ); + private final CapabilityProvider capabilityProvider; protected final Map normalizationConfig; @@ -31,25 +55,8 @@ public BaseDataTypeHandler(CapabilityProvider capabilityProvider, Map * @return configured precision, or defaultPrecision if not set */ protected int getPrecision(String dataTypeName, int defaultPrecision) { - if (normalizationConfig == null) { - return defaultPrecision; - } - - try { - Object rule = normalizationConfig.get(dataTypeName); - if (rule != null) { - // Get precision field via reflection - java.lang.reflect.Method getPrecisionMethod = rule.getClass().getMethod("getPrecision"); - Integer precision = (Integer) getPrecisionMethod.invoke(rule); - if (precision != null) { - return precision; - } - } - } catch (Exception e) { - log.debug("Failed to get precision config for type '{}': {}", dataTypeName, e.getMessage()); - } - - return defaultPrecision; + Integer precision = getRuleValue(dataTypeName, "getPrecision", Integer.class); + return precision != null ? precision : defaultPrecision; } /** @@ -60,25 +67,50 @@ protected int getPrecision(String dataTypeName, int defaultPrecision) { * @return configured rounding flag, or defaultRounding if not set */ protected boolean getRounding(String dataTypeName, boolean defaultRounding) { + Boolean rounding = getRuleValue(dataTypeName, "getRounding", Boolean.class); + return rounding != null ? rounding : defaultRounding; + } + + protected String getFormat(String dataTypeName, String defaultFormat) { + String format = getRuleValue(dataTypeName, "getFormat", String.class); + return format != null ? format : defaultFormat; + } + + protected String getTimezone(String dataTypeName, String defaultTimezone) { + String timezone = getRuleValue(dataTypeName, "getTimezone", String.class); + return timezone != null ? timezone : defaultTimezone; + } + + protected String getComparisonMode(String dataTypeName, String defaultMode) { + String comparisonMode = getRuleValue(dataTypeName, "getComparisonMode", String.class); + return comparisonMode != null ? comparisonMode.trim().toUpperCase(Locale.ROOT) : defaultMode; + } + + protected boolean isDateOnlyComparison(String dataTypeName) { + return DATE_ONLY_COMPARISON_MODES.contains(getComparisonMode(dataTypeName, "EXACT")); + } + + protected String escapeSqlLiteral(String value) { + return value == null ? null : value.replace("'", "''"); + } + + private T getRuleValue(String dataTypeName, String getterName, Class valueType) { if (normalizationConfig == null) { - return defaultRounding; + return null; } - + try { Object rule = normalizationConfig.get(dataTypeName); - if (rule != null) { - // Get rounding field via reflection - Method getRoundingMethod = rule.getClass().getMethod("getRounding"); - Boolean rounding = (Boolean) getRoundingMethod.invoke(rule); - if (rounding != null) { - return rounding; - } + if (rule == null) { + return null; } + Method getter = rule.getClass().getMethod(getterName); + Object value = getter.invoke(rule); + return valueType.isInstance(value) ? valueType.cast(value) : null; } catch (Exception e) { - log.debug("Failed to get rounding config for type '{}': {}", dataTypeName, e.getMessage()); + log.debug("Failed to get normalization config '{}' for type '{}': {}", getterName, dataTypeName, e.getMessage()); + return null; } - - return defaultRounding; } @Override @@ -88,8 +120,8 @@ public String normalizeColumn(String columnName, DataType dataType) { // Log normalization at DEBUG level log.debug("BaseDataTypeHandler.normalizeColumn: column='{}', dataType={}", columnName, dataType); - if (dataType == null) { - log.debug(" -> Using normalizeDefault (dataType is null)"); + if (dataType == null || dataType == DataType.UNKNOWN) { + log.warn("Unsupported data type normalization for column '{}', falling back to default normalization", columnName); return normalizeDefault(quotedCol); } @@ -118,6 +150,7 @@ public String normalizeColumn(String columnName, DataType dataType) { return normalizeDecimal(quotedCol); case BOOLEAN: + case BIT: log.debug(" -> Using normalizeBoolean"); return normalizeBoolean(quotedCol); @@ -154,7 +187,8 @@ public String normalizeColumn(String columnName, DataType dataType) { return normalizeJson(quotedCol); default: - log.debug(" -> Using normalizeDefault (default case)"); + log.warn("Unsupported data type normalization for column '{}': {}, falling back to default normalization", + columnName, dataType); return normalizeDefault(quotedCol); } } @@ -289,16 +323,22 @@ protected String normalizeDefault(String quotedCol) { @Override public DataType convertToDataType(String sourceType) { - if (sourceType == null) { + if (sourceType == null || sourceType.isBlank()) { return DataType.UNKNOWN; } - String upperType = sourceType.toUpperCase(); + + String canonicalType = canonicalizeTypeName(sourceType); + DataType aliasType = DATA_TYPE_ALIASES.get(canonicalType); + if (aliasType != null) { + return aliasType; + } + try { - return DataType.valueOf(upperType); + return DataType.valueOf(canonicalType); } catch (IllegalArgumentException e) { - // Try to find by display name or partial match for (DataType dt : DataType.values()) { - if (dt.getDisplayName().equalsIgnoreCase(upperType)) { + if (dt.getDisplayName().equalsIgnoreCase(canonicalType) + || dt.getDisplayName().replace(' ', '_').equalsIgnoreCase(canonicalType)) { return dt; } } @@ -306,6 +346,13 @@ public DataType convertToDataType(String sourceType) { } } + private String canonicalizeTypeName(String sourceType) { + String upperType = sourceType.trim().toUpperCase(Locale.ROOT); + upperType = upperType.replaceAll("\\s*\\([^)]*\\)", ""); + upperType = upperType.replaceAll("\\s+", " ").trim(); + return upperType.replace(' ', '_'); + } + @Override public String formatDataType(DataType dataType, int length, int precision, int scale) { @@ -395,4 +442,4 @@ protected boolean isBooleanType(String dataType) { dataType.equals("TINYINT(1)"); } -} \ No newline at end of file +} diff --git a/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/main/java/com/consilens/conncetor/base/jdbc/JdbcDatasetHandle.java b/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/main/java/com/consilens/conncetor/base/jdbc/JdbcDatasetHandle.java index f43108e..f24388a 100644 --- a/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/main/java/com/consilens/conncetor/base/jdbc/JdbcDatasetHandle.java +++ b/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/main/java/com/consilens/conncetor/base/jdbc/JdbcDatasetHandle.java @@ -111,10 +111,16 @@ public DatasetMetadata getMetadata() { @Override public SchemaDescriptor getSchema() throws ConnectorException { - if (schema == null) { - schema = discoverSchema(); + SchemaDescriptor local = schema; + if (local != null) { + return local; + } + synchronized (this) { + if (schema == null) { + schema = discoverSchema(); + } + return schema; } - return schema; } @Override @@ -719,6 +725,7 @@ private JdbcTypeNormalizationRule mapNormalizationRule(NormalizationRule rule) { } else if ("format_datetime".equals(operation)) { mapped.setFormat(stringValue(params.get("format"))); mapped.setTimezone(stringValue(params.get("timezone"))); + mapped.setComparisonMode(stringValue(params.get("comparisonMode"))); } else if ("encode".equals(operation)) { mapped.setEncoding(firstString(params.get("encoding"), params.get("format"))); mapped.setUppercase(booleanValue(params.get("uppercase"))); @@ -858,6 +865,7 @@ public static class JdbcTypeNormalizationRule { private Boolean rounding; private String format; private String timezone; + private String comparisonMode; private String encoding; private Boolean uppercase; private String trueValue; @@ -896,6 +904,14 @@ public void setTimezone(String timezone) { this.timezone = timezone; } + public String getComparisonMode() { + return comparisonMode; + } + + public void setComparisonMode(String comparisonMode) { + this.comparisonMode = comparisonMode; + } + public String getEncoding() { return encoding; } diff --git a/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/test/java/com/consilens/conncetor/base/BaseDataTypeHandlerTest.java b/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/test/java/com/consilens/conncetor/base/BaseDataTypeHandlerTest.java new file mode 100644 index 0000000..ddb4ac4 --- /dev/null +++ b/consilens-connector/consilens-connector-plugins/consilens-connector-base/src/test/java/com/consilens/conncetor/base/BaseDataTypeHandlerTest.java @@ -0,0 +1,91 @@ +package com.consilens.conncetor.base; + +import com.consilens.connector.api.CapabilityProvider; +import com.consilens.connector.api.enums.DatabaseFeature; +import com.consilens.connector.api.model.DataType; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class BaseDataTypeHandlerTest { + + private final BaseDataTypeHandler handler = new BaseDataTypeHandler(new TestCapabilityProvider()); + + @Test + void shouldFallbackForUnknownTypeNormalization() { + assertEquals("COALESCE(TRIM(CAST(\"payload\" AS VARCHAR)), '')", + handler.normalizeColumn("payload", DataType.UNKNOWN)); + assertEquals("COALESCE(TRIM(CAST(\"payload\" AS VARCHAR)), '')", + handler.normalizeColumn("payload", null)); + } + + @Test + void shouldConvertParameterizedAndAliasedTypes() { + assertEquals(DataType.VARCHAR, handler.convertToDataType("varchar(255)")); + assertEquals(DataType.DECIMAL, handler.convertToDataType("decimal(10,2)")); + assertEquals(DataType.TIMESTAMP_WITH_TIMEZONE, handler.convertToDataType("timestamp with time zone")); + assertEquals(DataType.DOUBLE, handler.convertToDataType("double precision")); + assertEquals(DataType.BIGINT, handler.convertToDataType("int8")); + assertEquals(DataType.BOOLEAN, handler.convertToDataType("bool")); + assertEquals(DataType.CHAR, handler.convertToDataType("bpchar")); + assertEquals(DataType.VARCHAR, handler.convertToDataType("enum")); + assertEquals(DataType.VARCHAR, handler.convertToDataType("set")); + } + + @Test + void shouldNormalizeBitAsBoolean() { + assertEquals("CASE WHEN \"flag\" = TRUE THEN '1' ELSE '0' END", + handler.normalizeColumn("flag", DataType.BIT)); + } + + private static class TestCapabilityProvider implements CapabilityProvider { + @Override + public boolean supportsFeature(DatabaseFeature feature) { + return false; + } + + @Override + public Set getSupportedFeatures() { + return Collections.emptySet(); + } + + @Override + public String getDefaultSchema() { + return "public"; + } + + @Override + public String getCatalogSeparator() { + return "."; + } + + @Override + public String getPaginationHint(long offset, long limit) { + return ""; + } + + @Override + public char getWildcardEscapeChar() { + return '\\'; + } + + @Override + public String escapePattern(String pattern) { + return pattern; + } + + @Override + public String getOpenQuote() { + return "\""; + } + + @Override + public String getCloseQuote() { + return "\""; + } + } +} diff --git a/consilens-connector/consilens-connector-plugins/consilens-connector-mysql/src/main/java/com/consilens/connector/mysql/MySQLDataTypeHandler.java b/consilens-connector/consilens-connector-plugins/consilens-connector-mysql/src/main/java/com/consilens/connector/mysql/MySQLDataTypeHandler.java index 7ce8502..8056eae 100644 --- a/consilens-connector/consilens-connector-plugins/consilens-connector-mysql/src/main/java/com/consilens/connector/mysql/MySQLDataTypeHandler.java +++ b/consilens-connector/consilens-connector-plugins/consilens-connector-mysql/src/main/java/com/consilens/connector/mysql/MySQLDataTypeHandler.java @@ -117,7 +117,8 @@ protected String normalizeFloat(String quotedCol) { */ @Override protected String normalizeDate(String quotedCol) { - return "COALESCE(DATE_FORMAT(" + quotedCol + ", '%Y-%m-%d'), '')"; + return "COALESCE(DATE_FORMAT(" + quotedCol + ", '" + resolveMySqlTemporalFormat("date", + "%Y-%m-%d", "%Y-%m-%d") + "'), '')"; } /** @@ -125,7 +126,8 @@ protected String normalizeDate(String quotedCol) { */ @Override protected String normalizeTime(String quotedCol) { - return "COALESCE(TIME_FORMAT(" + quotedCol + ", '%H:%i:%s'), '')"; + return "COALESCE(TIME_FORMAT(" + quotedCol + ", '" + resolveMySqlTemporalFormat("time", + "%H:%i:%s", "%H:%i:%s") + "'), '')"; } /** @@ -141,9 +143,10 @@ protected String normalizeTime(String quotedCol) { */ @Override protected String normalizeDateTime(String quotedCol) { - // DATETIME: Convert to UTC assuming it's in session timezone - // This ensures consistency with PostgreSQL TIMESTAMPTZ - return "COALESCE(DATE_FORMAT(CONVERT_TZ(" + quotedCol + ", @@session.time_zone, '+00:00'), '%Y-%m-%d %H:%i:%s'), '')"; + String targetTimezone = resolveMySqlTimezone("datetime", "+00:00"); + String format = resolveMySqlTemporalFormat("datetime", "%Y-%m-%d %H:%i:%s", "%Y-%m-%d"); + return "COALESCE(DATE_FORMAT(CONVERT_TZ(" + quotedCol + ", @@session.time_zone, '" + + targetTimezone + "'), '" + format + "'), '')"; } /** @@ -156,8 +159,10 @@ protected String normalizeDateTime(String quotedCol) { */ @Override protected String normalizeTimestamp(String quotedCol) { - // Same as normalizeDateTime for MySQL - return "COALESCE(DATE_FORMAT(CONVERT_TZ(" + quotedCol + ", @@session.time_zone, '+00:00'), '%Y-%m-%d %H:%i:%s'), '')"; + String targetTimezone = resolveMySqlTimezone("timestamp", "+00:00"); + String format = resolveMySqlTemporalFormat("timestamp", "%Y-%m-%d %H:%i:%s", "%Y-%m-%d"); + return "COALESCE(DATE_FORMAT(CONVERT_TZ(" + quotedCol + ", @@session.time_zone, '" + + targetTimezone + "'), '" + format + "'), '')"; } /** @@ -170,8 +175,10 @@ protected String normalizeTimestamp(String quotedCol) { */ @Override protected String normalizeTimestampWithTimezone(String quotedCol) { - // MySQL TIMESTAMP is already timezone-aware, same as normalizeTimestamp - return "COALESCE(DATE_FORMAT(CONVERT_TZ(" + quotedCol + ", @@session.time_zone, '+00:00'), '%Y-%m-%d %H:%i:%s'), '')"; + String targetTimezone = resolveMySqlTimezone("timestamp", "+00:00"); + String format = resolveMySqlTemporalFormat("timestamp", "%Y-%m-%d %H:%i:%s", "%Y-%m-%d"); + return "COALESCE(DATE_FORMAT(CONVERT_TZ(" + quotedCol + ", @@session.time_zone, '" + + targetTimezone + "'), '" + format + "'), '')"; } /** @@ -210,6 +217,53 @@ protected String normalizeDefault(String quotedCol) { return "COALESCE(TRIM(CAST(" + quotedCol + " AS CHAR)), '0')"; } + private String resolveMySqlTimezone(String dataTypeName, String defaultTimezone) { + String timezone = getTimezone(dataTypeName, defaultTimezone); + if ("UTC".equalsIgnoreCase(timezone)) { + return "+00:00"; + } + return escapeSqlLiteral(timezone); + } + + private String resolveMySqlTemporalFormat(String dataTypeName, String defaultFormat, String dateOnlyDefaultFormat) { + String configuredFormat = getFormat(dataTypeName, null); + String effectiveDefault = isDateOnlyComparison(dataTypeName) ? dateOnlyDefaultFormat : defaultFormat; + if (configuredFormat == null || configuredFormat.isBlank()) { + return effectiveDefault; + } + return toMySqlDateFormat(configuredFormat, effectiveDefault); + } + + private String toMySqlDateFormat(String javaFormat, String fallbackFormat) { + if (!isSupportedJavaTemporalFormat(javaFormat)) { + log.warn("Unsupported MySQL temporal format '{}', falling back to '{}'", javaFormat, fallbackFormat); + return fallbackFormat; + } + return javaFormat + .replace("yyyy", "%Y") + .replace("MM", "%m") + .replace("dd", "%d") + .replace("HH", "%H") + .replace("mm", "%i") + .replace("ss", "%s"); + } + + private boolean isSupportedJavaTemporalFormat(String format) { + String residual = format + .replace("yyyy", "") + .replace("MM", "") + .replace("dd", "") + .replace("HH", "") + .replace("mm", "") + .replace("ss", "") + .replace("-", "") + .replace(":", "") + .replace(" ", "") + .replace("/", "") + .replace("T", ""); + return residual.isEmpty(); + } + @Override public DataType convertToDataType(String sourceType) { if (sourceType == null) { diff --git a/consilens-connector/consilens-connector-plugins/consilens-connector-mysql/src/test/java/com/consilens/connector/mysql/MySQLDataTypeHandlerTest.java b/consilens-connector/consilens-connector-plugins/consilens-connector-mysql/src/test/java/com/consilens/connector/mysql/MySQLDataTypeHandlerTest.java index de8a4f0..6de183e 100644 --- a/consilens-connector/consilens-connector-plugins/consilens-connector-mysql/src/test/java/com/consilens/connector/mysql/MySQLDataTypeHandlerTest.java +++ b/consilens-connector/consilens-connector-plugins/consilens-connector-mysql/src/test/java/com/consilens/connector/mysql/MySQLDataTypeHandlerTest.java @@ -1,9 +1,12 @@ package com.consilens.connector.mysql; import com.consilens.connector.api.model.DataType; +import com.consilens.conncetor.base.jdbc.JdbcDatasetHandle; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.*; /** @@ -34,6 +37,19 @@ void testNormalizeColumn_Timestamp() { result); } + @Test + void testNormalizeColumn_TimestampDateOnlyMode() { + JdbcDatasetHandle.JdbcTypeNormalizationRule rule = new JdbcDatasetHandle.JdbcTypeNormalizationRule(); + rule.setComparisonMode("DATE_ONLY"); + handler = new MySQLDataTypeHandler(capabilityProvider, Map.of("timestamp", rule)); + + String result = handler.normalizeColumn("created_at", DataType.TIMESTAMP); + + assertEquals( + "COALESCE(DATE_FORMAT(CONVERT_TZ(`created_at`, @@session.time_zone, '+00:00'), '%Y-%m-%d'), '')", + result); + } + @Test void testNormalizeColumn_Boolean() { String result = handler.normalizeColumn("is_active", DataType.BOOLEAN); @@ -42,6 +58,14 @@ void testNormalizeColumn_Boolean() { assertTrue(result.contains("'0'")); } + @Test + void testNormalizeColumn_Bit() { + String result = handler.normalizeColumn("is_active", DataType.BIT); + assertTrue(result.contains("CASE")); + assertTrue(result.contains("'1'")); + assertTrue(result.contains("'0'")); + } + @Test void testGetDataTypeMappingVarchar() { assertEquals("VARCHAR(100)", handler.getDataTypeMapping("varchar", 100, 0, 0)); diff --git a/consilens-connector/consilens-connector-plugins/consilens-connector-postgresql/src/main/java/com/consilens/connector/postgresql/PostgreSQLDataTypeHandler.java b/consilens-connector/consilens-connector-plugins/consilens-connector-postgresql/src/main/java/com/consilens/connector/postgresql/PostgreSQLDataTypeHandler.java index 0b442f6..1cb553c 100644 --- a/consilens-connector/consilens-connector-plugins/consilens-connector-postgresql/src/main/java/com/consilens/connector/postgresql/PostgreSQLDataTypeHandler.java +++ b/consilens-connector/consilens-connector-plugins/consilens-connector-postgresql/src/main/java/com/consilens/connector/postgresql/PostgreSQLDataTypeHandler.java @@ -124,7 +124,8 @@ protected String normalizeFloat(String quotedCol) { */ @Override protected String normalizeDate(String quotedCol) { - return "COALESCE(TO_CHAR(" + quotedCol + ", 'YYYY-MM-DD'), '')"; + return "COALESCE(TO_CHAR(" + quotedCol + ", '" + resolvePostgreSqlTemporalFormat("date", + "YYYY-MM-DD", "YYYY-MM-DD") + "'), '')"; } /** @@ -140,7 +141,8 @@ protected String normalizeDate(String quotedCol) { */ @Override protected String normalizeTime(String quotedCol) { - return "COALESCE(TO_CHAR(" + quotedCol + ", 'HH24:MI:SS'), '')"; + return "COALESCE(TO_CHAR(" + quotedCol + ", '" + resolvePostgreSqlTemporalFormat("time", + "HH24:MI:SS", "HH24:MI:SS") + "'), '')"; } /** @@ -152,8 +154,14 @@ protected String normalizeTime(String quotedCol) { */ @Override protected String normalizeDateTime(String quotedCol) { - // DATETIME: No timezone conversion, format directly - return "COALESCE(TO_CHAR(" + quotedCol + ", 'YYYY-MM-DD HH24:MI:SS'), '')"; + String expression = quotedCol; + String targetTimezone = getTimezone("datetime", null); + if (targetTimezone != null && !targetTimezone.isBlank()) { + expression = quotedCol + " AT TIME ZONE current_setting('TIMEZONE') AT TIME ZONE '" + + escapeSqlLiteral(targetTimezone) + "'"; + } + return "COALESCE(TO_CHAR(" + expression + ", '" + resolvePostgreSqlTemporalFormat("datetime", + "YYYY-MM-DD HH24:MI:SS", "YYYY-MM-DD") + "'), '')"; } /** @@ -171,10 +179,10 @@ protected String normalizeDateTime(String quotedCol) { */ @Override protected String normalizeTimestamp(String quotedCol) { - // TIMESTAMP: Two-step timezone conversion using session timezone - // Step 1: AT TIME ZONE current_setting('TIMEZONE') - interpret in session timezone, returns TIMESTAMPTZ - // Step 2: AT TIME ZONE 'UTC' - convert to UTC, returns TIMESTAMP - return "COALESCE(TO_CHAR(" + quotedCol + " AT TIME ZONE current_setting('TIMEZONE') AT TIME ZONE 'UTC', 'YYYY-MM-DD HH24:MI:SS'), '')"; + String targetTimezone = escapeSqlLiteral(getTimezone("timestamp", "UTC")); + return "COALESCE(TO_CHAR(" + quotedCol + " AT TIME ZONE current_setting('TIMEZONE') AT TIME ZONE '" + + targetTimezone + "', '" + resolvePostgreSqlTemporalFormat("timestamp", + "YYYY-MM-DD HH24:MI:SS", "YYYY-MM-DD") + "'), '')"; } /** @@ -188,9 +196,9 @@ protected String normalizeTimestamp(String quotedCol) { */ @Override protected String normalizeTimestampWithTimezone(String quotedCol) { - // TIMESTAMPTZ: Direct conversion to UTC - // AT TIME ZONE 'UTC' converts TIMESTAMPTZ to TIMESTAMP in UTC - return "COALESCE(TO_CHAR(" + quotedCol + " AT TIME ZONE 'UTC', 'YYYY-MM-DD HH24:MI:SS'), '')"; + String targetTimezone = escapeSqlLiteral(getTimezone("timestamp", "UTC")); + return "COALESCE(TO_CHAR(" + quotedCol + " AT TIME ZONE '" + targetTimezone + "', '" + + resolvePostgreSqlTemporalFormat("timestamp", "YYYY-MM-DD HH24:MI:SS", "YYYY-MM-DD") + "'), '')"; } /** @@ -210,6 +218,45 @@ protected String normalizeBoolean(String quotedCol) { return "CASE WHEN " + quotedCol + " = TRUE THEN '1' ELSE '0' END"; } + private String resolvePostgreSqlTemporalFormat(String dataTypeName, String defaultFormat, String dateOnlyDefaultFormat) { + String configuredFormat = getFormat(dataTypeName, null); + String effectiveDefault = isDateOnlyComparison(dataTypeName) ? dateOnlyDefaultFormat : defaultFormat; + if (configuredFormat == null || configuredFormat.isBlank()) { + return effectiveDefault; + } + return toPostgreSqlDateFormat(configuredFormat, effectiveDefault); + } + + private String toPostgreSqlDateFormat(String javaFormat, String fallbackFormat) { + if (!isSupportedJavaTemporalFormat(javaFormat)) { + log.warn("Unsupported PostgreSQL temporal format '{}', falling back to '{}'", javaFormat, fallbackFormat); + return fallbackFormat; + } + return javaFormat + .replace("yyyy", "YYYY") + .replace("MM", "MM") + .replace("dd", "DD") + .replace("HH", "HH24") + .replace("mm", "MI") + .replace("ss", "SS"); + } + + private boolean isSupportedJavaTemporalFormat(String format) { + String residual = format + .replace("yyyy", "") + .replace("MM", "") + .replace("dd", "") + .replace("HH", "") + .replace("mm", "") + .replace("ss", "") + .replace("-", "") + .replace(":", "") + .replace(" ", "") + .replace("/", "") + .replace("T", ""); + return residual.isEmpty(); + } + @Override public DataType convertToDataType(String sourceType) { if (sourceType == null) { diff --git a/consilens-connector/consilens-connector-plugins/consilens-connector-postgresql/src/test/java/com/consilens/connector/postgresql/PostgreSQLDataTypeHandlerTest.java b/consilens-connector/consilens-connector-plugins/consilens-connector-postgresql/src/test/java/com/consilens/connector/postgresql/PostgreSQLDataTypeHandlerTest.java index 7736ec8..404b449 100644 --- a/consilens-connector/consilens-connector-plugins/consilens-connector-postgresql/src/test/java/com/consilens/connector/postgresql/PostgreSQLDataTypeHandlerTest.java +++ b/consilens-connector/consilens-connector-plugins/consilens-connector-postgresql/src/test/java/com/consilens/connector/postgresql/PostgreSQLDataTypeHandlerTest.java @@ -1,9 +1,12 @@ package com.consilens.connector.postgresql; import com.consilens.connector.api.model.DataType; +import com.consilens.conncetor.base.jdbc.JdbcDatasetHandle; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import java.util.Map; + import static org.junit.jupiter.api.Assertions.*; /** @@ -34,6 +37,20 @@ void testNormalizeColumn_Timestamp() { result); } + @Test + void testNormalizeColumn_TimestampDateOnlyModeWithTimezone() { + JdbcDatasetHandle.JdbcTypeNormalizationRule rule = new JdbcDatasetHandle.JdbcTypeNormalizationRule(); + rule.setComparisonMode("TRUNCATE_TO_DAY"); + rule.setTimezone("Asia/Shanghai"); + handler = new PostgreSQLDataTypeHandler(capabilityProvider, Map.of("timestamp", rule)); + + String result = handler.normalizeColumn("created_at", DataType.TIMESTAMP); + + assertEquals( + "COALESCE(TO_CHAR(\"created_at\" AT TIME ZONE current_setting('TIMEZONE') AT TIME ZONE 'Asia/Shanghai', 'YYYY-MM-DD'), '')", + result); + } + @Test void testNormalizeColumn_Boolean() { String result = handler.normalizeColumn("is_active", DataType.BOOLEAN); @@ -71,4 +88,9 @@ void testGetDataTypeMappingJSON() { String result = handler.getDataTypeMapping("jsonb", 0, 0, 0); assertEquals("JSONB", result); } + + @Test + void testConvertBpcharToChar() { + assertEquals(DataType.CHAR, handler.convertToDataType("bpchar")); + } } diff --git a/consilens-core/src/main/java/com/consilens/core/algorithm/ChecksumDiffer.java b/consilens-core/src/main/java/com/consilens/core/algorithm/ChecksumDiffer.java index cfa9c1f..19071dc 100644 --- a/consilens-core/src/main/java/com/consilens/core/algorithm/ChecksumDiffer.java +++ b/consilens-core/src/main/java/com/consilens/core/algorithm/ChecksumDiffer.java @@ -1,7 +1,7 @@ package com.consilens.core.algorithm; -import com.consilens.common.enums.ChecksumAlgorithm; import com.consilens.common.enums.LocalCompareMode; +import com.consilens.core.database.adpter.DatabaseAdapter; import com.consilens.core.segment.TableSegment; import com.consilens.core.diff.DiffResult; import com.consilens.core.segment.TableSegment.ChecksumResult; @@ -9,6 +9,7 @@ import com.consilens.core.diff.DiffRow; import com.consilens.core.diff.DiffOperation; import com.consilens.core.diff.InfoTreeRecorder; +import com.consilens.core.thread.ExecutorProvider; import com.consilens.connector.api.model.DataType; import com.consilens.connector.api.model.TableSchema; import com.github.benmanes.caffeine.cache.Cache; @@ -21,7 +22,10 @@ import java.time.Instant; import java.util.*; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicLong; /** @@ -37,6 +41,8 @@ public class ChecksumDiffer extends TableDiffer implements AutoCloseable { private final AtomicLong mismatchCount = new AtomicLong(0); private final ProgressReporter progressReporter; private final AtomicLong segmentSequence = new AtomicLong(0); + private final Map segmenterCache = new ConcurrentHashMap<>(); + private final Semaphore activeSegmentBudget; /** * Creates a new ChecksumDiffer with the given configuration. @@ -45,9 +51,18 @@ public class ChecksumDiffer extends TableDiffer implements AutoCloseable { * @throws IllegalArgumentException if bisectionFactor >= bisectionThreshold or bisectionFactor < 2 */ public ChecksumDiffer(DifferConfig config) { - super(config); + this(config, null); + } + + public ChecksumDiffer(DifferConfig config, ExecutorProvider executorProvider) { + super(config, + executorProvider != null ? executorProvider : new ExecutorProvider(config.getConcurrencyConfig()), + executorProvider == null); // Prevent degenerate recursion settings that would explode tasks or never split + if (config.getBisectionThreshold() <= 0) { + throw new IllegalArgumentException("Bisection threshold must be greater than 0"); + } if (config.getBisectionFactor() >= config.getBisectionThreshold()) { throw new IllegalArgumentException("Bisection factor must be lower than threshold"); } @@ -58,6 +73,7 @@ public ChecksumDiffer(DifferConfig config) { this.checksumCache = new ChecksumCache(); this.performanceMonitor = new PerformanceMonitor(); this.progressReporter = new ProgressReporter(); + this.activeSegmentBudget = new Semaphore(resolveSegmentBudget(config)); log.info("ChecksumDiffer initialized with checksumAlgorithm: {}, local comparison mode: auto", config.getChecksumAlgorithm()); @@ -165,7 +181,13 @@ private CompletableFuture diffSegmentsWithParent( segmentId, level, maxRows, config.getChecksumAlgorithm()); return diffSegmentsWithChecksum(table1, table2, infoTreeRecorder, maxRows, level, segmentId) - .whenComplete((result, error) -> infoTreeRecorder.endNode(segmentId)); + .handle((result, error) -> { + infoTreeRecorder.endNode(segmentId); + if (error != null) { + throw propagateAsyncFailure(error); + } + return result; + }); } /** @@ -232,8 +254,10 @@ private CompletableFuture determineSegmentAction( // Calculate segment characteristics long totalRows = result1.getCount() + result2.getCount(); - double sizeRatio = Math.max(result1.getCount(), result2.getCount()) - / (double) Math.min(result1.getCount(), result2.getCount()); + long minCount = Math.min(result1.getCount(), result2.getCount()); + double sizeRatio = minCount == 0 + ? (Math.max(result1.getCount(), result2.getCount()) == 0 ? 1.0 : Double.POSITIVE_INFINITY) + : Math.max(result1.getCount(), result2.getCount()) / (double) minCount; /* * Decision tree for segment processing: @@ -321,10 +345,7 @@ private CompletableFuture bisectAndDiffSegmentsEnhanced( largerTable.getTablePath(), smallerTable.getTablePath()); // Create a dedicated segmenter for the larger table to avoid database adapter confusion - TableSegmenter largerTableSegmenter = new TableSegmenter( - largerTable.getDatabase(), - TableSegmenter.SegmenterConfig.defaultConfig(), - executorProvider); + TableSegmenter largerTableSegmenter = getOrCreateSegmenter(largerTable); return largerTableSegmenter.createOptimalSegments(largerTable, bisectionFactor, config.getBisectionThreshold()) .thenCompose(largerSegments -> { @@ -336,9 +357,21 @@ private CompletableFuture bisectAndDiffSegmentsEnhanced( seg.getDatabase() != null ? seg.getDatabase().getName() : "null"); } + if (largerSegments.isEmpty()) { + return CompletableFuture.completedFuture(null); + } + + int requestedPermits = largerSegments.size(); + if (!activeSegmentBudget.tryAcquire(requestedPermits)) { + log.warn("Segment budget exhausted (requested={}, available={}), falling back to local comparison for {}", + requestedPermits, activeSegmentBudget.availablePermits(), parentSegmentId); + return performLocalComparison(table1, table2, infoTreeRecorder, parentSegmentId); + } + // Create corresponding segments for the smaller table and process all segments - return createCorrespondingSegmentsAndProcess(smallerTable, largerSegments, table1, table2, - largerTable, infoTreeRecorder, level, parentSegmentId); + return createCorrespondingSegmentsAndProcess(smallerTable, largerSegments, table1, table2, + largerTable, infoTreeRecorder, level, parentSegmentId) + .whenComplete((ignored, error) -> activeSegmentBudget.release(requestedPermits)); }); } @@ -423,6 +456,41 @@ private TableSegment createBoundedSegment(TableSegment original, ChecksumResult .build(); } + private RuntimeException propagateAsyncFailure(Throwable error) { + Throwable unwrapped = unwrapAsyncFailure(error); + if (unwrapped instanceof RuntimeException) { + return (RuntimeException) unwrapped; + } + return new RuntimeException(unwrapped); + } + + private Throwable unwrapAsyncFailure(Throwable error) { + Throwable current = error; + while ((current instanceof CompletionException || current instanceof ExecutionException) + && current.getCause() != null) { + current = current.getCause(); + } + return current; + } + + private TableSegmenter getOrCreateSegmenter(TableSegment table) { + DatabaseAdapter database = table.getDatabase(); + if (database == null) { + return new TableSegmenter(null, TableSegmenter.SegmenterConfig.defaultConfig(), executorProvider); + } + return segmenterCache.computeIfAbsent(database, + db -> new TableSegmenter(db, TableSegmenter.SegmenterConfig.defaultConfig(), executorProvider)); + } + + private int resolveSegmentBudget(DifferConfig config) { + if (config == null || config.getConcurrencyConfig() == null || config.getConcurrencyConfig().getIo() == null) { + return 64; + } + com.consilens.core.thread.ConcurrencyConfig.PoolConfig ioConfig = config.getConcurrencyConfig().getIo(); + int executorCapacity = Math.max(ioConfig.getCore(), ioConfig.getMax()); + return Math.max(64, executorCapacity * 4); + } + /** * Check if segment should be compared locally. */ diff --git a/consilens-core/src/main/java/com/consilens/core/algorithm/JoinDiffer.java b/consilens-core/src/main/java/com/consilens/core/algorithm/JoinDiffer.java index fefb8c5..8e05f35 100644 --- a/consilens-core/src/main/java/com/consilens/core/algorithm/JoinDiffer.java +++ b/consilens-core/src/main/java/com/consilens/core/algorithm/JoinDiffer.java @@ -6,6 +6,7 @@ import com.consilens.core.diff.DiffOperation; import com.consilens.core.diff.InfoTreeRecorder; import com.consilens.core.database.adpter.DatabaseAdapter; +import com.consilens.core.thread.ExecutorProvider; import lombok.Getter; import lombok.extern.slf4j.Slf4j; @@ -29,7 +30,13 @@ public class JoinDiffer extends TableDiffer implements AutoCloseable { private final AtomicLong segmentSequence = new AtomicLong(0); public JoinDiffer(DifferConfig config, JoinDifferOptions options) { - super(config); + this(config, options, null); + } + + public JoinDiffer(DifferConfig config, JoinDifferOptions options, ExecutorProvider executorProvider) { + super(config, + executorProvider != null ? executorProvider : new ExecutorProvider(config.getConcurrencyConfig()), + executorProvider == null); this.validateUniqueKeys = options.isValidateUniqueKeys(); this.performanceMonitor = new JoinPerformanceMonitor(); } diff --git a/consilens-core/src/main/java/com/consilens/core/algorithm/TableDiffer.java b/consilens-core/src/main/java/com/consilens/core/algorithm/TableDiffer.java index 0b363c2..e4f7b91 100644 --- a/consilens-core/src/main/java/com/consilens/core/algorithm/TableDiffer.java +++ b/consilens-core/src/main/java/com/consilens/core/algorithm/TableDiffer.java @@ -28,11 +28,21 @@ public abstract class TableDiffer implements DiffEmitter { protected final ExecutorProvider executorProvider; protected final DifferConfig config; + protected final boolean ownsExecutorProvider; private DiffSink diffSink; protected TableDiffer(DifferConfig config) { + this(config, new ExecutorProvider(config.getConcurrencyConfig()), true); + } + + protected TableDiffer(DifferConfig config, ExecutorProvider executorProvider) { + this(config, executorProvider, false); + } + + protected TableDiffer(DifferConfig config, ExecutorProvider executorProvider, boolean ownsExecutorProvider) { this.config = config; - this.executorProvider = new ExecutorProvider(config.getConcurrencyConfig()); + this.executorProvider = executorProvider; + this.ownsExecutorProvider = ownsExecutorProvider; this.diffSink = new InMemoryDiffSink(); } @@ -144,6 +154,10 @@ protected void validateAndAdjustColumns(TableSegment table1, TableSegment table2 public void shutdown() { try { log.info("Shutting down TableDiffer..."); + if (!ownsExecutorProvider) { + log.debug("Skipping ExecutorProvider shutdown because TableDiffer does not own it"); + return; + } if (executorProvider != null) { executorProvider.shutdown(); log.info("ExecutorProvider shutdown completed"); diff --git a/consilens-core/src/main/java/com/consilens/core/database/adpter/AbstractDatabaseAdapter.java b/consilens-core/src/main/java/com/consilens/core/database/adpter/AbstractDatabaseAdapter.java index 830146a..a534635 100644 --- a/consilens-core/src/main/java/com/consilens/core/database/adpter/AbstractDatabaseAdapter.java +++ b/consilens-core/src/main/java/com/consilens/core/database/adpter/AbstractDatabaseAdapter.java @@ -13,7 +13,6 @@ import java.sql.*; import java.util.*; -import java.util.concurrent.ConcurrentHashMap; /** * Abstract base class for database adapters providing common functionality. @@ -310,7 +309,6 @@ public void insertIntoTable(String tableName, String sql, int limit) { @Override public void close() { - clearStatementCache(); connectionPool.close(); log.info("Database adapter '{}' closed", name); } @@ -474,13 +472,14 @@ private String buildPostgreSQLChecksum(List columns) { if (columns.size() == 1) { expr.append("COALESCE(CAST(").append(columns.get(0)).append(" AS TEXT), '')"); } else { - expr.append("MD5(CONCAT("); + expr.append("MD5("); for (int i = 0; i < columns.size(); i++) { - if (i > 0) - expr.append(", '|', "); + if (i > 0) { + expr.append(" || '|' || "); + } expr.append("COALESCE(CAST(").append(columns.get(i)).append(" AS TEXT), '')"); } - expr.append("))"); + expr.append(")"); } return expr.toString(); } @@ -625,39 +624,29 @@ public Map, String> querySegmentRowHashes(TableSegment segment) { Map, String> rowHashes = new LinkedHashMap<>(); - Connection connection = null; - PreparedStatement statement = null; - ResultSet resultSet = null; - - try { - connection = getConnection(); - statement = connection.prepareStatement(rowHashQuery); + try (Connection connection = getConnection(); + PreparedStatement statement = connection.prepareStatement(rowHashQuery)) { statement.setFetchSize(1000); - resultSet = statement.executeQuery(); - - int keyColumnCount = segment.getKeyColumns().size(); - int rowCount = 0; + try (ResultSet resultSet = statement.executeQuery()) { + int keyColumnCount = segment.getKeyColumns().size(); + int rowCount = 0; + + while (resultSet.next()) { + // Extract primary key values + List primaryKey = new ArrayList<>(keyColumnCount); + for (int i = 1; i <= keyColumnCount; i++) { + primaryKey.add(resultSet.getObject(i)); + } - while (resultSet.next()) { - // Extract primary key values - List primaryKey = new ArrayList<>(keyColumnCount); - for (int i = 1; i <= keyColumnCount; i++) { - primaryKey.add(resultSet.getObject(i)); + // Extract row hash (last column) + String rowHash = resultSet.getString(resultSet.getMetaData().getColumnCount()); + rowHashes.put(primaryKey, rowHash); + rowCount++; } - // Extract row hash (last column) - String rowHash = resultSet.getString(resultSet.getMetaData().getColumnCount()); - rowHashes.put(primaryKey, rowHash); - rowCount++; + log.debug("Retrieved {} row hashes for segment: {}", rowHashes.size(), segment.getTablePath()); + return rowHashes; } - - log.debug("Retrieved {} row hashes for segment: {}", rowHashes.size(), segment.getTablePath()); - - return rowHashes; - - } finally { - ResourceManager.closeJdbcResources(statement, resultSet); - releaseQuietly(connection); } } catch (Exception e) { @@ -1038,41 +1027,9 @@ protected void releaseQuietly(Connection connection) { } } - /** - * Cache for compiled statements to improve performance. - */ - protected static final Map statementCache = new ConcurrentHashMap<>(); - // Note: DatabaseChecksumHelper will be implemented in the query package // For now, we'll use a simplified approach - /** - * Get cached prepared statement or create new one. - */ - protected PreparedStatement getCachedStatement(Connection connection, String sql) throws SQLException { - return statementCache.computeIfAbsent(sql, key -> { - try { - return connection.prepareStatement(key); - } catch (SQLException e) { - throw new RuntimeException("Error preparing statement: " + sql, e); - } - }); - } - - /** - * Clear the statement cache. - */ - protected void clearStatementCache() { - statementCache.values().forEach(stmt -> { - try { - stmt.close(); - } catch (Exception e) { - log.debug("Error closing statement", e); - } - }); - statementCache.clear(); - } - /** * Build the SQL query for row-hash based local comparison. diff --git a/consilens-core/src/main/java/com/consilens/core/segment/TableSegment.java b/consilens-core/src/main/java/com/consilens/core/segment/TableSegment.java index bee64c0..421ab8c 100644 --- a/consilens-core/src/main/java/com/consilens/core/segment/TableSegment.java +++ b/consilens-core/src/main/java/com/consilens/core/segment/TableSegment.java @@ -13,8 +13,11 @@ import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.List; +import java.util.Locale; import java.util.Optional; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** * Represents a segment of table data for comparison. @@ -24,6 +27,17 @@ @Builder(toBuilder = true) public class TableSegment { + private static final Pattern CUSTOM_WHERE_TOKEN_PATTERN = Pattern.compile( + "\\s+|<=|>=|<>|!=|=|<|>|\\(|\\)|,|-?\\d+(?:\\.\\d+)?|'(?:''|[^'])*'|[A-Za-z_][A-Za-z0-9_]*"); + + private static final Set ALLOWED_WHERE_KEYWORDS = Set.of( + "AND", "OR", "NOT", "IN", "IS", "NULL", "LIKE", "BETWEEN", "TRUE", "FALSE"); + + private static final Set BLOCKED_WHERE_KEYWORDS = Set.of( + "SELECT", "UNION", "DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", + "TRUNCATE", "MERGE", "CALL", "EXEC", "EXECUTE", "FROM", "JOIN", "HAVING", + "ORDER", "GROUP", "LIMIT", "OFFSET", "WITH"); + // Database and table identification private DatabaseAdapter database; @@ -157,15 +171,16 @@ public long approximateSize() { long minVal = ((Number) minValue).longValue(); long maxVal = ((Number) maxValue).longValue(); - long range = Math.max(1, maxVal - minVal); - - // Check for overflow - if (estimatedRows > Long.MAX_VALUE / range) { + long diff = Math.subtractExact(maxVal, minVal); + long range = Math.max(1L, diff); + if (range <= 0) { return Long.MAX_VALUE; } - estimatedRows *= range; + estimatedRows = Math.multiplyExact(estimatedRows, range); } return estimatedRows; + } catch (ArithmeticException e) { + return Long.MAX_VALUE; } catch (Exception e) { log.warn("Failed to estimate segment size", e); return -1; // Unknown size @@ -409,14 +424,9 @@ public String buildWhereClause() { // Add custom where clause if (whereClause != null && whereClause.isPresent()) { - // Basic validation for custom where clause to prevent obvious injection - // In a real scenario, this should be more robust or use a parser String customWhere = whereClause.get(); - if (customWhere.toLowerCase().contains("drop table") || - customWhere.toLowerCase().contains("delete from") || - customWhere.toLowerCase().contains("update ") || - customWhere.toLowerCase().contains("insert into")) { - throw new IllegalArgumentException("Dangerous SQL detected in where clause"); + if (!customWhere.isBlank()) { + validateCustomWhereClause(customWhere); } if (whereBuilder.length() > 0) { @@ -445,6 +455,55 @@ private void validateColumnName(String column) { } } + private void validateCustomWhereClause(String customWhere) { + if (customWhere.contains(";") || customWhere.contains("--") + || customWhere.contains("/*") || customWhere.contains("*/")) { + throw new IllegalArgumentException("Invalid custom where clause"); + } + + Matcher matcher = CUSTOM_WHERE_TOKEN_PATTERN.matcher(customWhere); + int index = 0; + while (index < customWhere.length()) { + if (!matcher.find(index) || matcher.start() != index) { + throw new IllegalArgumentException("Unsupported token in custom where clause"); + } + + String token = matcher.group(); + if (!token.isBlank()) { + validateCustomWhereToken(customWhere, token, matcher.end()); + } + index = matcher.end(); + } + } + + private void validateCustomWhereToken(String customWhere, String token, int nextIndex) { + if (!Character.isLetter(token.charAt(0)) && token.charAt(0) != '_') { + return; + } + + String upperToken = token.toUpperCase(Locale.ROOT); + if (BLOCKED_WHERE_KEYWORDS.contains(upperToken)) { + throw new IllegalArgumentException("Invalid custom where clause"); + } + if (ALLOWED_WHERE_KEYWORDS.contains(upperToken)) { + return; + } + + validateColumnName(token); + int followingIndex = skipWhitespace(customWhere, nextIndex); + if (followingIndex < customWhere.length() && customWhere.charAt(followingIndex) == '(') { + throw new IllegalArgumentException("Function calls are not allowed in custom where clause"); + } + } + + private int skipWhitespace(String value, int startIndex) { + int index = startIndex; + while (index < value.length() && Character.isWhitespace(value.charAt(index))) { + index++; + } + return index; + } + /** * Get the where clause as Optional. */ @@ -526,13 +585,16 @@ private String formatValue(Object value) { return "NULL"; } else if (value instanceof String) { return "'" + escapeSQL((String) value) + "'"; - } else if (value instanceof java.time.Instant) { + } else if (value instanceof Boolean) { + return (Boolean) value ? "TRUE" : "FALSE"; + } else if (value instanceof java.time.temporal.TemporalAccessor) { return "'" + value.toString() + "'"; + } else if (value instanceof java.util.Date) { + return "'" + new java.sql.Timestamp(((java.util.Date) value).getTime()) + "'"; } else if (value instanceof Number) { return value.toString(); } else { - // Fallback for other types, but still escape if it's treated as string - return "'" + escapeSQL(value.toString()) + "'"; + throw new IllegalArgumentException("Unsupported key value type: " + value.getClass().getName()); } } diff --git a/consilens-core/src/main/java/com/consilens/core/segmentation/CheckpointSelector.java b/consilens-core/src/main/java/com/consilens/core/segmentation/CheckpointSelector.java index 14db5da..8ca1874 100644 --- a/consilens-core/src/main/java/com/consilens/core/segmentation/CheckpointSelector.java +++ b/consilens-core/src/main/java/com/consilens/core/segmentation/CheckpointSelector.java @@ -26,8 +26,11 @@ public CheckpointSelector(int bisectionFactor, int bisectionThreshold) { * Choose optimal checkpoints for the given key range. */ public List chooseCheckpoints(KeyVector minKey, KeyVector maxKey, int preferredCount) { - if (minKey.isGreaterThanOrEqual(maxKey)) { - throw new IllegalArgumentException("minKey must be less than maxKey"); + if (minKey.isGreaterThan(maxKey)) { + throw new IllegalArgumentException("minKey must not be greater than maxKey"); + } + if (minKey.equals(maxKey)) { + return List.of(minKey, maxKey); } // Calculate the total space size diff --git a/consilens-core/src/main/java/com/consilens/core/thread/ConcurrencyConfig.java b/consilens-core/src/main/java/com/consilens/core/thread/ConcurrencyConfig.java index 7f8b9f2..b2a973a 100644 --- a/consilens-core/src/main/java/com/consilens/core/thread/ConcurrencyConfig.java +++ b/consilens-core/src/main/java/com/consilens/core/thread/ConcurrencyConfig.java @@ -10,12 +10,17 @@ */ @Data @NoArgsConstructor -@AllArgsConstructor public class ConcurrencyConfig { private PoolConfig io; private PoolConfig cpu; + public ConcurrencyConfig(PoolConfig io, PoolConfig cpu) { + this.io = io; + this.cpu = cpu; + validate(); + } + public static ConcurrencyConfig defaultConfig() { int cores = Math.max(1, Runtime.getRuntime().availableProcessors()); return new ConcurrencyConfig( @@ -24,9 +29,17 @@ public static ConcurrencyConfig defaultConfig() { ); } + public void validate() { + if (io != null) { + io.validate("io"); + } + if (cpu != null) { + cpu.validate("cpu"); + } + } + @Data @NoArgsConstructor - @AllArgsConstructor public static class PoolConfig { private int core; private int max; @@ -35,6 +48,15 @@ public static class PoolConfig { @JsonProperty(access = JsonProperty.Access.WRITE_ONLY) private String threadNamePrefix; + public PoolConfig(int core, int max, int queueSize, long keepAliveSeconds, String threadNamePrefix) { + this.core = core; + this.max = max; + this.queueSize = queueSize; + this.keepAliveSeconds = keepAliveSeconds; + this.threadNamePrefix = threadNamePrefix; + validate("pool"); + } + public static PoolConfig defaultIo(int cores) { int core = Math.max(8, cores * 2); int max = Math.max(32, cores * 8); @@ -46,5 +68,23 @@ public static PoolConfig defaultCpu(int cores) { int max = Math.max(core, cores * 2); return new PoolConfig(core, max, 10000, 60L, "consilens-cpu-"); } + + public void validate(String poolName) { + if (core <= 0) { + throw new IllegalArgumentException(poolName + ".core 必须大于 0"); + } + if (max <= 0) { + throw new IllegalArgumentException(poolName + ".max 必须大于 0"); + } + if (core > max) { + throw new IllegalArgumentException(poolName + ".core 不能大于 max"); + } + if (queueSize < 0) { + throw new IllegalArgumentException(poolName + ".queueSize 不能小于 0"); + } + if (keepAliveSeconds < 0) { + throw new IllegalArgumentException(poolName + ".keepAliveSeconds 不能小于 0"); + } + } } } diff --git a/consilens-core/src/main/java/com/consilens/core/thread/ExecutorProvider.java b/consilens-core/src/main/java/com/consilens/core/thread/ExecutorProvider.java index 103a624..0bf52da 100644 --- a/consilens-core/src/main/java/com/consilens/core/thread/ExecutorProvider.java +++ b/consilens-core/src/main/java/com/consilens/core/thread/ExecutorProvider.java @@ -19,19 +19,39 @@ public class ExecutorProvider { private final ExecutorService ioExecutor; private final ExecutorService cpuExecutor; + private final boolean managesExecutors; public ExecutorProvider(ConcurrencyConfig config) { ConcurrencyConfig effective = config != null ? config : ConcurrencyConfig.defaultConfig(); this.ioExecutor = createExecutor(effective.getIo(), "io"); this.cpuExecutor = createExecutor(effective.getCpu(), "cpu"); + this.managesExecutors = true; + } + + public ExecutorProvider(ExecutorService ioExecutor, ExecutorService cpuExecutor) { + this(ioExecutor, cpuExecutor, false); + } + + public ExecutorProvider(ExecutorService ioExecutor, ExecutorService cpuExecutor, boolean managesExecutors) { + this.ioExecutor = ioExecutor; + this.cpuExecutor = cpuExecutor; + this.managesExecutors = managesExecutors; } public void shutdown() { + if (!managesExecutors) { + log.debug("Skipping executor shutdown because ExecutorProvider does not own the executors"); + return; + } shutdownExecutor("io", ioExecutor); shutdownExecutor("cpu", cpuExecutor); } public void shutdownNow() { + if (!managesExecutors) { + log.debug("Skipping immediate executor shutdown because ExecutorProvider does not own the executors"); + return; + } shutdownExecutorNow("io", ioExecutor); shutdownExecutorNow("cpu", cpuExecutor); } @@ -41,6 +61,7 @@ private ExecutorService createExecutor(ConcurrencyConfig.PoolConfig poolConfig, ? poolConfig : ("io".equals(name) ? ConcurrencyConfig.PoolConfig.defaultIo(1) : ConcurrencyConfig.PoolConfig.defaultCpu(1)); + effective.validate(name); ThreadFactory factory = new NamedThreadFactory( effective.getThreadNamePrefix() != null ? effective.getThreadNamePrefix() : "consilens-" + name + "-" ); diff --git a/consilens-core/src/test/java/com/consilens/core/algorithm/ChecksumDifferTest.java b/consilens-core/src/test/java/com/consilens/core/algorithm/ChecksumDifferTest.java index 7a04931..ba8944f 100644 --- a/consilens-core/src/test/java/com/consilens/core/algorithm/ChecksumDifferTest.java +++ b/consilens-core/src/test/java/com/consilens/core/algorithm/ChecksumDifferTest.java @@ -6,6 +6,7 @@ import com.consilens.core.database.connection.ConnectionPool; import com.consilens.core.diff.DiffResult; import com.consilens.core.diff.DiffResult.InfoTreeNode; +import com.consilens.core.thread.ExecutorProvider; import com.consilens.connector.api.model.TablePath; import com.consilens.connector.api.model.PoolConfiguration; import com.consilens.core.segment.TableSegment; @@ -36,6 +37,9 @@ class ChecksumDifferTest { @Mock private DatabaseAdapter mockAdapter2; + @Mock + private ExecutorProvider mockExecutorProvider; + private TableDiffer.DifferConfig config; private ChecksumDiffer differ; private TableSegment segment1; @@ -85,6 +89,16 @@ void testBisectionFactorMustBeAtLeast2() { assertTrue(exception.getMessage().contains("Bisection factor must be at least 2")); } + @Test + @DisplayName("测试二分阈值必须大于0") + void testBisectionThresholdMustBePositive() { + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> new ChecksumDiffer(new TableDiffer.DifferConfig(4, 0, false, + ChecksumAlgorithm.CONCAT))); + assertTrue(exception.getMessage().contains("Bisection threshold must be greater than 0")); + } + @Test @DisplayName("测试有效配置创建成功") void testValidConfiguration() { @@ -96,6 +110,22 @@ void testValidConfiguration() { } } + @Nested + @DisplayName("线程池所有权测试") + class ExecutorOwnershipTests { + + @Test + @DisplayName("外部注入的 ExecutorProvider 不应被自动关闭") + void shouldNotShutdownInjectedExecutorProvider() { + ChecksumDiffer externalDiffer = new ChecksumDiffer(config, mockExecutorProvider); + + externalDiffer.close(); + + verify(mockExecutorProvider, never()).shutdown(); + verify(mockExecutorProvider, never()).shutdownNow(); + } + } + @Nested @DisplayName("正常差异比较测试") class NormalDiffTests { diff --git a/consilens-core/src/test/java/com/consilens/core/segment/TableSegmentTest.java b/consilens-core/src/test/java/com/consilens/core/segment/TableSegmentTest.java index ac6f8ec..2315675 100644 --- a/consilens-core/src/test/java/com/consilens/core/segment/TableSegmentTest.java +++ b/consilens-core/src/test/java/com/consilens/core/segment/TableSegmentTest.java @@ -100,6 +100,18 @@ public void testApproximateSizeWithStringKey() { assertEquals(-1, estimatedSize); } + @Test + public void testApproximateSizeReturnsLongMaxValueOnOverflow() { + TableSegment segment = TableSegment.builder() + .tablePath(TablePath.of("test_table")) + .keyColumns(Arrays.asList("id_part1", "id_part2")) + .minKey(Optional.of(Arrays.asList(0L, 0L))) + .maxKey(Optional.of(Arrays.asList(Long.MAX_VALUE, 3L))) + .build(); + + assertEquals(Long.MAX_VALUE, segment.approximateSize()); + } + @Test public void testChooseCheckpoints() { TableSegment segment = TableSegment.builder() @@ -171,6 +183,36 @@ public void testBuildWhereClause() { assertTrue(whereClause.contains("status = 'active'")); } + @Test + public void testBuildWhereClauseRejectsUnsafeCustomClause() { + TableSegment segment = TableSegment.builder() + .tablePath(TablePath.of("test_table")) + .keyColumns(Arrays.asList("id")) + .minKey(Optional.of(Arrays.asList(1))) + .maxKey(Optional.of(Arrays.asList(10))) + .whereClause(Optional.of("status = 'active'; DROP TABLE user_data")) + .build(); + + assertThrows(IllegalArgumentException.class, segment::buildWhereClause); + } + + @Test + public void testBuildWhereClauseAllowsSafeCustomClause() { + TableSegment segment = TableSegment.builder() + .tablePath(TablePath.of("test_table")) + .keyColumns(Arrays.asList("id")) + .minKey(Optional.of(Arrays.asList(1))) + .maxKey(Optional.of(Arrays.asList(10))) + .whereClause(Optional.of("status = 'active' AND region IN ('cn', 'us') AND archived IS NULL")) + .build(); + + String whereClause = segment.buildWhereClause(); + + assertTrue(whereClause.contains("status = 'active'")); + assertTrue(whereClause.contains("region IN ('cn', 'us')")); + assertTrue(whereClause.contains("archived IS NULL")); + } + @Test public void testValidation() { // Valid segment should not throw diff --git a/consilens-core/src/test/java/com/consilens/core/segmentation/CheckpointSelectorTest.java b/consilens-core/src/test/java/com/consilens/core/segmentation/CheckpointSelectorTest.java new file mode 100644 index 0000000..17de2ed --- /dev/null +++ b/consilens-core/src/test/java/com/consilens/core/segmentation/CheckpointSelectorTest.java @@ -0,0 +1,29 @@ +package com.consilens.core.segmentation; + +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CheckpointSelectorTest { + + @Test + void shouldAllowSingleValueKeySpace() { + CheckpointSelector selector = new CheckpointSelector(4, 100); + KeyVector value = new KeyVector(10); + + List checkpoints = selector.chooseCheckpoints(value, value, 4); + + assertEquals(List.of(value, value), checkpoints); + } + + @Test + void shouldRejectInvertedBounds() { + CheckpointSelector selector = new CheckpointSelector(4, 100); + + assertThrows(IllegalArgumentException.class, + () -> selector.chooseCheckpoints(new KeyVector(10), new KeyVector(1), 4)); + } +} diff --git a/consilens-core/src/test/java/com/consilens/core/thread/ConcurrencyConfigTest.java b/consilens-core/src/test/java/com/consilens/core/thread/ConcurrencyConfigTest.java new file mode 100644 index 0000000..e832347 --- /dev/null +++ b/consilens-core/src/test/java/com/consilens/core/thread/ConcurrencyConfigTest.java @@ -0,0 +1,32 @@ +package com.consilens.core.thread; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class ConcurrencyConfigTest { + + @Test + void shouldRejectInvalidPoolSizes() { + assertThrows(IllegalArgumentException.class, + () -> new ConcurrencyConfig.PoolConfig(0, 4, 10, 60L, "bad-")); + assertThrows(IllegalArgumentException.class, + () -> new ConcurrencyConfig.PoolConfig(4, 2, 10, 60L, "bad-")); + assertThrows(IllegalArgumentException.class, + () -> new ConcurrencyConfig.PoolConfig(2, 4, -1, 60L, "bad-")); + assertThrows(IllegalArgumentException.class, + () -> new ConcurrencyConfig.PoolConfig(2, 4, 10, -1L, "bad-")); + } + + @Test + void shouldAllowValidExecutorConfiguration() { + assertDoesNotThrow(() -> { + ExecutorProvider provider = new ExecutorProvider(new ConcurrencyConfig( + new ConcurrencyConfig.PoolConfig(2, 4, 16, 30L, "test-io-"), + new ConcurrencyConfig.PoolConfig(1, 2, 16, 30L, "test-cpu-") + )); + provider.shutdownNow(); + }); + } +} diff --git a/consilens-sink/consilens-sink-plugins/consilens-sink-table/src/main/java/com/consilens/sink/table/TableDiffRecordSink.java b/consilens-sink/consilens-sink-plugins/consilens-sink-table/src/main/java/com/consilens/sink/table/TableDiffRecordSink.java index e066b9b..033a2a0 100644 --- a/consilens-sink/consilens-sink-plugins/consilens-sink-table/src/main/java/com/consilens/sink/table/TableDiffRecordSink.java +++ b/consilens-sink/consilens-sink-plugins/consilens-sink-table/src/main/java/com/consilens/sink/table/TableDiffRecordSink.java @@ -48,6 +48,9 @@ public void open(SinkConfig config, DiffContext context) throws Exception { dataSource = createDataSource(sinkConfig); tableName = sinkConfig.resolveTableName(); batchSize = sinkConfig.getBatchSize(); + if (batchSize <= 0) { + throw new IllegalArgumentException("sink.batchSize 必须大于 0"); + } sourceColumns = context.getSourceColumnNames() != null ? context.getSourceColumnNames() : new ArrayList<>(); @@ -70,18 +73,24 @@ public void onDiffRecords(List rows, DiffContext context) throws SQLExc try (Connection conn = dataSource.getConnection(); PreparedStatement ps = conn.prepareStatement(insertSql)) { conn.setAutoCommit(false); - int count = 0; - for (DiffRow row : rows) { - bindInsertParams(ps, row, context); - ps.addBatch(); - if (++count % batchSize == 0) { + try { + int count = 0; + for (DiffRow row : rows) { + bindInsertParams(ps, row, context); + ps.addBatch(); + if (++count % batchSize == 0) { + ps.executeBatch(); + } + } + if (count % batchSize != 0) { ps.executeBatch(); - conn.commit(); } + conn.commit(); + log.debug("TableDiffRecordSink inserted {} rows", rows.size()); + } catch (SQLException | RuntimeException e) { + rollbackQuietly(conn); + throw e; } - ps.executeBatch(); - conn.commit(); - log.debug("TableDiffRecordSink inserted {} rows", rows.size()); } } @@ -216,6 +225,17 @@ private Map buildOverrideMap() { return map; } + private void rollbackQuietly(Connection conn) { + if (conn == null) { + return; + } + try { + conn.rollback(); + } catch (SQLException rollbackError) { + log.warn("TableDiffRecordSink rollback failed", rollbackError); + } + } + /** * Returns the override-resolved value if an override exists for {@code colName}, * falling back to {@code defaultValue} when overrideMap is null or no match. diff --git a/consilens-sink/consilens-sink-plugins/consilens-sink-table/src/test/java/com/consilens/sink/table/TableDiffRecordSinkTest.java b/consilens-sink/consilens-sink-plugins/consilens-sink-table/src/test/java/com/consilens/sink/table/TableDiffRecordSinkTest.java new file mode 100644 index 0000000..aadceee --- /dev/null +++ b/consilens-sink/consilens-sink-plugins/consilens-sink-table/src/test/java/com/consilens/sink/table/TableDiffRecordSinkTest.java @@ -0,0 +1,87 @@ +package com.consilens.sink.table; + +import com.consilens.core.diff.DiffRow; +import com.consilens.core.lifecycle.DiffContext; +import com.consilens.sink.api.model.SinkConfig; +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class TableDiffRecordSinkTest { + + @Test + void shouldRejectNonPositiveBatchSize() { + SinkConfig sinkConfig = new SinkConfig(); + sinkConfig.setFormat("table"); + sinkConfig.setType("diff-record"); + sinkConfig.setProperties("{" + + "\"url\":\"jdbc:h2:mem:invalid_batch_size;MODE=MySQL;DB_CLOSE_DELAY=-1\"," + + "\"username\":\"sa\"," + + "\"password\":\"\"," + + "\"driver\":\"org.h2.Driver\"," + + "\"tableName\":\"diff_record_invalid_batch\"," + + "\"createTable\":true," + + "\"dropIfExists\":true," + + "\"batchSize\":0" + + "}"); + + TableDiffRecordSink sink = new TableDiffRecordSink(); + DiffContext context = DiffContext.builder().taskId("task-invalid-batch").build(); + + assertThrows(IllegalArgumentException.class, () -> sink.open(sinkConfig, context)); + } + + @Test + void shouldRollbackEntireTransactionWhenBatchInsertFails() throws Exception { + String url = "jdbc:h2:mem:diff_record_rollback;MODE=MySQL;DB_CLOSE_DELAY=-1"; + SinkConfig sinkConfig = new SinkConfig(); + sinkConfig.setFormat("table"); + sinkConfig.setType("diff-record"); + sinkConfig.setProperties("{" + + "\"url\":\"" + url + "\"," + + "\"username\":\"sa\"," + + "\"password\":\"\"," + + "\"driver\":\"org.h2.Driver\"," + + "\"tableName\":\"diff_record_rollback_test\"," + + "\"createTable\":true," + + "\"dropIfExists\":true," + + "\"batchSize\":1," + + "\"columns\":[" + + "{\"name\":\"task_id\",\"value\":\"${taskId}\",\"columnType\":\"VARCHAR(64) PRIMARY KEY\"}," + + "{\"name\":\"diff_type\",\"value\":\"${operation}\",\"columnType\":\"VARCHAR(20) NOT NULL\"}" + + "]" + + "}"); + + DiffContext context = DiffContext.builder() + .taskId("task-rollback") + .build(); + + TableDiffRecordSink sink = new TableDiffRecordSink(); + sink.open(sinkConfig, context); + try { + List rows = List.of( + DiffRow.removed(List.of(1), List.of("Alice"), List.of("name")), + DiffRow.modified(List.of(2), List.of("Bob"), List.of("Bobby"), List.of("name"), List.of("name")) + ); + + assertThrows(SQLException.class, () -> sink.onDiffRecords(rows, context)); + } finally { + sink.close(); + } + + try (Connection connection = DriverManager.getConnection(url, "sa", ""); + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery("SELECT COUNT(*) FROM diff_record_rollback_test")) { + resultSet.next(); + assertEquals(0, resultSet.getInt(1)); + } + } +}