Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .tool-versions
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
python 3.11.14
poetry 2.3.3
python 3.12.12
poetry 2.4.1
java liberica-1.8.0
112 changes: 56 additions & 56 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ optional = true
behave = "1.3.3"
coverage = "7.11.0"
moto = {extras = ["s3"], version = "4.2.14"}
cryptography = "48.0.1" # dependency of `moto`
requests = "2.33.0" # dependency of `moto`
Werkzeug = "3.1.6"
pytest = "9.0.3"
Expand Down
11 changes: 2 additions & 9 deletions src/dve/core_engine/backends/base/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ def read_to_entity_type(
"""
if entity_name == Iterator[dict[str, Any]]:
return self.read_to_py_iterator(
resource,
entity_name,
schema, # type: ignore
all_model_fields
resource, entity_name, schema, all_model_fields # type: ignore
)

self.raise_if_not_sensible_file(resource, entity_name)
Expand All @@ -133,11 +130,7 @@ def read_to_entity_type(
raise ReaderLacksEntityTypeSupport(entity_type=entity_type) from err

return reader_func(
self,
resource,
entity_name,
schema,
all_model_fields=all_model_fields # type: ignore
self, resource, entity_name, schema, all_model_fields=all_model_fields # type: ignore
)

def add_record_index(self, entity: EntityType, **kwargs) -> EntityType:
Expand Down
10 changes: 10 additions & 0 deletions src/dve/core_engine/backends/base/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,3 +681,13 @@ def read_parquet(self, path: URI, **kwargs) -> EntityType:
def write_parquet(self, entity: EntityType, target_location: URI, **kwargs) -> URI:
"""Method to write parquet files"""
raise NotImplementedError()

def filter_data_contract_record_rejections(
self,
working_directory: URI,
entity: EntityType,
entity_name: EntityName,
**kwargs,
):
"""Method to filter out record rejection errors from the data contract for a given entity"""
raise NotImplementedError()
5 changes: 1 addition & 4 deletions src/dve/core_engine/backends/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ class UnableToParseCSVError(MessageBearingError):
"""An error raised when unable to parse a CSV file"""

def __init__(
self,
entity_name: str,
field_check_error_message: str,
field_check_error_code: str
self, entity_name: str, field_check_error_message: str, field_check_error_code: str
):
super().__init__(
messages=[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from pydantic import BaseModel
from typing_extensions import Annotated, get_args, get_origin, get_type_hints

from dve.common.error_utils import get_feedback_errors_uri
from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
from dve.core_engine.type_hints import URI
from dve.core_engine.type_hints import URI, EntityName
from dve.parser.file_handling.service import LocalFilesystemImplementation, _get_implementation


Expand Down Expand Up @@ -100,7 +101,7 @@ def table_exists(connection: DuckDBPyConnection, table_name: str) -> bool:

def relation_is_empty(relation: DuckDBPyRelation) -> bool:
"""Check if a duckdb relation is empty"""
if relation.limit(1).count("*"):
if relation.limit(1).shape[0] > 0:
return False
return True

Expand Down Expand Up @@ -256,6 +257,48 @@ def duckdb_write_parquet(cls):
return cls


def _ddb_filter_contract_errors(
self,
working_directory: URI,
entity: DuckDBPyRelation,
entity_name: EntityName,
) -> DuckDBPyRelation:
contract_error_location = get_feedback_errors_uri(working_directory, "data_contract")
if not Path(contract_error_location).exists():
return entity
relevant_record_rejection_codes_rel = (
self._connection.read_json(
contract_error_location,
columns={
"RecordIndex": "INTEGER",
"FailureType": "STRING",
"Status": "STRING",
"Entity": "STRING",
},
)
.filter(
f"FailureType == 'record' AND Status != 'informational' AND Entity = '{entity_name}'"
) # pylint: disable=C0301
.select("RecordIndex")
.distinct()
.order("RecordIndex asc")
)

if relation_is_empty(relevant_record_rejection_codes_rel):
return entity

filtered_entity = entity.join(
relevant_record_rejection_codes_rel, condition="__record_index__ == RecordIndex", how="anti"
)
return filtered_entity


def ddb_filter_contract_errors(cls):
"""Class decorator to filter out records that failed casting and have record rejection scope"""
cls.filter_data_contract_record_rejections = _ddb_filter_contract_errors
return cls


@staticmethod # type: ignore
def _duckdb_get_entity_count(entity: DuckDBPyRelation) -> int:
"""Method to obtain entity count from a persisted parquet entity"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
quote_char=quotechar,
field_check=field_check,
field_check_error_code=field_check_error_code,
field_check_error_message=field_check_error_message
field_check_error_message=field_check_error_message,
)

