diff --git a/accelerator/extras.py b/accelerator/extras.py index 6fd96121..8b2b3e8b 100644 --- a/accelerator/extras.py +++ b/accelerator/extras.py @@ -4,7 +4,7 @@ # Copyright (c) 2017 eBay Inc. # # Modifications copyright (c) 2019-2020 Anders Berkeman # # Modifications copyright (c) 2019-2024 Carl Drougge # -# Modifications copyright (c) 2023 Pablo Correa Gómez # +# Modifications copyright (c) 2023-2024 Pablo Correa Gómez # # # # Licensed under the Apache License, Version 2.0 (the "License"); # # you may not use this file except in compliance with the License. # @@ -421,7 +421,7 @@ def __next__(self): return item next = __next__ - def merge_auto(self): + def merge_auto(self, allow_overwrite=False): """Merge values from iterator using magic. Currenly supports data that has .update, .itervalues and .iteritems methods. @@ -429,15 +429,20 @@ def merge_auto(self): level, otherwise the value will be overwritten by later slices. Don't try to use this if all your values don't have the same depth, or if you have empty dicts at the last level. + + If allow_overwrite is set to True, when there are keys in dict-like + objects in multiple slices, the merging might only take the values from + one of them, without any kind of warranties. This was historical + behavior, and has a slightly greater performance. """ if self._started: raise self._exc("Will not merge after iteration started") if self._inner._is_tupled: - return (self._merge_auto_single(it, ix) for ix, it in enumerate(self._inner._loaders)) + return (self._merge_auto_single(it, ix, allow_overwrite) for ix, it in enumerate(self._inner._loaders)) else: - return self._merge_auto_single(self, -1) + return self._merge_auto_single(self, -1, allow_overwrite) - def _merge_auto_single(self, it, ix): + def _merge_auto_single(self, it, ix, allow_overwrite): # find a non-empty one, so we can look at the data in it data = next(it) if isinstance(data, num_types): @@ -466,6 +471,9 @@ def _merge_auto_single(self, it, ix): raise self._exc("Top level has no .values (index %d)" % (ix,)) def upd(aggregate, part, level): if level == depth: + if not allow_overwrite: + for k in part: + assert k not in aggregate, "duplicate %s" % (k,) aggregate.update(part) else: for k, v in iteritems(part): diff --git a/accelerator/standard_methods/a_dataset_fanout_collect.py b/accelerator/standard_methods/a_dataset_fanout_collect.py index 8cb4e00d..139e173b 100644 --- a/accelerator/standard_methods/a_dataset_fanout_collect.py +++ b/accelerator/standard_methods/a_dataset_fanout_collect.py @@ -42,4 +42,4 @@ def analysis(sliceno): return set(imap(unicode, chain.iterate(sliceno, options.column))) def synthesis(analysis_res): - return analysis_res.merge_auto() + return analysis_res.merge_auto(allow_overwrite=True) diff --git a/accelerator/test_methods/a_test_dataset_fanout.py b/accelerator/test_methods/a_test_dataset_fanout.py index 1919eff3..8a84ba89 100644 --- a/accelerator/test_methods/a_test_dataset_fanout.py +++ b/accelerator/test_methods/a_test_dataset_fanout.py @@ -67,7 +67,7 @@ def chk(job, colnames, types, ds2lines, previous={}, hashlabel=None): j_a_C = subjobs.build('dataset_fanout', source=a, column='C') chk(j_a_C, 'AB', ('unicode', 'ascii'), {'1': [('a', 'a')], '2': [('b', 'b')], '3': [('a', 'c')]}, hashlabel='A') - b = mk('b', ('ascii', 'unicode', 'int32', 'int32'), [('a', 'aa', 11, 111), ('b', 'bb', 12, 112), ('a', 'cc', 13, 113), ('d', 'dd', 14, 114)], previous=a) + b = mk('b', ('ascii', 'unicode', 'int32', 'int32'), [('a', 'aa', 11, 111), ('b', 'bb', 12, 112), ('a', 'cc', 13, 113), ('d', 'dd', 14, 114)], hashlabel=a.hashlabel, previous=a) # with previous j_b_A = subjobs.build('dataset_fanout', source=b, column='A', previous=j_a_A) chk(