diff --git a/frictionless/detector/detector.py b/frictionless/detector/detector.py index 4f3ba3154c..c59e86323a 100644 --- a/frictionless/detector/detector.py +++ b/frictionless/detector/detector.py @@ -9,7 +9,6 @@ from .. import helpers, settings from ..dialect import Dialect -from ..exception import FrictionlessException from ..fields import AnyField from ..metadata import Metadata from ..platform import platform @@ -411,33 +410,6 @@ def detect_schema( fields[index] = AnyField(name=name, schema=schema) # type: ignore schema.fields = fields # type: ignore - # Sync schema - if self.schema_sync: - if labels: - case_sensitive = options["header_case"] - - if not case_sensitive: - labels = [label.lower() for label in labels] - - if len(labels) != len(set(labels)): - note = '"schema_sync" requires unique labels in the header' - raise FrictionlessException(note) - - mapped_fields = self.mapped_schema_fields_names( - schema.fields, # type: ignore - case_sensitive, - ) - - self.rearrange_schema_fields_given_labels( - mapped_fields, - schema, - labels, - ) - - self.add_missing_required_labels_to_schema_fields( - mapped_fields, schema, labels, case_sensitive - ) - # Patch schema if self.schema_patch: patch = deepcopy(self.schema_patch) @@ -452,56 +424,3 @@ def detect_schema( return schema - @staticmethod - def mapped_schema_fields_names( - fields: List[Field], case_sensitive: bool - ) -> Dict[str, Field]: - """Create a dictionnary to map field names with schema fields""" - if case_sensitive: - return {field.name: field for field in fields} - else: - return {field.name.lower(): field for field in fields} - - @staticmethod - def rearrange_schema_fields_given_labels( - fields_mapping: Dict[str, Field], - schema: Schema, - labels: List[str], - ): - """Rearrange fields according to the order of labels. All fields - missing from labels are dropped""" - schema.clear_fields() - - for name in labels: - default_field = Field.from_descriptor({"name": name, "type": "any"}) - field = fields_mapping.get(name, default_field) - schema.add_field(field) - - def add_missing_required_labels_to_schema_fields( - self, - fields_mapping: Dict[str, Field], - schema: Schema, - labels: List[str], - case_sensitive: bool, - ): - """This method aims to add missing required labels and - primary key field not in labels to schema fields. - """ - for name, field in fields_mapping.items(): - if ( - self.field_is_required(field, schema, case_sensitive) - and name not in labels - ): - schema.add_field(field) - - @staticmethod - def field_is_required( - field: Field, - schema: Schema, - case_sensitive: bool, - ) -> bool: - if case_sensitive: - return field.required or field.name in schema.primary_key - else: - lower_primary_key = [pk.lower() for pk in schema.primary_key] - return field.required or field.name.lower() in lower_primary_key diff --git a/frictionless/resource/__spec__/test_validate.py b/frictionless/resource/__spec__/test_validate.py index e5a539358f..7fe64fab63 100644 --- a/frictionless/resource/__spec__/test_validate.py +++ b/frictionless/resource/__spec__/test_validate.py @@ -509,10 +509,12 @@ def test_resource_validate_detector_sync_schema(): ) report = resource.validate() assert report.valid + # schema_sync no longer mutates the user-provided schema: the order + # given by the user is preserved. assert resource.schema.to_descriptor() == { "fields": [ - {"name": "name", "type": "string"}, {"name": "id", "type": "integer"}, + {"name": "name", "type": "string"}, ], } diff --git a/frictionless/resources/table.py b/frictionless/resources/table.py index 056fee6e53..a1dbd10343 100644 --- a/frictionless/resources/table.py +++ b/frictionless/resources/table.py @@ -217,6 +217,7 @@ def __open_header(self): fields=self.schema.fields, row_numbers=self.dialect.header_rows, ignore_case=not self.dialect.header_case, + schema_sync=self.detector.schema_sync, ) # Handle errors @@ -270,24 +271,9 @@ def __open_lookup(self): self.__lookup[source_name][source_key].add(cells) def __open_row_stream(self): - # TODO: we need to rework this field_info / row code - # During row streaming we create a field info structure - # This structure is optimized and detached version of schema.fields - # We create all data structures in-advance to share them between rows - - # Create field info - field_number = 0 - field_info: Dict[str, Any] = {"names": [], "objects": [], "mapping": {}} - for field in self.schema.fields: - field_number += 1 - field_info["names"].append(field.name) - field_info["objects"].append(field.to_copy()) - field_info["mapping"][field.name] = ( - field, - field_number, - field.create_cell_reader(), - field.create_cell_writer(), - ) + # The header knows the fields to expect in the data (in order, and + # accounting for schema_sync rules). + expected_fields: List[Field] = self.header.get_expected_fields() # Create state memory_unique: Dict[str, Any] = {} @@ -320,7 +306,7 @@ def row_stream(): row = Row( cells, - field_info=field_info, + fields=expected_fields, row_number=row_number, ) @@ -400,50 +386,9 @@ def row_stream(): # Yield row yield row - if self.detector.schema_sync: - # Missing required labels are not included in the - # field_info parameter used for row creation - for field in self.schema.fields: - self.remove_missing_required_label_from_field_info(field, field_info) - # Create row stream self.__row_stream = row_stream() - def remove_missing_required_label_from_field_info( - self, field: Field, field_info: Dict[str, Any] - ): - is_case_sensitive = self.dialect.header_case - if self.label_is_missing( - field.name, field_info["names"], self.labels, is_case_sensitive - ): - self.remove_field_from_field_info(field.name, field_info) - - @staticmethod - def label_is_missing( - field_name: str, - expected_field_names: List[str], - table_labels: types.ILabels, - case_sensitive: bool, - ) -> bool: - """Check if a schema field name is missing from the TableResource - labels. - """ - if not case_sensitive: - field_name = field_name.lower() - table_labels = [label.lower() for label in table_labels] - expected_field_names = [ - field_name.lower() for field_name in expected_field_names - ] - - return field_name not in table_labels and field_name in expected_field_names - - @staticmethod - def remove_field_from_field_info(field_name: str, field_info: Dict[str, Any]): - field_index = field_info["names"].index(field_name) - del field_info["names"][field_index] - del field_info["objects"][field_index] - del field_info["mapping"][field_name] - def primary_key_cells(self, row: Row, case_sensitive: bool) -> Tuple[Any, ...]: """Create a tuple containg all cells from a given row associated to primary keys""" diff --git a/frictionless/table/__spec__/test_header.py b/frictionless/table/__spec__/test_header.py index b4c43600b1..bfe6e51351 100644 --- a/frictionless/table/__spec__/test_header.py +++ b/frictionless/table/__spec__/test_header.py @@ -3,6 +3,7 @@ import frictionless from frictionless import Schema, fields from frictionless.resources import TableResource +from frictionless.table.header import Header # General @@ -42,6 +43,70 @@ def test_missing_label(): assert header.valid is False +# get_expected_fields + + +def _make_header(labels, field_names, *, schema_sync=False, ignore_case=False): + return Header( + labels, + fields=[fields.AnyField(name=name) for name in field_names], + row_numbers=[1], + ignore_case=ignore_case, + schema_sync=schema_sync, + ) + + +@pytest.mark.parametrize( + "labels, field_names, schema_sync, ignore_case, expected_names", + [ + pytest.param( + ["a", "b"], ["a", "b"], False, False, ["a", "b"], + id="no-sync: schema fields are returned as-is", + ), + pytest.param( + ["b", "a"], ["a", "b"], False, False, ["a", "b"], + id="no-sync: schema order is kept even if labels differ", + ), + pytest.param( + ["b", "a"], ["a", "b"], True, False, ["b", "a"], + id="sync: fields are reordered to match labels", + ), + pytest.param( + ["a", "extra"], ["a"], True, False, ["a", "extra"], + id="sync: extra labels get a default any-typed field", + ), + pytest.param( + ["a"], ["a", "b"], True, False, ["a"], + id="sync: fields absent from labels are dropped", + ), + pytest.param( + ["B", "A"], ["a", "b"], True, True, ["b", "a"], + id="sync + ignore_case: matching is case-insensitive", + ), + ], +) +def test_get_expected_fields( + labels, field_names, schema_sync, ignore_case, expected_names +): + header = _make_header( + labels, field_names, schema_sync=schema_sync, ignore_case=ignore_case + ) + actual = [f.name for f in header.get_expected_fields()] + assert actual == expected_names + + +def test_get_expected_fields_sync_default_field_is_any_typed(): + header = _make_header(["a", "extra"], ["a"], schema_sync=True) + expected = header.get_expected_fields() + assert expected[1].type == "any" + + +def test_get_expected_fields_sync_raises_on_duplicate_labels(): + header = _make_header(["a", "a"], ["a"], schema_sync=True) + with pytest.raises(frictionless.FrictionlessException): + header.get_expected_fields() + + @pytest.mark.parametrize( "source, required, valid_report, nb_errors, types_errors_expected, header_case", [ diff --git a/frictionless/table/__spec__/test_row.py b/frictionless/table/__spec__/test_row.py index a9033eaf9f..6a5e384612 100644 --- a/frictionless/table/__spec__/test_row.py +++ b/frictionless/table/__spec__/test_row.py @@ -1,7 +1,9 @@ import json from decimal import Decimal +from frictionless import fields from frictionless.resources import TableResource +from frictionless.table.row import Row # General @@ -19,6 +21,21 @@ def test_basic(): assert row.to_dict() == {"field1": 1, "field2": 2, "field3": 3} +def test_row_can_be_built_from_fields_list(): + row = Row( + ["1", "2"], + fields=[fields.IntegerField(name="a"), fields.IntegerField(name="b")], + row_number=2, + ) + assert row == {"a": 1, "b": 2} + assert row.field_names == ["a", "b"] + assert row.field_numbers == [1, 2] + assert row.row_number == 2 + assert row.errors == [] + assert row.to_list() == [1, 2] + assert row.to_dict() == {"a": 1, "b": 2} + + # Convert diff --git a/frictionless/table/header.py b/frictionless/table/header.py index ed485a1ad5..daa8f8d330 100644 --- a/frictionless/table/header.py +++ b/frictionless/table/header.py @@ -1,12 +1,11 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, List +from typing import List, Optional, Tuple from .. import errors, helpers - -if TYPE_CHECKING: - from ..schema import Field +from ..exception import FrictionlessException +from ..schema import Field class Header(List[str]): # type: ignore @@ -29,14 +28,24 @@ def __init__( fields: List[Field], row_numbers: List[int], ignore_case: bool = False, + schema_sync: bool = False, ): super().__init__(field.name for field in fields) - self.__fields = [field.to_copy() for field in fields] + self.__fields = [] + for field in fields: + copy = field.to_copy() + # to_copy() goes through the descriptor and drops the back-reference + # to the schema; restore it so checks like "field belongs to schema's + # primary_key" remain accurate. + copy.schema = field.schema + self.__fields.append(copy) self.__field_names = self.copy() self.__row_numbers = row_numbers self.__ignore_case = ignore_case + self.__schema_sync = schema_sync self.__labels = labels self.__errors: List[errors.HeaderError] = [] + self.__expected_fields: Optional[List[Field]] = None self.__process() @cached_property @@ -103,6 +112,96 @@ def valid(self): """ return not self.__errors + # Schema sync / expectations + + def get_expected_fields(self) -> List[Field]: + """Returns the fields, in the order expected in the data. + + Without `schema_sync`, this is just the schema fields unchanged. + + With `schema_sync`, fields are reordered to match the labels; labels + without a matching field get a fresh `any`-typed field, and fields not + present in labels are dropped. Duplicate labels are rejected. + """ + if self.__expected_fields is not None: + return self.__expected_fields + + if not self.__schema_sync: + self.__expected_fields = self.__fields + return self.__expected_fields + + if len(self.__labels) != len(set(self.__labels)): + note = '"schema_sync" requires unique labels in the header' + raise FrictionlessException(note) + + expected: List[Field] = [] + for label in self.__labels: + field = self.__find_field_by_name(label) + if field is None: + field = Field.from_descriptor({"name": label, "type": "any"}) + expected.append(field) + self.__expected_fields = expected + return self.__expected_fields + + def _get_extra_labels(self) -> List[str]: + """Returns labels in the data that don't correspond to any schema field. + + Without `schema_sync`, labels beyond the schema's field count are + considered extras. With `schema_sync`, extras are accepted, so an + empty list is returned. + """ + if not self.__schema_sync: + if len(self.__fields) < len(self.__labels): + return self.__labels[len(self.__fields) :] + return [] + + def _get_missing_fields(self) -> List[Tuple[int, Field]]: + """Returns (field_number, field) pairs for schema fields that don't + have a corresponding label. + + Without `schema_sync`, fields beyond the labels count are considered + missing. With `schema_sync`, only required fields whose name is not + among the labels are missing. + + The field_number is `len(labels) + offset + 1` in both modes: under + no-sync the missing fields are precisely the tail of the schema, so + this matches their position; under sync the missing fields have no + natural position in the data, so we place them after the labels by + convention. + """ + fields = self.__fields + labels = self.__labels + + if not self.__schema_sync: + missing = fields[len(labels) :] if len(fields) > len(labels) else [] + else: + normalized_labels = [self.__normalize(label) for label in labels] + + def required_and_missing(field: Field) -> bool: + required = field.required or ( + field.schema is not None + and field.name in field.schema.primary_key + ) + return ( + required + and self.__normalize(field.name) not in normalized_labels + ) + + missing = [field for field in fields if required_and_missing(field)] + + start = len(labels) + 1 + return [(start + offset, field) for offset, field in enumerate(missing)] + + def __find_field_by_name(self, name: str) -> Optional[Field]: + target = self.__normalize(name) + for f in self.__fields: + if self.__normalize(f.name) == target: + return f + return None + + def __normalize(self, s: str) -> str: + return s.lower() if self.__ignore_case else s + # Convert def to_str(self): @@ -129,40 +228,43 @@ def __process(self): labels = self.__labels fields = self.__fields - # Extra label - if len(fields) < len(labels): - start = len(fields) + 1 - iterator = labels[len(fields) :] - for field_number, label in enumerate(iterator, start=start): - self.__errors.append( - errors.ExtraLabelError( - note="", - labels=list(map(str, labels)), - row_numbers=self.__row_numbers, - label="", - field_name="", - field_number=field_number, - ) + # Extra labels + extra_start = len(fields) + 1 + for offset, label in enumerate(self._get_extra_labels()): + self.__errors.append( + errors.ExtraLabelError( + note="", + labels=list(map(str, labels)), + row_numbers=self.__row_numbers, + label="", + field_name="", + field_number=extra_start + offset, ) + ) - # Missing label - if len(fields) > len(labels): - start = len(labels) + 1 - iterator = fields[len(labels) :] - for field_number, field in enumerate(iterator, start=start): - if field is not None: # type: ignore - self.__errors.append( - errors.MissingLabelError( - note="", - labels=list(map(str, labels)), - row_numbers=self.__row_numbers, - label="", - field_name=field.name, - field_number=field_number, - ) - ) + # Missing fields + for field_number, field in self._get_missing_fields(): + self.__errors.append( + errors.MissingLabelError( + note="", + labels=list(map(str, labels)), + row_numbers=self.__row_numbers, + label="", + field_name=field.name, + field_number=field_number, + ) + ) # Iterate items + # Under schema_sync, labels and fields are matched by name (not by + # position), so the positional comparisons below (blank label vs + # field at the same index, incorrect label vs field name at the same + # index) don't apply. Duplicate labels are still invalid, but they + # are rejected earlier by get_expected_fields(), which raises a + # FrictionlessException — so detecting them here would be redundant. + if self.__schema_sync: + return + field_number = 0 for field, label in zip(fields, labels): field_number += 1 diff --git a/frictionless/table/row.py b/frictionless/table/row.py index a31f42c6b9..8a28094e77 100644 --- a/frictionless/table/row.py +++ b/frictionless/table/row.py @@ -2,16 +2,24 @@ from functools import cached_property from itertools import zip_longest -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, NamedTuple, Optional from .. import errors, helpers from ..platform import platform +from ..schema import Field # NOTE: # Currently dict.update/setdefault/pop/popitem/clear is not disabled (can be confusing) # We can consider adding row.header property to provide more comprehensive API +class _CellHandler(NamedTuple): + field: Field + field_number: int + reader: Callable[..., Any] + writer: Callable[..., Any] + + # TODO: add types class Row(Dict[str, Any]): """Row representation @@ -28,7 +36,7 @@ class Row(Dict[str, Any]): Parameters: cells (any[]): array of cells - field_info (dict): special field info structure + fields (Field[]): schema fields, in the order expected in the data row_number (int): row number from 1 """ @@ -36,11 +44,20 @@ def __init__( self, cells: List[Any], *, - field_info: Dict[str, Any], + fields: List[Field], row_number: int, ): self.__cells = cells - self.__field_info = field_info + self.__field_copies: List[Field] = [field.to_copy() for field in fields] + self.__handlers: Dict[str, _CellHandler] = { + field.name: _CellHandler( + field=field, + field_number=field_number, + reader=field.create_cell_reader(), + writer=field.create_cell_writer(), + ) + for field_number, field in enumerate(fields, start=1) + } self.__row_number = row_number self.__processed: bool = False self.__blank_cells: Dict[str, Any] = {} @@ -65,7 +82,7 @@ def __repr__(self): def __setitem__(self, key: str, value: Any): try: - _, field_number, _, _ = self.__field_info["mapping"][key] + field_number = self.__handlers[key].field_number except KeyError: raise KeyError(f"Row does not have a field {key}") if len(self.__cells) < field_number: @@ -77,30 +94,30 @@ def __missing__(self, key: str): return self.__process(key) def __iter__(self): - return iter(self.__field_info["names"]) + return iter(self.__handlers) def __len__(self): - return len(self.__field_info["names"]) + return len(self.__handlers) def __contains__(self, key: object): - return key in self.__field_info["mapping"] + return key in self.__handlers def __reversed__(self): - return reversed(self.__field_info["names"]) + return reversed(self.__handlers) def keys(self): - return iter(self.__field_info["names"]) + return iter(self.__handlers) def values(self): # type: ignore - for name in self.__field_info["names"]: + for name in self.__handlers: yield self[name] def items(self): # type: ignore - for name in self.__field_info["names"]: + for name in self.__handlers: yield (name, self[name]) def get(self, key: str, default: Optional[Any] = None): - if key not in self.__field_info["names"]: + if key not in self.__handlers: return default return self[key] @@ -118,7 +135,7 @@ def fields(self): Returns: Field[]: table schema fields """ - return self.__field_info["objects"] + return self.__field_copies @cached_property def field_names(self) -> List[str]: @@ -126,7 +143,7 @@ def field_names(self) -> List[str]: Returns: str[]: field names """ - return self.__field_info["names"] + return list(self.__handlers) @cached_property def field_numbers(self): @@ -134,7 +151,7 @@ def field_numbers(self): Returns: str[]: field numbers """ - return list(range(1, len(self.__field_info["names"]) + 1)) + return list(range(1, len(self.__handlers) + 1)) @cached_property def row_number(self) -> int: @@ -205,14 +222,14 @@ def to_list(self, *, json: bool = False, types: Optional[List[str]] = None): # Prepare self.__process() - result = [self[name] for name in self.__field_info["names"]] + result = [self[name] for name in self.__handlers] if types is None and json: types = platform.frictionless_formats.JsonParser.supported_types # Convert if types is not None: - for index, field_mapping in enumerate(self.__field_info["mapping"].values()): - field, _, _, cell_writer = field_mapping + for index, handler in enumerate(self.__handlers.values()): + field = handler.field # Here we can optimize performance if we use a types mapping if field.type in types: continue @@ -220,7 +237,7 @@ def to_list(self, *, json: bool = False, types: Optional[List[str]] = None): if json is True and field.type == "number" and field.float_number: continue cell = result[index] - cell, _ = cell_writer(cell, ignore_missing=True) + cell, _ = handler.writer(cell, ignore_missing=True) result[index] = cell # Return @@ -239,7 +256,7 @@ def to_dict( # Prepare self.__process() - result = {name: self[name] for name in self.__field_info["names"]} + result = {name: self[name] for name in self.__handlers} if types is None and json: types = platform.frictionless_formats.JsonParser.supported_types if types is None and csv: @@ -247,12 +264,12 @@ def to_dict( # Convert if types is not None: - for field_mapping in self.__field_info["mapping"].values(): - field, _, _, cell_writer = field_mapping + for handler in self.__handlers.values(): + field = handler.field # Here we can optimize performance if we use a types mapping if field.type not in types: cell = result[field.name] - cell, _ = cell_writer(cell, ignore_missing=True) + cell, _ = handler.writer(cell, ignore_missing=True) result[field.name] = cell # Return @@ -272,31 +289,33 @@ def __process(self, key: Optional[str] = None): # Prepare context cells = self.__cells to_str = lambda v: str(v) if v is not None else "" # type: ignore - fields = self.__field_info["objects"] - field_mapping = self.__field_info["mapping"] - iterator = zip_longest(field_mapping.values(), cells) + handlers = self.__handlers is_empty = not bool(super().__len__()) if key: try: - field, field_number, cell_reader, cell_writer = self.__field_info[ - "mapping" - ][key] + handler = handlers[key] except KeyError: raise KeyError(f"Row does not have a field {key}") - cell = cells[field_number - 1] if len(cells) >= field_number else None - iterator = zip([(field, field_number, cell_reader, cell_writer)], [cell]) + cell = ( + cells[handler.field_number - 1] + if len(cells) >= handler.field_number + else None + ) + iterator = zip([handler], [cell]) + else: + iterator = zip_longest(handlers.values(), cells) # Iterate cells - for field_mapping, source in iterator: + for handler, source in iterator: # Prepare context - if field_mapping is None: + if handler is None: break - field, field_number, cell_reader, _ = field_mapping + field = handler.field if not is_empty and super().__contains__(field.name): continue # Read cell - target, notes = cell_reader(source) + target, notes = handler.reader(source) type_note = notes.pop("type", None) if notes else None if target is None and not type_note: self.__blank_cells[field.name] = source @@ -311,7 +330,7 @@ def __process(self, key: Optional[str] = None): row_number=self.__row_number, cell=str(source), field_name=field.name, - field_number=field_number, + field_number=handler.field_number, ) ) @@ -325,7 +344,7 @@ def __process(self, key: Optional[str] = None): row_number=self.__row_number, cell=str(source), field_name=field.name, - field_number=field_number, + field_number=handler.field_number, ) ) @@ -335,10 +354,10 @@ def __process(self, key: Optional[str] = None): return target # Extra cells - if len(fields) < len(cells): - start = len(fields) + 1 - iterator = cells[len(fields) :] - for field_number, cell in enumerate(iterator, start=start): + n_fields = len(handlers) + if n_fields < len(cells): + start = n_fields + 1 + for field_number, cell in enumerate(cells[n_fields:], start=start): self.__errors.append( errors.ExtraCellError( note="", @@ -351,24 +370,22 @@ def __process(self, key: Optional[str] = None): ) # Missing cells - if len(fields) > len(cells): - start = len(cells) + 1 - iterator = fields[len(cells) :] - for field_number, field in enumerate(iterator, start=start): - if field is not None: - self.__errors.append( - errors.MissingCellError( - note="", - cells=list(map(to_str, cells)), # type: ignore - row_number=self.__row_number, - cell="", - field_name=field.name, - field_number=field_number, - ) + if n_fields > len(cells): + missing_handlers = list(handlers.values())[len(cells) :] + for handler in missing_handlers: + self.__errors.append( + errors.MissingCellError( + note="", + cells=list(map(to_str, cells)), # type: ignore + row_number=self.__row_number, + cell="", + field_name=handler.field.name, + field_number=handler.field_number, ) + ) # Blank row - if len(fields) == len(self.__blank_cells): + if n_fields == len(self.__blank_cells): self.__errors = [ errors.BlankRowError( note="",