def read_to_py_iterator(
Expand Down Expand Up @@ -254,7 +254,7 @@ def read_to_relation( # pylint: disable=unused-argument
resource=resource,
entity_name=entity_name,
schema=schema,
all_model_fields=all_model_fields
all_model_fields=all_model_fields,
)
entity = entity.select(StarExpression(exclude=[RECORD_INDEX_COLUMN_NAME])).distinct()
no_records = entity.shape[0]
Expand Down
2 changes: 2 additions & 0 deletions src/dve/core_engine/backends/implementations/duckdb/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dve.core_engine.backends.exceptions import ConstraintError
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
DDBStruct,
ddb_filter_contract_errors,
duckdb_read_parquet,
duckdb_record_index,
duckdb_rel_to_dictionaries,
Expand Down Expand Up @@ -61,6 +62,7 @@
@duckdb_record_index
@duckdb_write_parquet
@duckdb_read_parquet
@ddb_filter_contract_errors
class DuckDBStepImplementations(BaseStepImplementations[DuckDBPyRelation]):
"""An implementation of transformation steps in duckdb."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from pyspark.sql.types import StructType

from dve.core_engine.backends.base.reader import read_function
from dve.core_engine.backends.readers.csv import CSVFileReader
from dve.core_engine.backends.exceptions import EmptyFileError
from dve.core_engine.backends.implementations.spark.spark_helpers import (
get_type_from_annotation,
spark_record_index,
spark_write_parquet,
)
from dve.core_engine.backends.readers.csv import CSVFileReader
from dve.core_engine.type_hints import URI, EntityName
from dve.parser.file_handling import get_content_length

Expand Down
2 changes: 2 additions & 0 deletions src/dve/core_engine/backends/implementations/spark/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
create_udf,
get_all_registered_udfs,
object_to_spark_literal,
spark_filter_contract_errors,
spark_read_parquet,
spark_record_index,
spark_write_parquet,
Expand Down Expand Up @@ -53,6 +54,7 @@
@spark_record_index
@spark_write_parquet
@spark_read_parquet
@spark_filter_contract_errors
class SparkStepImplementations(BaseStepImplementations[DataFrame]):
"""An implementation of transformation steps in Apache Spark."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import dataclass, is_dataclass
from decimal import Decimal
from functools import wraps
from pathlib import Path
from typing import Any, ClassVar, Optional, TypeVar, Union, overload

from delta.exceptions import ConcurrentAppendException, DeltaConcurrentModificationException
Expand All @@ -25,9 +26,10 @@
from pyspark.sql.types import LongType, StructField, StructType
from typing_extensions import Annotated, Protocol, TypedDict, get_args, get_origin, get_type_hints

from dve.common.error_utils import get_feedback_errors_uri
from dve.core_engine.backends.base.utilities import _get_non_heterogenous_type
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
from dve.core_engine.type_hints import URI
from dve.core_engine.type_hints import URI, EntityName

# It would be really nice if there was a more parameterisable
# way of doing this.
Expand Down Expand Up @@ -365,6 +367,53 @@
return cls


def _spark_filter_contract_errors(
self,
working_directory: URI,
entity: DataFrame,
entity_name: EntityName,
) -> DataFrame:
contract_error_location = get_feedback_errors_uri(working_directory, "data_contract")
if not Path(contract_error_location).exists():
return entity

relevant_record_rejections_codes_df = (
self.spark_session.read.json(
path=contract_error_location,
schema=st.StructType(
[
st.StructField("RecordIndex", st.IntegerType()),
st.StructField("FailureType", st.StringType()),
st.StructField("Status", st.StringType()),
st.StructField("Entity", st.StringType()),
]
),
)
.filter(
(sf.col("FailureType") == sf.lit("record"))
& (sf.col("Status") != sf.lit("informational"))
& (sf.col("Entity") == sf.lit(entity_name))
)
.distinct()
.orderBy(sf.asc(sf.col("RecordIndex")))
# todo - ^^ possibly relook at join strat. Does this help? Over prescriptive?

Check warning on line 399 in src/dve/core_engine/backends/implementations/spark/spark_helpers.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Complete the task associated to this "TODO" comment.

See more on https://sonarcloud.io/project/issues?id=NHSDigital_data-validation-engine&issues=AZ8ZBACx519iTKrq5mza&open=AZ8ZBACx519iTKrq5mza&pullRequest=121
)
if df_is_empty(relevant_record_rejections_codes_df):
return entity
filtered_entity = entity.join(
relevant_record_rejections_codes_df,
on=entity.__record_index__ == relevant_record_rejections_codes_df.RecordIndex,
how="anti",
)
return filtered_entity


def spark_filter_contract_errors(cls):
"""Class decorator to filter out records that failed casting and have record rejection scope"""
cls.filter_data_contract_record_rejections = _spark_filter_contract_errors
return cls


@staticmethod # type: ignore
def _spark_get_entity_count(entity: DataFrame) -> int:
"""Method to obtain entity count from a persisted parquet entity"""
Expand Down
2 changes: 1 addition & 1 deletion src/dve/core_engine/backends/metadata/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def schemas(self) -> dict[EntityName, type[BaseModel]]:
"""The per-entity schemas, as pydantic models."""
if not self._schemas:
for entity_name, validator in self.validators.items():
self._schemas[entity_name] = validator.model # type: ignore # pylint: disable=E1137
self._schemas[entity_name] = validator.model # type: ignore # pylint: disable=E1137
return self._schemas.copy() # pylint: disable=E1101

@root_validator(allow_reuse=True)
Expand Down
4 changes: 2 additions & 2 deletions src/dve/core_engine/backends/readers/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
MissingHeaderError,
)
from dve.core_engine.backends.readers.utilities import (
raise_message_bearing_error_on_header_differences
get_all_model_fields,
raise_message_bearing_error_on_header_differences,
)
from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model
from dve.core_engine.constants import RECORD_INDEX_COLUMN_NAME
from dve.core_engine.backends.readers.utilities import get_all_model_fields
from dve.core_engine.type_hints import EntityName
from dve.parser.file_handling import get_content_length, open_stream
from dve.parser.file_handling.implementations.file import file_uri_to_local_path
Expand Down
14 changes: 7 additions & 7 deletions src/dve/core_engine/backends/readers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def raise_message_bearing_error_on_header_differences(
header or vice versa.
"""
missing, additional = check_csv_header_expected(
resource,
expected_schema,
all_model_fields,
delimiter,
quote_char
resource, expected_schema, all_model_fields, delimiter, quote_char
)

if missing or additional:
record_details_missing = f"missing fields: {', '.join(sorted(missing))};" if missing else "" # pylint: disable=C0301
record_details_additional = f"additional fields: {', '.join(sorted(additional))};" if additional else "" # pylint: disable=C0301
record_details_missing = (
f"missing fields: {', '.join(sorted(missing))};" if missing else ""
) # pylint: disable=C0301
record_details_additional = (
f"additional fields: {', '.join(sorted(additional))};" if additional else ""
) # pylint: disable=C0301
raise MessageBearingError(
"The CSV header doesn't match what is expected",
messages=[
Expand Down
4 changes: 3 additions & 1 deletion src/dve/core_engine/configuration/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def _load_rules_and_vars(self) -> tuple[list[Rule], list[TemplateVariables]]:
rules, local_variable_list = [], []
added_rules: set[RuleName] = set()

for index, complex_rule_config in enumerate(self.transformations.complex_rules): # pylint: disable=E1101
for index, complex_rule_config in enumerate(
self.transformations.complex_rules
): # pylint: disable=E1101
rule, local_params, deps = self._resolve_business_rule(complex_rule_config)
missing_rules = deps - added_rules
if missing_rules:
Expand Down
16 changes: 12 additions & 4 deletions src/dve/core_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def __exit__(
exc_value: Optional[Exception],
traceback: Optional[TracebackType],
) -> None:
self.main_log.info(f"Exiting pipeline context, clearing {self.cache_prefix!r}") # pylint: disable=E1101
self.main_log.info(
f"Exiting pipeline context, clearing {self.cache_prefix!r}"
) # pylint: disable=E1101
cache_dir = self._cache_dir
self._cache_dir = None

Expand Down Expand Up @@ -198,17 +200,23 @@ def _write_entity_outputs(self, entities: SparkEntities) -> SparkEntities:
"""
output_entities = {}

self.main_log.info(f"Writing entities to the output location: {self.output_prefix_uri}") # pylint: disable=E1101
self.main_log.info(
f"Writing entities to the output location: {self.output_prefix_uri}"
) # pylint: disable=E1101
for entity_name, entity in entities.items():
entity = entity.drop(RECORD_INDEX_COLUMN_NAME)

self.main_log.info(f"Entity: {entity_name} {type(entity)}") # pylint: disable=E1101

output_uri = joinuri(self.output_prefix_uri, entity_name)
if get_resource_exists(output_uri):
self.main_log.info(f"{output_uri} already exists - will be overwritten") # pylint: disable=E1101
self.main_log.info(
f"{output_uri} already exists - will be overwritten"
) # pylint: disable=E1101

self.main_log.info(f"+ Writing parquet output to {output_uri!r}") # pylint: disable=E1101
self.main_log.info(
f"+ Writing parquet output to {output_uri!r}"
) # pylint: disable=E1101
entity.write.mode("overwrite").parquet(output_uri)
spark_session = SparkSession.builder.getOrCreate()
output_entities[entity_name] = spark_session.read.format("parquet").load(
Expand Down
Loading
Loading