diff --git a/src/dve/core_engine/backends/base/rules.py b/src/dve/core_engine/backends/base/rules.py index b66b3ae..9b6b4fe 100644 --- a/src/dve/core_engine/backends/base/rules.py +++ b/src/dve/core_engine/backends/base/rules.py @@ -681,3 +681,13 @@ def read_parquet(self, path: URI, **kwargs) -> EntityType: def write_parquet(self, entity: EntityType, target_location: URI, **kwargs) -> URI: """Method to write parquet files""" raise NotImplementedError() + + def filter_data_contract_record_rejections( + self, + working_directory: URI, + entity: EntityType, + entity_name: EntityName, + **kwargs, + ): + """Method to filter out record rejection errors from the data contract for a given entity""" + raise NotImplementedError() diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index 627822b..786ef8f 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -18,9 +18,10 @@ from pydantic import BaseModel from typing_extensions import Annotated, get_args, get_origin, get_type_hints +from dve.common.error_utils import get_feedback_errors_uri from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME -from dve.core_engine.type_hints import URI +from dve.core_engine.type_hints import URI, EntityName from dve.parser.file_handling.service import LocalFilesystemImplementation, _get_implementation @@ -100,7 +101,7 @@ def table_exists(connection: DuckDBPyConnection, table_name: str) -> bool: def relation_is_empty(relation: DuckDBPyRelation) -> bool: """Check if a duckdb relation is empty""" - if relation.limit(1).count("*"): + if relation.limit(1).shape[0] > 0: return False return True @@ -256,6 +257,48 @@ def duckdb_write_parquet(cls): return cls +def _ddb_filter_contract_errors( + self, + working_directory: URI, + entity: DuckDBPyRelation, + entity_name: EntityName, +) -> DuckDBPyRelation: + contract_error_location = get_feedback_errors_uri(working_directory, "data_contract") + if not Path(contract_error_location).exists(): + return entity + relevant_record_rejection_codes_rel = ( + self._connection.read_json( + contract_error_location, + columns={ + "RecordIndex": "INTEGER", + "FailureType": "STRING", + "Status": "STRING", + "Entity": "STRING", + }, + ) + .filter(f"FailureType == 'record' AND Status != 'informational' AND Entity = '{entity_name}'") # pylint: disable=C0301 + .select("RecordIndex") + .distinct() + .order("RecordIndex asc") + ) + + if relation_is_empty(relevant_record_rejection_codes_rel): + return entity + + filtered_entity = entity.join( + relevant_record_rejection_codes_rel, + condition="__record_index__ == RecordIndex", + how="anti" + ) + return filtered_entity + + +def ddb_filter_contract_errors(cls): + """Class decorator to filter out records that failed casting and have record rejection scope""" + cls.filter_data_contract_record_rejections = _ddb_filter_contract_errors + return cls + + @staticmethod # type: ignore def _duckdb_get_entity_count(entity: DuckDBPyRelation) -> int: """Method to obtain entity count from a persisted parquet entity""" diff --git a/src/dve/core_engine/backends/implementations/duckdb/rules.py b/src/dve/core_engine/backends/implementations/duckdb/rules.py index debb8fe..dc73dad 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/rules.py +++ b/src/dve/core_engine/backends/implementations/duckdb/rules.py @@ -22,6 +22,7 @@ from dve.core_engine.backends.exceptions import ConstraintError from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( DDBStruct, + ddb_filter_contract_errors, duckdb_read_parquet, duckdb_record_index, duckdb_rel_to_dictionaries, @@ -61,6 +62,7 @@ @duckdb_record_index @duckdb_write_parquet @duckdb_read_parquet +@ddb_filter_contract_errors class DuckDBStepImplementations(BaseStepImplementations[DuckDBPyRelation]): """An implementation of transformation steps in duckdb.""" diff --git a/src/dve/core_engine/backends/implementations/spark/rules.py b/src/dve/core_engine/backends/implementations/spark/rules.py index 307e71a..66564ee 100644 --- a/src/dve/core_engine/backends/implementations/spark/rules.py +++ b/src/dve/core_engine/backends/implementations/spark/rules.py @@ -17,6 +17,7 @@ spark_read_parquet, spark_record_index, spark_write_parquet, + spark_filter_contract_errors, ) from dve.core_engine.backends.implementations.spark.types import ( Joined, @@ -53,6 +54,7 @@ @spark_record_index @spark_write_parquet @spark_read_parquet +@spark_filter_contract_errors class SparkStepImplementations(BaseStepImplementations[DataFrame]): """An implementation of transformation steps in Apache Spark.""" diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index ced985a..2c2fde4 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -12,6 +12,7 @@ from dataclasses import dataclass, is_dataclass from decimal import Decimal from functools import wraps +from pathlib import Path from typing import Any, ClassVar, Optional, TypeVar, Union, overload from delta.exceptions import ConcurrentAppendException, DeltaConcurrentModificationException @@ -26,8 +27,9 @@ from typing_extensions import Annotated, Protocol, TypedDict, get_args, get_origin, get_type_hints from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type +from dve.common.error_utils import get_feedback_errors_uri from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME -from dve.core_engine.type_hints import URI +from dve.core_engine.type_hints import URI, EntityName # It would be really nice if there was a more parameterisable # way of doing this. @@ -365,6 +367,51 @@ def spark_write_parquet(cls): return cls +def _spark_filter_contract_errors( + self, + working_directory: URI, + entity: DataFrame, + entity_name: EntityName, +) -> DataFrame: + contract_error_location = get_feedback_errors_uri(working_directory, "data_contract") + if not Path(contract_error_location).exists(): + return entity + + relevant_record_rejections_codes_df = ( + self.spark_session.read.json( + path=contract_error_location, + schema=st.StructType([ + st.StructField("RecordIndex", st.IntegerType()), + st.StructField("FailureType", st.StringType()), + st.StructField("Status", st.StringType()), + st.StructField("Entity", st.StringType()), + ]), + ) + .filter( + (sf.col("FailureType") == sf.lit("record")) + & (sf.col("Status") != sf.lit("informational")) + & (sf.col("Entity") == sf.lit(entity_name)) + ) + .distinct() + .orderBy(sf.asc(sf.col("RecordIndex"))) + # todo - ^^ possibly relook at join strat. Does this help? Over prescriptive? + ) + if df_is_empty(relevant_record_rejections_codes_df): + return entity + filtered_entity = entity.join( + relevant_record_rejections_codes_df, + on=entity.__record_index__ == relevant_record_rejections_codes_df.RecordIndex, + how="anti", + ) + return filtered_entity + + +def spark_filter_contract_errors(cls): + """Class decorator to filter out records that failed casting and have record rejection scope""" + cls.filter_data_contract_record_rejections = _spark_filter_contract_errors + return cls + + @staticmethod # type: ignore def _spark_get_entity_count(entity: DataFrame) -> int: """Method to obtain entity count from a persisted parquet entity""" diff --git a/src/dve/core_engine/models.py b/src/dve/core_engine/models.py index 09fcbb3..f29889a 100644 --- a/src/dve/core_engine/models.py +++ b/src/dve/core_engine/models.py @@ -105,6 +105,8 @@ class SubmissionStatisticsRecord(AuditRecord): record_count: Optional[int] """Count of records in the submitted file""" + number_submission_rejections: Optional[int] + """Number of submission rejections raised following validation""" number_record_rejections: Optional[int] """Number of record rejections raised following validation""" number_warnings: Optional[int] diff --git a/src/dve/pipeline/pipeline.py b/src/dve/pipeline/pipeline.py index 00a0c51..1c32e87 100644 --- a/src/dve/pipeline/pipeline.py +++ b/src/dve/pipeline/pipeline.py @@ -379,9 +379,6 @@ def file_transformation_step( failed.append((submission_info, submission_status)) else: success.append((submission_info, submission_status)) - except AttributeError as exc: - self._logger.error(f"File transformation raised exception: {exc}") - raise exc except PERMISSIBLE_EXCEPTIONS as exc: self._logger.warning( f"File transformation raised exception: {exc}. Will be retried later." @@ -509,9 +506,6 @@ def data_contract_step( submission_info: SubmissionInfo submission_status: SubmissionStatus submission_info, submission_status = future.result() - except AttributeError as exc: - self._logger.error(f"Data Contract raised exception: {exc}") - raise exc except PERMISSIBLE_EXCEPTIONS as exc: self._logger.warning( f"Data Contract raised exception: {exc}. Will be retried later." @@ -616,8 +610,19 @@ def apply_business_rules( # pylint: disable=R0914 submission_status.processing_failed = True for entity_name, entity in entity_manager.entities.items(): + # Note BI filtering done within the apply_rules + self._logger.info(f"applying data contract filter to {entity_name}.") + if not entity_name.startswith("Original"): + filtered_entity = self._step_implementations.filter_data_contract_record_rejections( + working_directory, + entity, + entity_name, + ) + else: + self._logger.info(f"Skipping {entity_name}. Marked original.") + filtered_entity = entity projected = self._step_implementations.write_parquet( # type: ignore - entity, + filtered_entity, fh.joinuri( self.processed_files_path, submission_info.submission_id, @@ -629,6 +634,7 @@ def apply_business_rules( # pylint: disable=R0914 projected ) + # todo - add to submission_status around records that have passed record validations/rejected submission_status.number_of_records = self.get_entity_count( entity=entity_manager.entities[ f"""Original{rules.global_variables.get( @@ -682,9 +688,6 @@ def business_rule_step( unsucessful_files.append((submission_info, submission_status)) # type: ignore else: successful_files.append((submission_info, submission_status)) # type: ignore - except AttributeError as exc: - self._logger.error(f"Business Rules raised exception: {exc}") - raise exc except PERMISSIBLE_EXCEPTIONS as exc: self._logger.warning( f"Business Rules raised exception: {exc}. Will be retried later." @@ -758,10 +761,12 @@ def _get_error_dataframes(self, submission_id: str): df = pl.DataFrame(errors, schema={key: pl.Utf8() for key in errors[0]}) # type: ignore df = df.with_columns( - pl.when(pl.col("Status") == pl.lit("error")) # type: ignore + pl.when(pl.col("Status") == pl.lit("informational")) + .then(pl.lit("Warning")) + .when(pl.col("FailureType") == pl.lit("submission")) # type: ignore .then(pl.lit("Submission Failure")) # type: ignore - .otherwise(pl.lit("Warning")) # type: ignore - .alias("error_type") + .otherwise(pl.lit("Record Rejection")) # type: ignore + .alias("error_type") # type: ignore ) df = df.select( pl.col("Entity").alias("Table"), # type: ignore @@ -823,7 +828,8 @@ def error_report( sub_stats = SubmissionStatisticsRecord( submission_id=submission_info.submission_id, record_count=submission_status.number_of_records, - number_record_rejections=err_types.get("Submission Failure", 0), + number_submission_rejections=err_types.get("Submission Failure", 0), + number_record_rejections=err_types.get("Record Rejection", 0), number_warnings=err_types.get("Warning", 0), ) @@ -835,7 +841,7 @@ def error_report( summary_items = er.SummaryItems( submission_status=submission_status, summary_dict=summary_dict, - row_headings=["Submission Failure", "Warning"], + row_headings=["Submission Failure", "Record Rejection", "Warning"], ) workbook = er.ExcelFormat( @@ -894,9 +900,6 @@ def error_report_step( try: submission_info, submission_status, submission_stats, feedback_uri = future.result() reports.append((submission_info, submission_status, submission_stats, feedback_uri)) - except AttributeError as exc: - self._logger.error(f"Error reports raised exception: {exc}") - raise exc except PERMISSIBLE_EXCEPTIONS as exc: self._logger.warning( f"Error reports raised exception: {exc}. Will be retried later." diff --git a/src/dve/reporting/excel_report.py b/src/dve/reporting/excel_report.py index 82aa510..9471c83 100644 --- a/src/dve/reporting/excel_report.py +++ b/src/dve/reporting/excel_report.py @@ -141,6 +141,11 @@ def _add_submission_info(self, status: str, summary: Worksheet): for key, value in self.summary_dict.items(): summary.append(["", _key_renames.get(key, key), str(value)]) + summary.append([ + "", + "Total Number of Records Processed", + self.submission_status.number_of_records if self.submission_status.number_of_records else 0 # pylint: disable=C0301 + ]) summary.append(["", ""]) diff --git a/tests/features/animals.feature b/tests/features/animals.feature new file mode 100644 index 0000000..d68ddbf --- /dev/null +++ b/tests/features/animals.feature @@ -0,0 +1,59 @@ +Feature: Pipeline tests using the animal dataset + Test record rejection and ensuring that records are correctly removed from the entity and that + the correct validation feedback is raised in the error report. + + Scenario: Validate XML data with just record level rejections (duckdb) + Given I submit the animals file animals.xml for processing + And A duckdb pipeline is configured with schema file 'animals.dischema.json' + And I add initial audit entries for the submission + Then the latest audit record for the submission is marked with processing status file_transformation + When I run the file transformation phase + Then the animals entity is stored as a parquet after the file_transformation phase + And the latest audit record for the submission is marked with processing status data_contract + When I run the data contract phase + Then there are no record rejections from the data_contract phase + And the animals entity is stored as a parquet after the data_contract phase + And the latest audit record for the submission is marked with processing status business_rules + When I run the business rules phase + Then there are errors with the following details and associated error_count from the business_rules phase + | ErrorType | ErrorCode | error_count | + | record | ANE01 | 2 | + And The rules restrict "animals" to 3 qualifying records + When I run the error report phase + Then An error report is produced + And The statistics entry for the submission shows the following information + | parameter | value | + | record_count | 5 | + | number_record_rejections | 2 | + | number_warnings | 0 | + + Scenario: Validate XML data with a mixture of error types in (duckdb) + Given I submit the animals file animals_mixture.xml for processing + And A duckdb pipeline is configured with schema file 'animals.dischema.json' + And I add initial audit entries for the submission + Then the latest audit record for the submission is marked with processing status file_transformation + When I run the file transformation phase + Then the animals entity is stored as a parquet after the file_transformation phase + And the latest audit record for the submission is marked with processing status data_contract + When I run the data contract phase + Then there are no record rejections from the data_contract phase + # Then there are errors with the following details and associated error_count from the data_contract phase + # | FailureType | Status | ErrorCode | error_count | + # | record | error | FieldBlank | 1 | + And the animals entity is stored as a parquet after the data_contract phase + And the latest audit record for the submission is marked with processing status business_rules + When I run the business rules phase + Then there are errors with the following details and associated error_count from the business_rules phase + | FailureType | Status | ErrorCode | error_count | + | record | error | ANE01 | 2 | + | submission | error | ANE02 | 1 | + | record | informational | ANE03 | 1 | + And The rules restrict "animals" to 5 qualifying records + When I run the error report phase + Then An error report is produced + And The statistics entry for the submission shows the following information + | parameter | value | + | record_count | 7 | + | number_submission_rejections | 1 | + | number_record_rejections | 2 | + | number_warnings | 1 | diff --git a/tests/features/demographics.feature b/tests/features/demographics.feature index aa59bfc..af4b62a 100644 --- a/tests/features/demographics.feature +++ b/tests/features/demographics.feature @@ -17,7 +17,7 @@ Feature: Pipeline tests using the ambsys dataset And the demographics entity is stored as a parquet after the data_contract phase And the latest audit record for the submission is marked with processing status business_rules When I run the business rules phase - Then The rules restrict "demographics" to 6 qualifying records + Then The rules restrict "demographics" to 2 qualifying records And At least one row from "demographics" has generated error code "BAD_NHS" And the demographics entity is stored as a parquet after the business_rules phase And The entity "demographics" does not contain an entry for "FALSE" in column "NHS_Number_Valid" @@ -43,7 +43,7 @@ Feature: Pipeline tests using the ambsys dataset And the demographics entity is stored as a parquet after the data_contract phase And the latest audit record for the submission is marked with processing status business_rules When I run the business rules phase - Then The rules restrict "demographics" to 6 qualifying records + Then The rules restrict "demographics" to 2 qualifying records And At least one row from "demographics" has generated error code "BAD_NHS" And the demographics entity is stored as a parquet after the business_rules phase And The entity "demographics" does not contain an entry for "FALSE" in column "NHS_Number_Valid" diff --git a/tests/features/movies.feature b/tests/features/movies.feature index fa041ea..6916a4e 100644 --- a/tests/features/movies.feature +++ b/tests/features/movies.feature @@ -28,7 +28,7 @@ Feature: Pipeline tests using the movies dataset And the movies entity is stored as a parquet after the data_contract phase And the latest audit record for the submission is marked with processing status business_rules When I run the business rules phase - Then The rules restrict "movies" to 4 qualifying records + Then The rules restrict "movies" to 2 qualifying records And there are errors with the following details and associated error_count from the business_rules phase | ErrorCode | ErrorMessage | RecordIndex | error_count | | LIMITED_RATINGS | Movie has too few ratings ([6.5]) | 4 | 1 | @@ -64,7 +64,7 @@ Feature: Pipeline tests using the movies dataset And the movies entity is stored as a parquet after the data_contract phase And the latest audit record for the submission is marked with processing status business_rules When I run the business rules phase - Then The rules restrict "movies" to 4 qualifying records + Then The rules restrict "movies" to 2 qualifying records And there are errors with the following details and associated error_count from the business_rules phase | ErrorCode | ErrorMessage | RecordIndex | error_count | | LIMITED_RATINGS | Movie has too few ratings ([6.5]) | 4 | 1 | diff --git a/tests/features/steps/steps_pipeline.py b/tests/features/steps/steps_pipeline.py index 55acadd..061bfe1 100644 --- a/tests/features/steps/steps_pipeline.py +++ b/tests/features/steps/steps_pipeline.py @@ -50,14 +50,14 @@ def setup_spark_pipeline( rules_path = get_test_file_path(f"{dataset_id}/{schema_file_name}").resolve().as_uri() return SparkDVEPipeline( - processed_files_path=processing_path.as_uri(), + processed_files_path=processing_path.as_posix(), audit_tables=SparkAuditingManager( database="dve", spark=spark, ), job_run_id=12345, rules_path=rules_path, - submitted_files_path=processing_path.as_uri(), + submitted_files_path=processing_path.as_posix(), spark=spark, ) diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py index 19e96e2..4a24960 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py @@ -1,10 +1,15 @@ """Test Duck DB helpers""" +# pylint: disable=C0301,C0116 + import datetime +import json +import os import tempfile from pathlib import Path from typing import Any, List +import polars as pl import pytest import pyspark.sql.types as pst from duckdb import DuckDBPyRelation, DuckDBPyConnection @@ -12,10 +17,13 @@ from pyspark.sql import Row, SparkSession from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( + _ddb_filter_contract_errors, _ddb_read_parquet, duckdb_rel_to_dictionaries, get_duckdb_cast_statement_from_annotation, - get_duckdb_type_from_annotation) + get_duckdb_type_from_annotation, + relation_is_empty, +) @pytest.fixture def casting_test_table(temp_ddb_conn): @@ -51,8 +59,60 @@ def casting_test_table(temp_ddb_conn): yield temp_ddb_conn conn.sql("DROP TABLE IF EXISTS test_casting") - - + + +@pytest.fixture +def example_data_contract_error_codes(temp_ddb_conn): + _, con = temp_ddb_conn + + test_df = pl.DataFrame([ # pylint: disable=W0612 + {"id": "field1", "attr": 1, "__record_index__": 1,}, + {"id": "field2", "attr": None, "__record_index__": 2,}, + {"id": "field3", "attr": 2, "__record_index__": 3,}, + {"id": "field4", "attr": None, "__record_index__": 4,}, + ]) + test_entity = con.sql("SELECT * FROM test_df") + error_contract_messages = [ + { + "Entity": "test_entity", + "Key": "", + "FailureType": "record", + "Status": "error", + "ErrorType": "", + "ErrorLocation": "attr", + "ErrorMessage": "", + "ErrorCode": "", + "ReportingField": "attr", + "RecordIndex": 2, + "Value": "hello", + "Category": "Bad value" + }, + { + "Entity": "test_entity", + "Key": "", + "FailureType": "record", + "Status": "error", + "ErrorType": "", + "ErrorLocation": "attr", + "ErrorMessage": "", + "ErrorCode": "", + "ReportingField": "attr", + "RecordIndex": 4, + "Value": "world", + "Category": "Bad value" + } + ] + with tempfile.TemporaryDirectory() as temp_dir_path: + os.mkdir(Path(temp_dir_path, "errors")) + temp_error_file = Path(temp_dir_path, "errors", "data_contract_errors.jsonl") + with open(temp_error_file, encoding="utf-8", mode="w") as tpf: + for error in error_contract_messages: + json.dump(error, tpf) + tpf.write("\n") + + yield con, test_entity, temp_dir_path + + class BasicModel(BaseModel): str_field: str @@ -176,4 +236,23 @@ def test_use_cast_statements(casting_test_table): not dodgy_date_rec.get("basic_model",{}).get("date_field") and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) ) - + + +def test_ddb_filter_contract_errors(example_data_contract_error_codes): # pylint: disable=W0621 + ddb_cnn, entity_rel, temp_dir = example_data_contract_error_codes + expected_df = pl.DataFrame([ # pylint: disable=W0612 + {"id": "field1", "attr": 1, "__record_index__": 1,}, + {"id": "field3", "attr": 2, "__record_index__": 3,}, + ]) + expected_rel = ddb_cnn.sql("SELECT * FROM expected_df") + result_rel = _ddb_filter_contract_errors( + TempConnection(ddb_cnn), temp_dir, entity_rel, "test_entity" + ) + assert result_rel.pl().shape[0] == 2 + assert expected_rel.join(result_rel, "__record_index__", "anti").pl().shape[0] == 0 + + +def test_relation_is_empty(temp_ddb_conn: DuckDBPyConnection): + _, con = temp_ddb_conn + rel = con.sql("SELECT 'abc' AS test").filter("test IS NULL") + assert relation_is_empty(rel) diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py index 7502673..8a0e45e 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py @@ -1,9 +1,15 @@ """Tests for UDF helpers.""" # pylint: disable=redefined-outer-name +# pylint: disable=C0301,C0115,C0116 + import datetime as dt +import json +import os +import tempfile from dataclasses import dataclass from decimal import Decimal +from pathlib import Path from typing import Any, List, Optional, Union from uuid import UUID @@ -19,6 +25,7 @@ from dve.core_engine.backends.implementations.spark.spark_helpers import ( DecimalConfig, create_udf, + _spark_filter_contract_errors, get_spark_cast_statement_from_annotation, get_type_from_annotation, object_to_spark_literal, @@ -42,9 +49,56 @@ def casting_dataframe(spark): StructField("basic_model", bm_schema), StructField("another_model", StructType([StructField("unique_id", StringType()), StructField("basic_models", ArrayType(bm_schema))]))]) yield spark.createDataFrame(data, schema=schema) - - - + + +@pytest.fixture +def example_data_contract_error_codes(spark: SparkSession): + test_df = spark.createDataFrame([ # pylint: disable=W0612 + {"id": "field1", "attr": 1, "__record_index__": 1,}, + {"id": "field2", "attr": None, "__record_index__": 2,}, + {"id": "field3", "attr": 2, "__record_index__": 3,}, + {"id": "field4", "attr": None, "__record_index__": 4,}, + ]) + error_contract_messages = [ + { + "Entity": "test_entity", + "Key": "", + "FailureType": "record", + "Status": "error", + "ErrorType": "", + "ErrorLocation": "attr", + "ErrorMessage": "", + "ErrorCode": "", + "ReportingField": "attr", + "RecordIndex": 2, + "Value": "hello", + "Category": "Bad value" + }, + { + "Entity": "test_entity", + "Key": "", + "FailureType": "record", + "Status": "error", + "ErrorType": "", + "ErrorLocation": "attr", + "ErrorMessage": "", + "ErrorCode": "", + "ReportingField": "attr", + "RecordIndex": 4, + "Value": "world", + "Category": "Bad value" + } + ] + with tempfile.TemporaryDirectory() as temp_dir_path: + os.mkdir(Path(temp_dir_path, "errors")) + temp_error_file = Path(temp_dir_path, "errors", "data_contract_errors.jsonl") + with open(temp_error_file, encoding="utf-8", mode="w") as tpf: + for error in error_contract_messages: + json.dump(error, tpf) + tpf.write("\n") + + yield test_df, temp_dir_path + class BasicModel(BaseModel): str_field: str @@ -264,4 +318,25 @@ def test_use_cast_statements(spark, casting_dataframe): not dodgy_date_rec.get("basic_model",{}).get("date_field") and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) ) - assert cast_df \ No newline at end of file + assert cast_df + + +class TempSparkSession: + def __init__(self, spark: SparkSession): + self.spark_session = spark + + +def test_spark_filter_contract_errors(spark: SparkSession, example_data_contract_error_codes): # pylint: disable=W0621 + entity_df, temp_dir = example_data_contract_error_codes + expected_df = spark.createDataFrame([ # pylint: disable=W0612 + {"id": "field1", "attr": 1, "__record_index__": 1,}, + {"id": "field3", "attr": 2, "__record_index__": 3,}, + ]) + result_df = _spark_filter_contract_errors( + TempSparkSession(spark), + temp_dir, + entity_df, + "test_entity" + ) + assert result_df.count() == 2 + assert expected_df.join(result_df, "__record_index__", "anti").count() == 0 diff --git a/tests/test_pipeline/test_spark_pipeline.py b/tests/test_pipeline/test_spark_pipeline.py index b3048a1..063ced7 100644 --- a/tests/test_pipeline/test_spark_pipeline.py +++ b/tests/test_pipeline/test_spark_pipeline.py @@ -439,7 +439,9 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out ("Dataset Id", "planets"), ("File Name", "doesnotmatter"), ("File Extension", "json"), - ("Submission Failure", "2"), + ("Total Number of Records Processed", "9"), + ("Submission Failure", "0"), + ("Record Rejection", "2"), ("Warning", "0"), ] @@ -455,7 +457,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out [ OrderedDict( **{ - "Type": "Submission Failure", + "Type": "Record Rejection", "Group": "planets", "Data Item Submission Name": "orbitalPeriod", "Category": "Bad value", @@ -465,7 +467,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out ), OrderedDict( **{ - "Type": "Submission Failure", + "Type": "Record Rejection", "Group": "planets", "Data Item Submission Name": "gravity", "Category": "Bad value", @@ -485,7 +487,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out OrderedDict( **{ "Group": "planets", - "Type": "Submission Failure", + "Type": "Record Rejection", "Error Code": "LONG_ORBIT", "Data Item Submission Name": "orbitalPeriod", "Errors and Warnings": "Planet has long orbital period", @@ -498,7 +500,7 @@ def test_error_report_where_report_is_expected( # pylint: disable=redefined-out OrderedDict( **{ "Group": "planets", - "Type": "Submission Failure", + "Type": "Record Rejection", "Error Code": "STRONG_GRAVITY", "Data Item Submission Name": "gravity", "Errors and Warnings": "Planet has too strong gravity", diff --git a/tests/testdata/animals/animals.dischema.json b/tests/testdata/animals/animals.dischema.json new file mode 100644 index 0000000..0e1eda1 --- /dev/null +++ b/tests/testdata/animals/animals.dischema.json @@ -0,0 +1,54 @@ +{ + "contract": { + "schemas": {}, + "datasets": { + "animals": { + "fields": { + "name": "str", + "height": "float", + "weight": "float", + "region": "str" + }, + "reader_config": { + ".xml": { + "reader": "DuckDBXMLStreamReader", + "kwargs": { + "record_tag": "animal", + "root_tag": "animals" + } + } + }, + "mandatory_fields": [ + "name" + ] + } + } + }, + "transformations": { + "filters": [ + { + "entity": "animals", + "name": "check_valid_region", + "expression": "lower(region) in ('africa', 'asia')", + "error_code": "ANE01", + "failure_message": "Record rejected - `{{ region }}` is not in a valid region." + }, + { + "entity": "animals", + "name": "check_for_pets", + "expression": "lower(name) != 'human'", + "error_code": "ANE02", + "failure_message": "Submission Rejected - 'Human' is not a valid animal.", + "failure_type": "submission" + }, + { + "entity": "animals", + "name": "check_valid_weight", + "expression": "weight > 0", + "error_code": "ANE03", + "failure_message": "Warning - `{{ weight }}` is below zero.", + "is_informational": true + } + ] + } +} \ No newline at end of file diff --git a/tests/testdata/animals/animals.xml b/tests/testdata/animals/animals.xml new file mode 100644 index 0000000..60bdcef --- /dev/null +++ b/tests/testdata/animals/animals.xml @@ -0,0 +1,33 @@ + + + + African Elephant + 3.5 + 6000.0 + Africa + + + Bengal Tiger + 1.1 + 260.0 + Asia + + + Giraffe + 5.5 + 1200.0 + Africa + + + Polar Bear + 2.6 + 900.0 + Arctic + + + Blue Whale + 24.0 + 180000.0 + Oceans + + diff --git a/tests/testdata/animals/animals_mixture.xml b/tests/testdata/animals/animals_mixture.xml new file mode 100644 index 0000000..230f790 --- /dev/null +++ b/tests/testdata/animals/animals_mixture.xml @@ -0,0 +1,45 @@ + + + + African Elephant + 3.5 + 6000.0 + Africa + + + Bengal Tiger + 1.1 + 260.0 + Asia + + + Giraffe + 5.5 + 1200.0 + Africa + + + Polar Bear + 2.6 + 900.0 + Arctic + + + Blue Whale + 24.0 + 180000.0 + Oceans + + + Human + 1.7 + 70.0 + Africa + + + African Elephant + 3.5 + -6000.0 + Africa + +