From e70f798e580f318a0c8c3d891fa2d92f8b565b2e Mon Sep 17 00:00:00 2001 From: Jafeer Ali Date: Mon, 21 Jul 2025 13:41:31 +0530 Subject: [PATCH 1/5] VS code gitignore added. --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 38b20b5..f5c35df 100644 --- a/.gitignore +++ b/.gitignore @@ -171,4 +171,5 @@ cython_debug/ .pypirc #vscode -.DS_Store \ No newline at end of file +.DS_Store +.vscode/ \ No newline at end of file From 726e15cd9fac9f6a9b747f5ce9f9feb64a036ce6 Mon Sep 17 00:00:00 2001 From: Jafeer Ali Date: Mon, 21 Jul 2025 13:42:00 +0530 Subject: [PATCH 2/5] numeric field - difference tolerance logic added. --- README.md | 102 ++++++-- spark_data_test/constants/common_constants.py | 2 + spark_data_test/entities/config.py | 9 +- spark_data_test/jobs/comparison_job.py | 238 +++++++++++------- spark_data_test/utils/spark_utils.py | 12 - tests/test_comparison_job.py | 84 +++++-- 6 files changed, 296 insertions(+), 151 deletions(-) diff --git a/README.md b/README.md index 099f155..30a5d0a 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ run_comparison_job_from_dfs( - `source_df`: Source DataFrame. - `target_df`: Target DataFrame. - `params`: An instance of `DatasetParams` specifying dataset name, primary keys, columns to select/drop, etc. -- `output_config`: An instance of `OutputConfig` specifying output directory, file format, Spark write options, etc. +- `output_config`: An instance of [`OutputConfig`](#outputconfig) specifying output directory, file format, Spark write options, etc. #### Example @@ -81,7 +81,7 @@ run_comparison_job( #### Parameters - `spark`: The active `SparkSession`. -- `config`: A dictionary or `ComparisonJobConfig` instance describing one or more datasets to compare, their source/target configs, and output config. +- `config`: A dictionary or [`ComparisonJobConfig`](#comparisonjobconfig) instance describing one or more datasets to compare, their source/target configs, and output config. #### Example @@ -116,23 +116,87 @@ run_comparison_job(spark, config) --- +## Example Configuration (Python dict) + +Below is an example of how to create a configuration dictionary for `run_comparison_job` using the dataclass structure: + +```python +config = { + "job_name": "sample_comparison_job", + "dataset_configs": [ + { + "params": { + "dataset_name": "table1", + "primary_keys": ["id"], + "test_params": {"difference_tolerance": 0.1}, + "select_cols": ["id", "name", "value"], + "drop_cols": [] + }, + "source_config": { + "path": "/data/source/table1", + "file_format": "parquet", + "spark_options": {} + }, + "target_config": { + "path": "/data/target/table1", + "file_format": "parquet", + "spark_options": {} + } + }, + { + "params": { + "dataset_name": "table2", + "primary_keys": ["key"], + "test_params": {"difference_tolerance": 0.0}, + "select_cols": ["key", "amount"], + "drop_cols": ["extra_col"] + }, + "source_config": { + "path": "/data/source/table2", + "file_format": "csv", + "spark_options": {"header": "true"} + }, + "target_config": { + "path": "/data/target/table2", + "file_format": "csv", + "spark_options": {"header": "true"} + } + } + ], + "output_config": { + "output_dir": "/tmp/comparison_results", + "output_file_format": "parquet", + "spark_options": {}, + "no_of_partitions": -1 + } +} +``` + +You can pass this config directly to `run_comparison_job(spark, config)`. + +--- + ## Configuration Dataclasses Below are the main dataclasses used for configuration in `spark-data-test`. You can use these directly in Python or as a reference for your JSON configs. ### DatasetParams - Defines parameters for a single dataset comparison. - +```python +@dataclass +class TestParams: + difference_margin: float = 0.0 # Allowed numeric difference for matching numeric columns. +``` ```python from dataclasses import dataclass, field @dataclass class DatasetParams: - dataset_name: str # Name of the dataset/table - primary_keys: list # List of primary key column names - select_cols: list = field(default_factory=lambda: ["*"]) # Columns to select (default: all) - drop_cols: list = field(default_factory=list) # Columns to drop (default: none) + dataset_name: str # Name of the dataset/table + primary_keys: list # List of primary key column names + test_params: TestParams # Testing parameters for dataset (Optional) + select_cols: list # Columns to select (default: all) (Optional) + drop_cols: list # Columns to drop (default: none) (Optional) ``` ### DataframeConfig @@ -144,9 +208,9 @@ from dataclasses import dataclass, field @dataclass class DataframeConfig: - path: str # Path to the data (e.g., file or table) - file_format: str = "parquet" # File format (parquet, csv, etc.) - spark_options: dict = field(default_factory=dict) # Spark read options (e.g., {"header": "true"}) + path: str # Path to the data (e.g., file or table) + file_format: str # File format (parquet, csv, etc.) (default:parquet) (Optional) + spark_options: dict # Spark read options (e.g., {"header": "true"}) (Optional) ``` ### OutputConfig @@ -158,10 +222,10 @@ from dataclasses import dataclass, field @dataclass class OutputConfig: - output_dir: str # Directory to write output files - output_file_format: str = "parquet" # Output file format - spark_options: dict = field(default_factory=dict) # Spark write options - no_of_partitions: int = -1 # Number of partitions for output (-1 for default) + output_dir: str # Directory to write output files + output_file_format: str # Output file format (default:parquet) (Optional) + spark_options: dict # Spark write options (Optional) + no_of_partitions: int = -1 # Number of partitions for output (-1 for default partitions) (Optional) ``` ### DatasetConfig @@ -200,7 +264,7 @@ After running a comparison job, the following files/directories are generated un ### **overall_test_report** -Summary DataFrame with row counts, matched counts, duplicate counts, missing rows, and test status for each dataset. +Summary DataFrame with row counts, matched counts, duplicate counts, missing rows, and test status for each dataset. Output will generate under `//overall_test_report` | dataset_name | count | matched_count | duplicate_count | missing_rows | test_status | |--------------|----------------------|---------------|------------------------|------------------------|-------------| @@ -210,7 +274,7 @@ Summary DataFrame with row counts, matched counts, duplicate counts, missing row ### **col_lvl_test_report** -Column-level report showing the count of unmatched values for each non-key column. +Column-level report showing the count of unmatched values for each non-key column. Output will generate under `//col_lvl_test_report` | dataset_name | column_name | unmatched_rows_count | |--------------|-------------|---------------------| @@ -221,7 +285,7 @@ Column-level report showing the count of unmatched values for each non-key colum ### **row_lvl_test_report** -Row-level report with primary keys, duplicate count, missing row status, and match status for each row. +Row-level report with primary keys, duplicate count, missing row status, and match status for each row. Output will generate under `//row_lvl_test_report` | dataset_name | id | duplicate_count | missing_row_status | all_rows_matched | |--------------|----|----------------|----------------------|------------------| @@ -232,7 +296,7 @@ Row-level report with primary keys, duplicate count, missing row status, and mat ### **unmatched_rows/** -Directory containing one file per column with all rows where that column did not match between source and target. +Directory containing one file per column with all rows where that column did not match between source and target. Output will generate under `//unmatched_rows//` Example for `unmatched_rows/colA`: diff --git a/spark_data_test/constants/common_constants.py b/spark_data_test/constants/common_constants.py index 6b47589..c590551 100644 --- a/spark_data_test/constants/common_constants.py +++ b/spark_data_test/constants/common_constants.py @@ -13,6 +13,7 @@ SRC_COL_SUFFIX = "{0}_src" TGT_COL_SUFFIX = "{0}_target" MATCHED_COL_SUFFIX = "{0}_matched" +MATCHED_SUFFIX = "_matched" #common col names CHK_SUM_COL = "_chk_sum" @@ -38,6 +39,7 @@ #row level report ALL_ROWS_MATCHED_COL = "all_rows_matched" +ALL_ROWS_MATCHED_AFTR_TOL_COL = "all_rows_matched_after_tolerance" DUPLICATE_COUNT_COL = "duplicate_count" MISSING_ROW_STATUS_COL = "missing_row_status" MISSING_AT_SOURCE_STATUS = "MISSING_AT_SOURCE" diff --git a/spark_data_test/entities/config.py b/spark_data_test/entities/config.py index 3ffcb74..74e7765 100644 --- a/spark_data_test/entities/config.py +++ b/spark_data_test/entities/config.py @@ -1,12 +1,18 @@ from dataclasses import dataclass, field from spark_data_test.constants.common_constants import PARQUET_FMT +@dataclass +class TestParams: + difference_tolerance: float = 0.0 + @dataclass class DatasetParams: dataset_name: str primary_keys: list + test_params: TestParams = field(default_factory=TestParams) select_cols: list = field(default_factory=lambda: ["*"]) drop_cols: list = field(default_factory=list) + @dataclass class OutputConfig: @@ -21,14 +27,13 @@ class DataframeConfig: file_format: str = PARQUET_FMT spark_options: dict = field(default_factory=dict) + @dataclass class DatasetConfig: params: DatasetParams source_config: DataframeConfig target_config: DataframeConfig - - @dataclass class ComparisonJobConfig: job_name: str diff --git a/spark_data_test/jobs/comparison_job.py b/spark_data_test/jobs/comparison_job.py index 386b218..d4dd6bf 100644 --- a/spark_data_test/jobs/comparison_job.py +++ b/spark_data_test/jobs/comparison_job.py @@ -1,121 +1,163 @@ import pyspark.sql.functions as f -from pyspark.sql.types import StringType, IntegerType +from pyspark.sql.types import ( + ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DecimalType, StructField, StringType, BooleanType +) from pyspark.sql import Window -from spark_data_test.utils.spark_utils import apply_spark_transformations, empty_str_to_null, set_value_ifnull +from spark_data_test.utils.spark_utils import set_value_ifnull from spark_data_test.utils.io_utils import read_dataframe, write_result from spark_data_test.utils.config_reader import parse_comparison_job_config, dict_to_dataclass from spark_data_test.entities.config import DatasetParams, OutputConfig, ComparisonJobConfig from spark_data_test.constants.common_constants import * +from functools import reduce +numeric_types = [ + ByteType(), ShortType(), IntegerType(), LongType(), + FloatType(), DoubleType(), DecimalType() +] -def __add_matched_column(df, column_name): - return df.withColumn( - MATCHED_COL_SUFFIX.format(column_name), - f.col(SRC_COL_SUFFIX.format(column_name)).eqNullSafe( - f.col(TGT_COL_SUFFIX.format(column_name)) - ), - ) - +def __add_matched_column(df, struct_field, test_params): + if struct_field.dataType in numeric_types: + return df.withColumn( + MATCHED_COL_SUFFIX.format(struct_field.name), + f.abs( + f.col(SRC_COL_SUFFIX.format(struct_field.name)) - + f.col(TGT_COL_SUFFIX.format(struct_field.name)) + ) <= test_params.difference_tolerance + ) + else: + return df.withColumn( + MATCHED_COL_SUFFIX.format(struct_field.name), + f.col(SRC_COL_SUFFIX.format(struct_field.name)).eqNullSafe( + f.col(TGT_COL_SUFFIX.format(struct_field.name)) + ), + ) def __apply_source_target_transformations(df, suffix, non_key_cols, params): - return apply_spark_transformations( - df.select(*params.select_cols).drop(*params.drop_cols), - lambda df, column_name: df.withColumn( - column_name, - empty_str_to_null(column_name), - ), - non_key_cols, - ).withColumn( - suffix.format(CHK_SUM_COL), - f.sha2(f.concat_ws(EMPTY_STR, *non_key_cols), 256), - ).withColumnsRenamed( - dict(map(lambda column_name: (column_name, suffix.format(column_name)), non_key_cols)) + return ( + df.select(*params.select_cols) + .drop(*params.drop_cols) + .withColumn( + suffix.format(CHK_SUM_COL), + f.sha2(f.concat_ws(EMPTY_STR, *non_key_cols), 256), + ) + .withColumnsRenamed( + dict(map(lambda column_name: (column_name, suffix.format(column_name)), non_key_cols)) + ) ) - -def __process_unmatched_records(unmatched_records, non_key_cols): - return apply_spark_transformations( - unmatched_records, __add_matched_column, non_key_cols +def __process_unmatched_records(unmatched_records, target_df_schema, key_cols, test_params): + non_key_struct_fields = [field for field in target_df_schema.fields if field.name not in key_cols] + unmatched_records = reduce( + lambda df, struct_field: __add_matched_column(df, struct_field, test_params), + non_key_struct_fields, + unmatched_records + ).drop(ALL_ROWS_MATCHED_COL) + matched_columns = [ + f.col(column_name) + for column_name in unmatched_records.columns + if column_name.endswith(MATCHED_SUFFIX) + ] + return unmatched_records.withColumn( + ALL_ROWS_MATCHED_AFTR_TOL_COL, + reduce(lambda col1, col2: col1 & col2, matched_columns) ) - def __get_overall_test_report(spark, dataset_name, key_cols, source_df, target_df, row_lvl_report, matched_records): matched_count = matched_records.count() - rows = [dataset_name, {SOURCE_COL:source_df.count(), TARGET_COL:target_df.count()}, matched_count] - + rows = [ + dataset_name, + {SOURCE_COL: source_df.count(), TARGET_COL: target_df.count()}, + matched_count + ] missing_rows_count = row_lvl_report.agg( - f.create_map( - f.lit(SOURCE_COL), f.sum(f.when(f.col(MISSING_ROW_STATUS_COL) == MISSING_AT_SOURCE_STATUS, 1).otherwise(0)), - f.lit(TARGET_COL), f.sum(f.when(f.col(MISSING_ROW_STATUS_COL) == MISSING_AT_TARGET_STATUS, 1).otherwise(0)) - ).alias(MISSING_ROWS_COL) + f.create_map( + f.lit(SOURCE_COL), + f.sum(f.when(f.col(MISSING_ROW_STATUS_COL) == MISSING_AT_SOURCE_STATUS, 1).otherwise(0)), + f.lit(TARGET_COL), + f.sum(f.when(f.col(MISSING_ROW_STATUS_COL) == MISSING_AT_TARGET_STATUS, 1).otherwise(0)) + ).alias(MISSING_ROWS_COL) ).withColumn(DATASET_NAME_COL, f.lit(dataset_name)) - - rows.extend( - [ - {SOURCE_COL:source_df.groupBy(*key_cols).count().where(f.col(COUNT_COL) > 1).count(), - TARGET_COL:target_df.groupBy(*key_cols).count().where(f.col(COUNT_COL) > 1).count()} - ] + rows.extend([ + { + SOURCE_COL: source_df.groupBy(*key_cols).count().where(f.col(COUNT_COL) > 1).count(), + TARGET_COL: target_df.groupBy(*key_cols).count().where(f.col(COUNT_COL) > 1).count() + } + ]) + overall_report = ( + spark.createDataFrame([tuple(rows)], [DATASET_NAME_COL, COUNT_COL, MATCHED_COUNT_COL, DUPLICATE_COUNT_COL]) + .join(missing_rows_count, DATASET_NAME_COL, LEFT_JOIN) + .withColumn( + TEST_STATUS_COL, + f.when( + (f.col(COUNT_COL).getItem(SOURCE_COL) == f.col(MATCHED_COUNT_COL)) & + (f.col(COUNT_COL).getItem(TARGET_COL) == f.col(MATCHED_COUNT_COL)), + f.lit(PASSED_STATUS) + ).otherwise(f.lit(FAILED_STATUS)) + ) ) - overall_report = spark.createDataFrame([tuple(rows)], [DATASET_NAME_COL, COUNT_COL, MATCHED_COUNT_COL, DUPLICATE_COUNT_COL]).join( - missing_rows_count, DATASET_NAME_COL, LEFT_JOIN - ).withColumn(TEST_STATUS_COL, f.when( - (f.col(COUNT_COL).getItem(SOURCE_COL) == f.col(MATCHED_COUNT_COL)) & (f.col(COUNT_COL).getItem(TARGET_COL) == f.col(MATCHED_COUNT_COL)), - f.lit(PASSED_STATUS) - ).otherwise(f.lit(FAILED_STATUS))) return overall_report def __get_column_level_test_report(dataset_name, unmatched_records, non_key_cols): - # Aggregates unmatched counts for each non-key column and unpivots to a long format. unmatch_count_conds = map( lambda column_name: f.sum( f.when(~f.col(MATCHED_COL_SUFFIX.format(column_name)), 1).otherwise(0) ).alias(column_name), non_key_cols ) - col_lvl_report = unmatched_records.agg(*unmatch_count_conds).unpivot( - [], non_key_cols, variableColumnName=COL_NAME, valueColumnName=UNMATCHED_ROWS_COUNT_COL - ).withColumn(UNMATCHED_ROWS_COUNT_COL, set_value_ifnull(UNMATCHED_ROWS_COUNT_COL, f.lit(0).cast(IntegerType()))).withColumn(DATASET_NAME_COL, f.lit(dataset_name)).cache() - return col_lvl_report.select( - DATASET_NAME_COL, COL_NAME, UNMATCHED_ROWS_COUNT_COL) + col_lvl_report = ( + unmatched_records.agg(*unmatch_count_conds) + .unpivot([], non_key_cols, variableColumnName=COL_NAME, valueColumnName=UNMATCHED_ROWS_COUNT_COL) + .withColumn( + UNMATCHED_ROWS_COUNT_COL, + set_value_ifnull(UNMATCHED_ROWS_COUNT_COL, f.lit(0).cast(IntegerType())) + ) + .withColumn(DATASET_NAME_COL, f.lit(dataset_name)) + .cache() + ) + return col_lvl_report.select(DATASET_NAME_COL, COL_NAME, UNMATCHED_ROWS_COUNT_COL) def __get_unmatched_records(dataset_name, unmatched_records, col_lvl_report, key_cols): - # Builds a dictionary of DataFrames for each column with unmatched records. unmatched_records_map = {} - unmatched_rows_cols = list( - map( - lambda x: x.column_name, - col_lvl_report.where(f.col(UNMATCHED_ROWS_COUNT_COL) > 0).select(COL_NAME).collect() - ) - ) + unmatched_rows_cols = [ + x.column_name + for x in col_lvl_report.where(f.col(UNMATCHED_ROWS_COUNT_COL) > 0).select(COL_NAME).collect() + ] for column_name in unmatched_rows_cols: - unmatched_records_map[f'{dataset_name}/{column_name}'] = unmatched_records.where( - ~f.col(MATCHED_COL_SUFFIX.format(column_name)) - ).select( - (key_cols + [SRC_COL_SUFFIX.format(column_name), TGT_COL_SUFFIX.format(column_name)]) - ).distinct() + unmatched_records_map[f'{dataset_name}/{column_name}'] = ( + unmatched_records.where(~f.col(MATCHED_COL_SUFFIX.format(column_name))) + .select(key_cols + [SRC_COL_SUFFIX.format(column_name), TGT_COL_SUFFIX.format(column_name)]) + .distinct() + ) return unmatched_records_map def __get_row_level_test_report(dataset_name, joined_df, key_cols): - # Returns a DataFrame with row-level comparison, duplicate count, and missing row status. - result = joined_df.select( - *key_cols, - SRC_COL_SUFFIX.format(CHK_SUM_COL), - TGT_COL_SUFFIX.format(CHK_SUM_COL), - ALL_ROWS_MATCHED_COL - ).withColumn( - DUPLICATE_COUNT_COL, - f.count("*").over(Window.partitionBy(*key_cols).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) - 1 - ).withColumn( - MISSING_ROW_STATUS_COL, - f.when( - f.col(SRC_COL_SUFFIX.format(CHK_SUM_COL)).isNull(), MISSING_AT_SOURCE_STATUS - ).when( - f.col(TGT_COL_SUFFIX.format(CHK_SUM_COL)).isNull(), MISSING_AT_TARGET_STATUS - ).otherwise(f.lit(PRESENT_IN_BOTH_STATUS).cast(StringType())) - ).drop( - SRC_COL_SUFFIX.format(CHK_SUM_COL), - TGT_COL_SUFFIX.format(CHK_SUM_COL) - ).distinct().withColumn(DATASET_NAME_COL, f.lit(dataset_name)) + result = ( + joined_df.select( + *key_cols, + SRC_COL_SUFFIX.format(CHK_SUM_COL), + TGT_COL_SUFFIX.format(CHK_SUM_COL), + ALL_ROWS_MATCHED_COL + ) + .withColumn( + DUPLICATE_COUNT_COL, + f.count("*").over(Window.partitionBy(*key_cols).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)) - 1 + ) + .withColumn( + MISSING_ROW_STATUS_COL, + f.when( + f.col(SRC_COL_SUFFIX.format(CHK_SUM_COL)).isNull(), MISSING_AT_SOURCE_STATUS + ).when( + f.col(TGT_COL_SUFFIX.format(CHK_SUM_COL)).isNull(), MISSING_AT_TARGET_STATUS + ).otherwise(f.lit(PRESENT_IN_BOTH_STATUS).cast(StringType())) + ) + .drop( + SRC_COL_SUFFIX.format(CHK_SUM_COL), + TGT_COL_SUFFIX.format(CHK_SUM_COL) + ) + .distinct() + .withColumn(DATASET_NAME_COL, f.lit(dataset_name)) + ) return result.select(DATASET_NAME_COL, *key_cols, DUPLICATE_COUNT_COL, MISSING_ROW_STATUS_COL, ALL_ROWS_MATCHED_COL) def __write_results(comparison_result, job_name, output_config): @@ -143,6 +185,7 @@ def __write_results(comparison_result, job_name, output_config): def compare_dataframes(spark, source_df, target_df, params): non_key_cols = list(set(source_df.columns) - set(params.primary_keys)) + target_df_schema = target_df.select(*params.select_cols).drop(*params.drop_cols).schema source_df = __apply_source_target_transformations( source_df, SRC_COL_SUFFIX, @@ -159,22 +202,33 @@ def compare_dataframes(spark, source_df, target_df, params): joined_df = source_df.join( target_df, params.primary_keys, FULL_OUTER_JOIN ).cache() - joined_df = __add_matched_column(joined_df, CHK_SUM_COL).withColumnRenamed( + joined_df = __add_matched_column( + joined_df, + StructField(CHK_SUM_COL, StringType(), True), + params.test_params + ).withColumnRenamed( MATCHED_COL_SUFFIX.format(CHK_SUM_COL), ALL_ROWS_MATCHED_COL ).cache() - unmatched_records = __process_unmatched_records( + unmatched_rows_aftr_tolerance = __process_unmatched_records( joined_df.filter( (f.col(ALL_ROWS_MATCHED_COL) == False) & (f.col(SRC_COL_SUFFIX.format(CHK_SUM_COL)).isNotNull()) & (f.col(TGT_COL_SUFFIX.format(CHK_SUM_COL)).isNotNull()) ), - non_key_cols + target_df_schema, + params.primary_keys, + params.test_params ) + unmatched_records = unmatched_rows_aftr_tolerance.filter(f.col(ALL_ROWS_MATCHED_AFTR_TOL_COL) == False) + + joined_df = joined_df.join( + unmatched_rows_aftr_tolerance.filter(f.col(ALL_ROWS_MATCHED_AFTR_TOL_COL)).select(*(params.primary_keys+[ALL_ROWS_MATCHED_AFTR_TOL_COL])), + params.primary_keys, + LEFT_JOIN + ).withColumn(ALL_ROWS_MATCHED_COL, f.col(ALL_ROWS_MATCHED_COL) | f.coalesce(f.col(ALL_ROWS_MATCHED_AFTR_TOL_COL), f.lit(False).cast(BooleanType()))).cache() - matched_records = joined_df.filter(f.col(ALL_ROWS_MATCHED_COL) == True).dropDuplicates( - params.primary_keys - ) + matched_records = joined_df.filter(f.col(ALL_ROWS_MATCHED_COL) == True).dropDuplicates(params.primary_keys) col_lvl_report = __get_column_level_test_report( params.dataset_name, unmatched_records, non_key_cols @@ -200,13 +254,12 @@ def run_comparison_job_from_dfs(spark, job_name, source_df, target_df, params, o params = dict_to_dataclass(DatasetParams, params) if isinstance(output_config, dict): output_config = dict_to_dataclass(OutputConfig, output_config) - comparison_result = compare_dataframes(spark, source_df, target_df, params) - # Save the results to the specified output directory __write_results(comparison_result, job_name, output_config) def run_comparison_job(spark, config): - config: ComparisonJobConfig = parse_comparison_job_config(config) + if isinstance(config, dict): + config: ComparisonJobConfig = parse_comparison_job_config(config) consolidated_result = {} for dataset_config in config.dataset_configs: comparison_result = compare_dataframes( @@ -230,5 +283,4 @@ def run_comparison_job(spark, config): ) else: consolidated_result = comparison_result - # Save the results to the specified output directory __write_results(consolidated_result, config.job_name, config.output_config) \ No newline at end of file diff --git a/spark_data_test/utils/spark_utils.py b/spark_data_test/utils/spark_utils.py index 411a61f..d0c47f7 100644 --- a/spark_data_test/utils/spark_utils.py +++ b/spark_data_test/utils/spark_utils.py @@ -1,17 +1,5 @@ from functools import reduce import pyspark.sql.functions as f -from spark_data_test.constants.common_constants import ( - EMPTY_STR, -) -from pyspark.sql.types import StringType -def apply_spark_transformations(df, trans_function, columns:list=None): - if columns is None: - columns = df.columns - return reduce(lambda df, column_name: trans_function(df, column_name), columns, df) - - -def empty_str_to_null(column_name): - return f.when(f.col(column_name) == EMPTY_STR, f.lit(None).cast(StringType())).otherwise(f.col(column_name).cast(StringType())) def set_value_ifnull(column_name, value): return f.when(f.col(column_name).isNull(), value).otherwise(f.col(column_name)) diff --git a/tests/test_comparison_job.py b/tests/test_comparison_job.py index 213c8a7..bfed790 100644 --- a/tests/test_comparison_job.py +++ b/tests/test_comparison_job.py @@ -1,9 +1,13 @@ import pytest from pyspark.sql import SparkSession, Row from pyspark.sql.functions import col -from spark_data_test.jobs.comparison_job import run_comparison_job_from_dfs -from spark_data_test.entities.config import DatasetParams, OutputConfig +from spark_data_test.jobs.comparison_job import run_comparison_job_from_dfs, run_comparison_job +from spark_data_test.entities.config import DatasetParams, OutputConfig, TestParams, ComparisonJobConfig, DataframeConfig, DatasetConfig from spark_data_test.constants.common_constants import * +from pyspark.testing import assertDataFrameEqual + +from dataclasses import replace + job_name = "unit-testing" @pytest.fixture(scope="session") @@ -13,16 +17,16 @@ def spark(): @pytest.fixture def sample_data(spark): df1 = spark.createDataFrame([ - Row(id=1, value='foo', value2='foo', value3= True), - Row(id=3, value=None, value2=None, value3 = False), - Row(id=3, value=None, value2=None, value3 = False), - Row(id=4, value=None, value2=None, value3 = False), - Row(id=7, value='foo7', value2='foo2', value3 = False), + Row(id=1, value='foo', value2='foo', value3= True, value4 = 1.0), + Row(id=3, value=None, value2=None, value3 = False, value4 = 3.0), + Row(id=3, value=None, value2=None, value3 = False, value4 = 3.0), + Row(id=4, value=None, value2=None, value3 = False, value4 = 4.0), + Row(id=7, value='foo7', value2='foo2', value3 = False, value4 = 7.1), ]) df2 = spark.createDataFrame([ - Row(id=3, value=None, value2=None, value3 = False), - Row(id=3, value=None, value2=None, value3 = False), - Row(id=7, value='foo', value2='foo3', value3 = False), + Row(id=3, value=None, value2=None, value3 = False, value4 = 3.0), + Row(id=3, value=None, value2=None, value3 = False, value4 = 3.0), + Row(id=7, value='foo', value2='foo3', value3 = False, value4 = 7.0), ]) return df1, df2 @@ -30,7 +34,8 @@ def sample_data(spark): def configs(tmp_path): job_params = DatasetParams( dataset_name="pytest_job", - primary_keys=["id"] + primary_keys=["id"], + test_params= TestParams(difference_tolerance=0.1) ) output_config = OutputConfig( output_dir=str(tmp_path), @@ -40,7 +45,7 @@ def configs(tmp_path): return job_params, output_config @pytest.fixture -def get_result(spark, sample_data, configs): +def get_comparison_job_from_dfs_results(spark, sample_data, configs): df1, df2 = sample_data job_params, output_config = configs @@ -51,9 +56,35 @@ def get_result(spark, sample_data, configs): result[COL_LVL_TEST_REPORT_KEY] = spark.read.parquet(f"{output_config.output_dir}/{job_name}/{COL_LVL_TEST_REPORT_KEY}") return result +@pytest.fixture +def get_comparison_job_results(spark,tmp_path, sample_data, configs): + df1, df2 = sample_data + job_params, output_config = configs + output_config = replace(output_config,output_dir= f"{str(tmp_path)}/comparison_job") + + df1.write.mode("overwrite").parquet(f"{str(tmp_path)}/{job_name}/inputs/df1") + df2.write.mode("overwrite").parquet(f"{str(tmp_path)}/{job_name}/inputs/df2") + + source_config = DataframeConfig(path = f"{str(tmp_path)}/{job_name}/inputs/df1") + target_config = DataframeConfig(path = f"{str(tmp_path)}/{job_name}/inputs/df2") + ds_config = DatasetConfig(params=job_params, source_config=source_config, target_config=target_config) + comparison_job_config = ComparisonJobConfig(job_name = job_name, dataset_configs = [ds_config], output_config = output_config) + + run_comparison_job(spark, comparison_job_config) + result = {} + result[OVERALL_TEST_REPORT_KEY] = spark.read.parquet(f"{output_config.output_dir}/{job_name}/{OVERALL_TEST_REPORT_KEY}") + result[ROW_LVL_TEST_REPORT_KEY] = spark.read.parquet(f"{output_config.output_dir}/{job_name}/{ROW_LVL_TEST_REPORT_KEY}") + result[COL_LVL_TEST_REPORT_KEY] = spark.read.parquet(f"{output_config.output_dir}/{job_name}/{COL_LVL_TEST_REPORT_KEY}") + return result + +def test_run_comparison_job_results(spark, get_comparison_job_results, get_comparison_job_from_dfs_results): + assertDataFrameEqual(get_comparison_job_results[OVERALL_TEST_REPORT_KEY], get_comparison_job_from_dfs_results[OVERALL_TEST_REPORT_KEY]) + assertDataFrameEqual(get_comparison_job_results[ROW_LVL_TEST_REPORT_KEY], get_comparison_job_from_dfs_results[ROW_LVL_TEST_REPORT_KEY]) + assertDataFrameEqual(get_comparison_job_results[COL_LVL_TEST_REPORT_KEY], get_comparison_job_from_dfs_results[COL_LVL_TEST_REPORT_KEY]) -def test_overall_test_report(get_result): - row = get_result.get(OVERALL_TEST_REPORT_KEY).first() + +def test_overall_test_report(get_comparison_job_from_dfs_results): + row = get_comparison_job_from_dfs_results.get(OVERALL_TEST_REPORT_KEY).first() assert row[COUNT_COL][SOURCE_COL] == 5 and row[COUNT_COL][TARGET_COL]== 3 assert row[MATCHED_COUNT_COL] == 1 @@ -61,8 +92,8 @@ def test_overall_test_report(get_result): assert row[MISSING_ROWS_COL][SOURCE_COL] == 0 and row[MISSING_ROWS_COL][TARGET_COL] == 2 assert row[TEST_STATUS_COL] == FAILED_STATUS -def test_row_level_test_report(get_result): - row_level_report = get_result.get(ROW_LVL_TEST_REPORT_KEY).cache() +def test_row_level_test_report(get_comparison_job_from_dfs_results): + row_level_report = get_comparison_job_from_dfs_results.get(ROW_LVL_TEST_REPORT_KEY).cache() assert row_level_report.count() == 4 # Check for specific rows @@ -76,16 +107,19 @@ def test_row_level_test_report(get_result): assert not row4[ALL_ROWS_MATCHED_COL] -def test_column_level_test_report(get_result): - col_level_report = get_result.get(COL_LVL_TEST_REPORT_KEY).cache() - assert col_level_report.count() == 3 +def test_column_level_test_report(get_comparison_job_from_dfs_results): + col_level_report = get_comparison_job_from_dfs_results.get(COL_LVL_TEST_REPORT_KEY).cache() + assert col_level_report.count() == 4 # Check for specific rows - row1 = col_level_report.filter(col(COL_NAME) == "value").first() - assert row1[UNMATCHED_ROWS_COUNT_COL] == 1 + value_col = col_level_report.filter(col(COL_NAME) == "value").first() + assert value_col[UNMATCHED_ROWS_COUNT_COL] == 1 + + value2_col = col_level_report.filter(col(COL_NAME) == "value2").first() + assert value2_col[UNMATCHED_ROWS_COUNT_COL] == 1 - row2 = col_level_report.filter(col(COL_NAME) == "value2").first() - assert row2[UNMATCHED_ROWS_COUNT_COL] == 1 + value3_col = col_level_report.filter(col(COL_NAME) == "value3").first() + assert value3_col[UNMATCHED_ROWS_COUNT_COL] == 0 - row4 = col_level_report.filter(col(COL_NAME) == "value3").first() - assert row4[UNMATCHED_ROWS_COUNT_COL] == 0 \ No newline at end of file + value4_col = col_level_report.filter(col(COL_NAME) == "value4").first() + assert value4_col[UNMATCHED_ROWS_COUNT_COL] == 0 \ No newline at end of file From b3144a6761316b836013d395781270646e0b0c04 Mon Sep 17 00:00:00 2001 From: Jafeer Ali Date: Mon, 21 Jul 2025 13:56:21 +0530 Subject: [PATCH 3/5] pandas dependency added for pytest coverage --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 17ec335..779ce93 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ 'pyspark==3.5.6' ], extras_require={ - "dev": ["pytest>=5", "pytest-cov"] + "dev": ["pytest>=5", "pytest-cov", "pandas"] }, python_requires='>=3.7', ) \ No newline at end of file From 20982144c35b54acff5060fa07b1b8a8fd85cd91 Mon Sep 17 00:00:00 2001 From: Jafeer Ali Date: Mon, 21 Jul 2025 13:58:50 +0530 Subject: [PATCH 4/5] pyarrow dependency added for pytest coverage --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 779ce93..7cbdd61 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ 'pyspark==3.5.6' ], extras_require={ - "dev": ["pytest>=5", "pytest-cov", "pandas"] + "dev": ["pytest>=5", "pytest-cov", "pandas", "pyarrow"] }, python_requires='>=3.7', ) \ No newline at end of file From c8f3586e8a637a1f0601225c3780fded8cd6d376 Mon Sep 17 00:00:00 2001 From: Jafeer Ali Date: Mon, 21 Jul 2025 14:54:01 +0530 Subject: [PATCH 5/5] codecov depedency version update --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7cbdd61..7a66ca5 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ 'pyspark==3.5.6' ], extras_require={ - "dev": ["pytest>=5", "pytest-cov", "pandas", "pyarrow"] + "dev": ["numpy==1.26.4","pytest>=5", "pytest-cov", "pandas==2.2.2", "pyarrow==14.0.2"] }, python_requires='>=3.7', ) \ No newline at end of file