From f66d4a733ed0e744f86378b46f8b21025ff8343c Mon Sep 17 00:00:00 2001 From: 0x26res Date: Fri, 13 Mar 2026 14:20:01 +0000 Subject: [PATCH] Fix warnings --- docs/contributing.md | 4 ++-- protarrow/arrow_to_proto.py | 10 +++++----- protarrow/cast_to_proto.py | 6 +++--- protarrow/message_extractor.py | 2 +- protarrow/proto_to_arrow.py | 12 ++++++------ tests/random_generator.py | 8 ++++---- tests/test_conversion.py | 4 ++-- tests/test_protobuf.py | 3 +-- 8 files changed, 24 insertions(+), 25 deletions(-) diff --git a/docs/contributing.md b/docs/contributing.md index c01f404..efe6039 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -17,10 +17,10 @@ pre-commit install This library relies on property based testing. Tests convert randomly generated data from protobuf to arrow and back, making sure the end result is the same as the input. -To run tests fast: +The tests take a long time to run. To run them faster: ```shell -pytest -n auto tests +pytest --numprocesses=auto -p no:benchmark ./tests ``` To Get coverage: diff --git a/protarrow/arrow_to_proto.py b/protarrow/arrow_to_proto.py index ac710aa..2daa70b 100644 --- a/protarrow/arrow_to_proto.py +++ b/protarrow/arrow_to_proto.py @@ -257,7 +257,7 @@ class RepeatedNestedIterable(collections.abc.Iterable): field_descriptor: FieldDescriptor def __post_init__(self): - assert self.field_descriptor.label == FieldDescriptor.LABEL_REPEATED + assert self.field_descriptor.is_repeated assert self.field_descriptor.type == FieldDescriptor.TYPE_MESSAGE def __iter__(self) -> Iterator[Any]: @@ -339,7 +339,7 @@ def __init__( ): self.messages = messages self.field_descriptor = field_descriptor - assert self.field_descriptor.label == FieldDescriptor.LABEL_REPEATED + assert self.field_descriptor.is_repeated self.sizes = sizes self.converter = converter self.attribute = None @@ -367,7 +367,7 @@ class MapKeyAssigner(collections.abc.Iterable): attribute: Any = None def __post_init__(self, key_arrow_type: pa.DataType): - assert self.field_descriptor.label == FieldDescriptor.LABEL_REPEATED + assert self.field_descriptor.is_repeated assert self.field_descriptor.message_type.GetOptions().map_entry self.converter = get_converter( self.field_descriptor.message_type.fields_by_name["key"], key_arrow_type @@ -410,7 +410,7 @@ class MapItemAssigner(collections.abc.Iterable): attribute: Optional[MessageMap] = None def __post_init__(self, key_arrow_type: pa.DataType, value_arrow_type: pa.DataType): - assert self.field_descriptor.label == FieldDescriptor.LABEL_REPEATED + assert self.field_descriptor.is_repeated assert self.field_descriptor.message_type.GetOptions().map_entry self.key_converter = get_converter( self.field_descriptor.message_type.fields_by_name["key"], key_arrow_type @@ -580,7 +580,7 @@ def _extract_repeated_message( def _extract_field( array: pa.Array, field_descriptor: FieldDescriptor, messages: Iterable[Message] ) -> None: - if field_descriptor.label == FieldDescriptor.LABEL_REPEATED: + if field_descriptor.is_repeated: _extract_repeated_field(array, field_descriptor, messages) elif field_descriptor.message_type in TEMPORAL_CONVERTERS: extractor = TEMPORAL_CONVERTERS[field_descriptor.message_type](array.type) diff --git a/protarrow/cast_to_proto.py b/protarrow/cast_to_proto.py index 856577f..d0c4cd9 100644 --- a/protarrow/cast_to_proto.py +++ b/protarrow/cast_to_proto.py @@ -30,7 +30,7 @@ def get_arrow_default_value( if field_descriptor.type == FieldDescriptor.TYPE_ENUM: default_value = ( field_descriptor.enum_type.values[0].number - if field_descriptor.label == FieldDescriptor.LABEL_REPEATED + if field_descriptor.is_repeated else field_descriptor.default_value ) if pa.types.is_integer(config.enum_type): @@ -128,7 +128,7 @@ def _cast_array( ) ) - elif field_descriptor.label == FieldDescriptor.LABEL_REPEATED: + elif field_descriptor.is_repeated: assert isinstance(array, (pa.ListArray, pa.LargeListArray)) item_array = _cast_flat_array(array.values, field_descriptor, config) return config.list_array_type.from_arrays( @@ -159,7 +159,7 @@ def get_casted_array( else: default_value = ( [] - if field_descriptor.label == FieldDescriptor.LABEL_REPEATED + if field_descriptor.is_repeated else get_arrow_default_value(field_descriptor, config) ) casted_array = pa.array( diff --git a/protarrow/message_extractor.py b/protarrow/message_extractor.py index 1bb5426..605c918 100644 --- a/protarrow/message_extractor.py +++ b/protarrow/message_extractor.py @@ -101,7 +101,7 @@ def get_field_converter( key, value = get_map_descriptors(field_descriptor) return MapConverterAdapter(field.type, key, value) else: - if field_descriptor.label == FieldDescriptor.LABEL_REPEATED: + if field_descriptor.is_repeated: return RepeatedConverterAdapter( get_flat_field_converter(field.type.value_type, field_descriptor) ) diff --git a/protarrow/proto_to_arrow.py b/protarrow/proto_to_arrow.py index 1788721..f9909e9 100644 --- a/protarrow/proto_to_arrow.py +++ b/protarrow/proto_to_arrow.py @@ -212,7 +212,7 @@ def _raise_recursion_error(trace: Tuple[Descriptor, ...]): def is_map(field_descriptor: FieldDescriptor) -> bool: return ( field_descriptor.type == FieldDescriptor.TYPE_MESSAGE - and field_descriptor.label == FieldDescriptor.LABEL_REPEATED + and field_descriptor.is_repeated and field_descriptor.message_type.GetOptions().map_entry ) @@ -277,7 +277,7 @@ def field_descriptor_to_field( nullable=config.map_nullable, metadata=config.field_metadata(field_descriptor.number), ) - elif field_descriptor.label == FieldDescriptor.LABEL_REPEATED: + elif field_descriptor.is_repeated: return pa.field( field_descriptor.name, config.list_( @@ -405,7 +405,7 @@ def _proto_field_to_array( field_descriptor.has_presence # We use none for repeated field as there should not # be any missing list elements, they are not nullable - or field_descriptor.label == FieldDescriptor.LABEL_REPEATED + or field_descriptor.is_repeated ) else converter(field_descriptor.default_value) ) @@ -506,7 +506,7 @@ def _proto_field_nullable( ) -> bool: if is_map(field_descriptor): return config.map_nullable - elif field_descriptor.label == FieldDescriptorProto.LABEL_REPEATED: + elif field_descriptor.is_repeated: return config.list_nullable else: return field_descriptor.has_presence @@ -540,7 +540,7 @@ def _messages_to_array( for field_descriptor in descriptor.fields: if ( field_descriptor.type == FieldDescriptor.TYPE_MESSAGE - and field_descriptor.label != FieldDescriptor.LABEL_REPEATED + and not field_descriptor.is_repeated ): field_values = NestedIterable( messages, NestedMessageGetter(field_descriptor.name) @@ -561,7 +561,7 @@ def _messages_to_array( array = _proto_map_to_array( field_values, field_descriptor, config, this_trace ) - elif field_descriptor.label == FieldDescriptorProto.LABEL_REPEATED: + elif field_descriptor.is_repeated: array = _repeated_proto_to_array( field_values, field_descriptor, config, this_trace ) diff --git a/tests/random_generator.py b/tests/random_generator.py index 97e0c94..f8b7201 100644 --- a/tests/random_generator.py +++ b/tests/random_generator.py @@ -89,7 +89,7 @@ def generate_message(message_type: typing.Type[M], repeated_count: int) -> M: for field in message_type.DESCRIPTOR.fields: if field.containing_oneof is None: if ( - field.label == FieldDescriptor.LABEL_REPEATED + field.is_repeated or field.type != FieldDescriptor.TYPE_MESSAGE or random.getrandbits(1) == 1 ): @@ -106,7 +106,7 @@ def generate_messages( def set_field(message: Message, field: FieldDescriptor, count: int) -> None: data = generate_field_data(field, count) - if field.label == FieldDescriptor.LABEL_REPEATED: + if field.is_repeated: field_value = getattr(message, field.name) if is_map(field): if ( @@ -131,7 +131,7 @@ def set_field(message: Message, field: FieldDescriptor, count: int) -> None: def generate_field_data(field: FieldDescriptor, count: int): - if field.label == FieldDescriptor.LABEL_REPEATED: + if field.is_repeated: size = random.randint(0, count) return [_generate_data(field, count) for _ in range(size)] else: @@ -204,7 +204,7 @@ def truncate_nanos_message( time_unit: str, timestamp_unit: str, ) -> None: - if field.label == FieldDescriptor.LABEL_REPEATED: + if field.is_repeated: field_value = getattr(message, field.name) if field.message_type is not None and field.message_type.GetOptions().map_entry: if ( diff --git a/tests/test_conversion.py b/tests/test_conversion.py index ad3450d..c6b9330 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -591,7 +591,7 @@ def test_only_messages_default_to_null_on_missing_array(config): expected = ( None if field_descriptor.type == FieldDescriptor.TYPE_MESSAGE - and field_descriptor.label != FieldDescriptor.LABEL_REPEATED + and not field_descriptor.is_repeated else [] ) assert get_casted_array(field_descriptor, None, 1, config)[0].to_pylist() == [ @@ -611,7 +611,7 @@ def test_only_messages_stay_to_null_on_casted_array(config): expected = ( None if field_descriptor.type == FieldDescriptor.TYPE_MESSAGE - and field_descriptor.label != FieldDescriptor.LABEL_REPEATED + and not field_descriptor.is_repeated else [] ) arrow_field = field_descriptor_to_field(field_descriptor, config) diff --git a/tests/test_protobuf.py b/tests/test_protobuf.py index f597543..273694a 100644 --- a/tests/test_protobuf.py +++ b/tests/test_protobuf.py @@ -2,7 +2,6 @@ Tests the behavior of Google protobuf """ -from google.protobuf.descriptor import FieldDescriptor from google.protobuf.empty_pb2 import Empty from google.protobuf.wrappers_pb2 import StringValue @@ -20,7 +19,7 @@ def test_empty_has_field(): def test_repeated_no_presence(): field = MessageWithOptional.DESCRIPTOR.fields_by_name["string_values"] - assert field.label == FieldDescriptor.LABEL_REPEATED + assert field.is_repeated assert not field.has_presence