diff --git a/tests/unit/build_tests/test_io_map.py b/tests/unit/build_tests/test_io_map.py index 3193562ba..663001853 100644 --- a/tests/unit/build_tests/test_io_map.py +++ b/tests/unit/build_tests/test_io_map.py @@ -1,6 +1,8 @@ from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, RefSpec -from hdmf.build import GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, LinkBuilder -from hdmf import Container +from hdmf.spec import Spec +from hdmf.build import GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, LinkBuilder, ReferenceBuilder +from hdmf.build.warnings import MissingRequiredWarning +from hdmf import Container, Data from hdmf.utils import docval, getargs, get_docval from hdmf.data_utils import DataChunkIterator from hdmf.backends.hdf5 import H5DataIO @@ -19,24 +21,35 @@ class Bar(Container): {'name': 'attr1', 'type': str, 'doc': 'an attribute'}, {'name': 'attr2', 'type': int, 'doc': 'another attribute'}, {'name': 'attr3', 'type': float, 'doc': 'a third attribute', 'default': 3.14}, - {'name': 'foo', 'type': 'Foo', 'doc': 'a group', 'default': None}) + {'name': 'foo', 'type': 'Foo', 'doc': 'a group', 'default': None}, + {'name': 'foo_data', 'type': 'FooData', 'doc': 'some data', 'default': None}, + {'name': 'foo_data_ref', 'type': 'FooData', 'doc': 'some data', 'default': None}, + {'name': 'extra_attr', 'type': str, 'doc': 'an extra attribute of a Foo/FooData', 'default': None}) def __init__(self, **kwargs): - name, data, attr1, attr2, attr3, foo = getargs('name', 'data', 'attr1', 'attr2', 'attr3', 'foo', kwargs) + name, data, attr1, attr2, attr3 = getargs('name', 'data', 'attr1', 'attr2', 'attr3', kwargs) + foo, foo_data, foo_data_ref, extra_attr = getargs('foo', 'foo_data', 'foo_data_ref', 'extra_attr', kwargs) super().__init__(name=name) self.__data = data self.__attr1 = attr1 self.__attr2 = attr2 self.__attr3 = attr3 self.__foo = foo - if self.__foo is not None and self.__foo.parent is None: + if foo is not None and self.__foo.parent is None: self.__foo.parent = self + self.__foo_data = foo_data + if foo_data is not None and self.__foo_data.parent is None: + self.__foo_data.parent = self + self.__foo_data_ref = foo_data_ref + if foo_data_ref is not None and self.__foo_data_ref.parent is None: + self.__foo_data_ref.parent = self + self.__extra_attr = extra_attr def __eq__(self, other): - attrs = ('name', 'data', 'attr1', 'attr2', 'attr3', 'foo') + attrs = ('name', 'data', 'attr1', 'attr2', 'attr3', 'foo', 'foo_data') return all(getattr(self, a) == getattr(other, a) for a in attrs) def __str__(self): - attrs = ('name', 'data', 'attr1', 'attr2', 'attr3', 'foo') + attrs = ('name', 'data', 'attr1', 'attr2', 'attr3', 'foo', 'foo_data') return ','.join('%s=%s' % (a, getattr(self, a)) for a in attrs) @property @@ -63,14 +76,35 @@ def attr3(self): def foo(self): return self.__foo + @property + def foo_data(self): + return self.__foo_data + + @property + def foo_data_ref(self): + return self.__foo_data_ref + + @property + def extra_attr(self): + return self.__extra_attr + class Foo(Container): + def __eq__(self, other): + return self.name == other.name + @property def data_type(self): return 'Foo' +class FooData(Data): + + def __eq__(self, other): + return self.name == other.name and self.data == other.data + + class TestGetSubSpec(TestCase): def setUp(self): @@ -89,7 +123,6 @@ def test_get_subspec_data_type_noname(self): parent_spec = GroupSpec('Something to hold a Bar', 'bar_bucket', groups=[self.bar_spec]) sub_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': -1}) - GroupBuilder('bar_bucket', groups={'my_bar': sub_builder}) result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, self.bar_spec) @@ -98,7 +131,6 @@ def test_get_subspec_named(self): parent_spec = GroupSpec('Something to hold a Bar', 'my_group', groups=[child_spec]) sub_builder = GroupBuilder('my_subgroup', attributes={'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': -1}) - GroupBuilder('my_group', groups={'my_bar': sub_builder}) result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, child_spec) @@ -243,7 +275,7 @@ def test_dynamic_container_creation(self): expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4'} received_args = set() for x in get_docval(cls.__init__): - if x['name'] != 'foo': + if x['name'] not in ('foo', 'foo_data', 'foo_data_ref', 'extra_attr'): received_args.add(x['name']) with self.subTest(name=x['name']): self.assertNotIn('default', x) @@ -265,7 +297,8 @@ def test_dynamic_container_creation_defaults(self): AttributeSpec('attr4', 'another example float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_container_cls(CORE_NAMESPACE, 'Baz') - expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo'} + expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo', 'foo_data', 'foo_data_ref', + 'extra_attr'} received_args = set(map(lambda x: x['name'], get_docval(cls.__init__))) self.assertSetEqual(expected_args, received_args) self.assertEqual(cls.__name__, 'Baz') @@ -724,3 +757,264 @@ def test_bool_spec(self): match = (value, np.bool_) self.assertTupleEqual(ret, match) self.assertIs(type(ret[0]), match[1]) + + +class BarWithFooMapper(ObjectMapper): + + @ObjectMapper.constructor_arg('extra_attr') + def extra_attr_carg(self, builder, manager): + if 'foo' in builder: + return builder['foo'].attributes.get('extra_attr') + return None + + +class FooMapper(ObjectMapper): + + @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, + {"name": "container", "type": Foo, "doc": "the container to get the attribute value from"}, + {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, + returns='the value of the attribute') + def get_attr_value(self, **kwargs): + ''' Get the value of the attribute corresponding to this spec from the given container ''' + spec, container, manager = getargs('spec', 'container', 'manager', kwargs) + if isinstance(container.parent, Bar) and spec.parent.name == 'foo' and spec.name == 'extra_attr': + return container.parent.extra_attr + return super().get_attr_value(spec, container, manager) + + +class TestExtendGroupAttrs(TestCase): + + def setUp(self): + self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') + self.foo_ext_spec = GroupSpec('An extended Foo without a name or data_type_def', + data_type_inc='Foo', + quantity='?', + name='foo', + attributes=[AttributeSpec('extra_attr', 'an example string attribute', 'text')]) + self.bar_spec = GroupSpec('A test group specification with a data type Bar containing extended Foos', + data_type_def='Bar', + datasets=[DatasetSpec('an example dataset', 'int', name='data')], + attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), + AttributeSpec('attr2', 'an example integer attribute', 'int')], + groups=[self.foo_ext_spec]) + + spec_catalog = SpecCatalog() + spec_catalog.register_spec(self.foo_spec, 'test.yaml') + spec_catalog.register_spec(self.bar_spec, 'test.yaml') + namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], + version='0.1.0', + catalog=spec_catalog) + namespace_catalog = NamespaceCatalog() + namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) + self.type_map = TypeMap(namespace_catalog) + self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) + self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) + self.type_map.register_map(Foo, FooMapper) + self.type_map.register_map(Bar, BarWithFooMapper) + self.manager = BuildManager(self.type_map) + + def test_build_missing_required_attr(self): + foo_ext1 = Foo('foo') + bar1 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo=foo_ext1) + + with self.assertWarnsWith(MissingRequiredWarning, "attribute 'extra_attr' for 'foo' (Foo)"): + self.manager.build(bar1, source='my_source') + + def test_build_required_attr(self): + foo_ext2 = Foo('foo') + bar2 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo=foo_ext2, extra_attr='hello') + + builder = self.manager.build(bar2, source='my_source') + self.assertTrue('extra_attr' in builder['foo'].attributes) + self.assertEqual(builder['foo'].attributes['extra_attr'], 'hello') + + def test_construct_required_attr(self): + dset_builder = DatasetBuilder('data', dtype='int', data=[1, 2, 3, 4]) + foo_ext_builder = GroupBuilder('foo', + attributes={'extra_attr': 'hello', + 'data_type': 'Foo', + 'namespace': CORE_NAMESPACE, + 'object_id': -1}) + bar_builder = GroupBuilder('my_bar', + attributes={'attr1': 'value1', + 'attr2': 10, + 'data_type': 'Bar', + 'namespace': CORE_NAMESPACE, + 'object_id': -1}, + datasets={'data': dset_builder}, + groups={'foo': foo_ext_builder}) + + bar3 = self.manager.construct(bar_builder) + + foo_ext2 = Foo('foo') + bar2 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo=foo_ext2, extra_attr='hello') + self.assertTrue(bar2 == bar3) + + +class BarWithFooDataMapper(ObjectMapper): + + @ObjectMapper.constructor_arg('extra_attr') + def extra_attr_carg(self, builder, manager): + if 'foo_data' in builder: + return builder['foo_data'].attributes.get('extra_attr') + return None + + +class FooDataMapper(ObjectMapper): + + @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, + {"name": "container", "type": FooData, "doc": "the container to get the attribute value from"}, + {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, + returns='the value of the attribute') + def get_attr_value(self, **kwargs): + ''' Get the value of the attribute corresponding to this spec from the given container ''' + spec, container, manager = getargs('spec', 'container', 'manager', kwargs) + if isinstance(container.parent, Bar) and spec.parent.name == 'foo_data' and spec.name == 'extra_attr': + return container.parent.extra_attr + return super().get_attr_value(spec, container, manager) + + +class TestExtendDatasetAttrs(TestCase): + + def setUp(self): + self.foo_spec = DatasetSpec('A test dataset specification with data type FooData', data_type_def='FooData') + self.foo_ext_spec = DatasetSpec('An extended FooData without a name or data_type_def', + data_type_inc='FooData', + quantity='?', + name='foo_data', + attributes=[AttributeSpec('extra_attr', 'an example string attribute', 'text')]) + self.bar_spec = GroupSpec('A test group specification with a data type Bar containing extended FooData', + data_type_def='Bar', + datasets=[DatasetSpec('an example dataset', 'int', name='data'), + self.foo_ext_spec], + attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), + AttributeSpec('attr2', 'an example integer attribute', 'int')]) + + spec_catalog = SpecCatalog() + spec_catalog.register_spec(self.foo_spec, 'test.yaml') + spec_catalog.register_spec(self.bar_spec, 'test.yaml') + namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], + version='0.1.0', + catalog=spec_catalog) + namespace_catalog = NamespaceCatalog() + namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) + self.type_map = TypeMap(namespace_catalog) + self.type_map.register_container_type(CORE_NAMESPACE, 'FooData', FooData) + self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) + self.type_map.register_map(FooData, FooDataMapper) + self.type_map.register_map(Bar, BarWithFooDataMapper) + self.manager = BuildManager(self.type_map) + + def test_build_missing_required_attr(self): + foo_ext1 = FooData('foo_data', [1, 2, 3]) + bar1 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo_data=foo_ext1) + + with self.assertWarnsWith(MissingRequiredWarning, "attribute 'extra_attr' for 'foo_data' (FooData)"): + self.manager.build(bar1, source='my_source') + + def test_build_required_attr(self): + foo_ext2 = FooData('foo_data', [1, 2, 3]) + bar2 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo_data=foo_ext2, extra_attr='hello') + + builder = self.manager.build(bar2, source='my_source') + self.assertTrue('extra_attr' in builder['foo_data'].attributes) + self.assertEqual(builder['foo_data'].attributes['extra_attr'], 'hello') + + def test_construct_required_attr(self): + dset_builder = DatasetBuilder('data', dtype='int', data=[1, 2, 3, 4]) + foo_ext_builder = DatasetBuilder('foo_data', + data=[1, 2, 3], + attributes={'extra_attr': 'hello', + 'data_type': 'FooData', + 'namespace': CORE_NAMESPACE, + 'object_id': -1}) + bar_builder = GroupBuilder('my_bar', + attributes={'attr1': 'value1', + 'attr2': 10, + 'data_type': 'Bar', + 'namespace': CORE_NAMESPACE, + 'object_id': -1}, + datasets={'data': dset_builder, + 'foo_data': foo_ext_builder}) + + bar3 = self.manager.construct(bar_builder) + + foo_ext2 = FooData('foo_data', [1, 2, 3]) + bar2 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo_data=foo_ext2, extra_attr='hello') + self.assertTrue(bar2 == bar3) + + +class TestExtendDatasetAttrsWithRef(TestCase): + """Test that extending dataset attributes works when the dataset is built from a reference before normal""" + + def setUp(self): + self.foo_spec = DatasetSpec('A test dataset specification with data type FooData', data_type_def='FooData') + self.foo_ext_ref_spec = DatasetSpec('A ref DatasetSpec', RefSpec(reftype='object', target_type='FooData'), + name='foo_data_ref') + self.foo_ext_spec = DatasetSpec('An extended FooData without a name or data_type_def', + data_type_inc='FooData', + quantity='?', + name='foo_data', + attributes=[AttributeSpec('extra_attr', 'an example string attribute', 'text')]) + self.bar_spec = GroupSpec('A test group specification with a data type Bar containing extended FooData', + data_type_def='Bar', + datasets=[DatasetSpec('an example dataset', 'int', name='data'), + self.foo_ext_ref_spec, + self.foo_ext_spec], + attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), + AttributeSpec('attr2', 'an example integer attribute', 'int')]) + + spec_catalog = SpecCatalog() + spec_catalog.register_spec(self.foo_spec, 'test.yaml') + spec_catalog.register_spec(self.bar_spec, 'test.yaml') + namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], + version='0.1.0', + catalog=spec_catalog) + namespace_catalog = NamespaceCatalog() + namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) + self.type_map = TypeMap(namespace_catalog) + self.type_map.register_container_type(CORE_NAMESPACE, 'FooData', FooData) + self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) + self.type_map.register_map(FooData, FooDataMapper) + self.type_map.register_map(Bar, BarWithFooDataMapper) + self.manager = BuildManager(self.type_map) + + def test_build_missing_required_attr(self): + foo_ext1 = FooData('foo_data', [1, 2, 3]) + bar1 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo_data=foo_ext1, foo_data_ref=foo_ext1) + + with self.assertWarnsWith(MissingRequiredWarning, "attribute 'extra_attr' for 'foo_data' (FooData)"): + self.manager.build(bar1, source='my_source') + + def test_build_required_attr(self): + foo_ext2 = FooData('foo_data', [1, 2, 3]) + bar2 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo_data=foo_ext2, foo_data_ref=foo_ext2, extra_attr='hello') + + builder = self.manager.build(bar2, source='my_source') + self.assertTrue('extra_attr' in builder['foo_data'].attributes) + self.assertEqual(builder['foo_data'].attributes['extra_attr'], 'hello') + + def test_construct_required_attr(self): + dset_builder = DatasetBuilder('data', dtype='int', data=[1, 2, 3, 4]) + foo_ext_builder = DatasetBuilder('foo_data', + data=[1, 2, 3], + attributes={'extra_attr': 'hello', + 'data_type': 'FooData', + 'namespace': CORE_NAMESPACE, + 'object_id': -1}) + foo_data_ref_builder = DatasetBuilder('foo_data_ref', [ReferenceBuilder(foo_ext_builder)], dtype='object') + bar_builder = GroupBuilder('my_bar', + attributes={'attr1': 'value1', + 'attr2': 10, + 'data_type': 'Bar', + 'namespace': CORE_NAMESPACE, + 'object_id': -1}, + datasets={'data': dset_builder, + 'foo_data_ref': foo_data_ref_builder, + 'foo_data': foo_ext_builder}) + + bar3 = self.manager.construct(bar_builder) + + foo_ext2 = FooData('foo_data', [1, 2, 3]) + bar2 = Bar('my_bar', [1, 2, 3, 4], 'value1', 10, foo_data=foo_ext2, foo_data_ref=foo_ext2, extra_attr='hello') + self.assertTrue(bar2 == bar3)