diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py index 4e45d0324ee2..b7f846d79e74 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py @@ -82,6 +82,31 @@ _SLEEP_DURATION_BETWEEN_POLLS = 10 +def _has_partitioning_load_parameters(additional_parameters): + return ('timePartitioning' in additional_parameters or + 'rangePartitioning' in additional_parameters) + + +def _add_destination_partitioning_load_parameters( + additional_parameters, destination_table): + if not isinstance(destination_table, bigquery_tools.bigquery.Table): + return additional_parameters + + additional_parameters = dict(additional_parameters) + + if ('timePartitioning' not in additional_parameters and + getattr(destination_table, 'timePartitioning', None) is not None): + additional_parameters['timePartitioning'] = ( + destination_table.timePartitioning) + + if ('rangePartitioning' not in additional_parameters and + getattr(destination_table, 'rangePartitioning', None) is not None): + additional_parameters['rangePartitioning'] = ( + destination_table.rangePartitioning) + + return additional_parameters + + def _generate_job_name(job_name, job_type, step_name): return bigquery_tools.generate_bq_job_name( job_name=job_name, @@ -716,6 +741,7 @@ def process( additional_parameters = self.additional_bq_parameters.get() else: additional_parameters = self.additional_bq_parameters + additional_parameters = dict(additional_parameters or {}) table_reference = bigquery_tools.parse_table_reference(destination) if table_reference.projectId is None: @@ -735,19 +761,36 @@ def process( create_disposition = self.create_disposition if self.temporary_tables: + destination_table = None + hashed_dest = bigquery_tools.get_hashable_destination(table_reference) + should_lookup_destination_table = ( + schema is None or + not _has_partitioning_load_parameters(additional_parameters)) + if should_lookup_destination_table: + try: + destination_table = self.bq_wrapper.get_table( + project_id=table_reference.projectId, + dataset_id=table_reference.datasetId, + table_id=table_reference.tableId) + except Exception as e: + if schema is None: + _LOGGER.warning( + "Input schema is absent and could not fetch the final " + "destination table's schema [%s]. Creating temp table [%s] " + "will likely fail: %s", + hashed_dest, + job_name, + e) + destination_table = None + # we need to create temp tables, so we need a schema. # if there is no input schema, fetch the destination table's schema if schema is None: - hashed_dest = bigquery_tools.get_hashable_destination(table_reference) if hashed_dest in self.schema_cache: schema = self.schema_cache[hashed_dest] - else: + elif destination_table is not None: try: - schema = bigquery_tools.table_schema_to_dict( - bigquery_tools.BigQueryWrapper().get_table( - project_id=table_reference.projectId, - dataset_id=table_reference.datasetId, - table_id=table_reference.tableId).schema) + schema = bigquery_tools.table_schema_to_dict(destination_table.schema) self.schema_cache[hashed_dest] = schema except Exception as e: _LOGGER.warning( @@ -758,6 +801,11 @@ def process( job_name, e) + if (destination_table is not None and + not _has_partitioning_load_parameters(additional_parameters)): + additional_parameters = _add_destination_partitioning_load_parameters( + additional_parameters, destination_table) + # If we are using temporary tables, then we must always create the # temporary tables, so we replace the create_disposition. create_disposition = 'CREATE_IF_NEEDED' diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py index 191719e6a208..f57b10c06cf6 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py @@ -703,6 +703,77 @@ def test_one_load_job_failed_after_waiting(self, sleep_mock): sleep_mock.assert_called_once() + def test_temporary_table_load_inherits_destination_time_partitioning(self): + destination = 'project1:dataset1.table1' + partition = (destination, (0, ['gs://bucket/file1'])) + job_reference = bigquery_api.JobReference(projectId='project1', + jobId='job_name1') + destination_table = bigquery_api.Table( + timePartitioning=bigquery_api.TimePartitioning(type='DAY')) + + dofn = bqfl.TriggerLoadJobs( + schema=_ELEMENTS_SCHEMA, + test_client=mock.Mock(), + temporary_tables=True) + dofn.start_bundle() + dofn.bq_wrapper.get_table = mock.Mock(return_value=destination_table) + dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference) + + list(dofn.process(partition, 'test_job', pane_info=mock.Mock(index=0))) + + load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs + self.assertEqual( + load_call['additional_load_parameters']['timePartitioning'], + destination_table.timePartitioning) + dofn.bq_wrapper.get_table.assert_called_once() + + def test_temporary_table_load_inherits_destination_range_partitioning(self): + destination = 'project1:dataset1.table1' + partition = (destination, (0, ['gs://bucket/file1'])) + job_reference = bigquery_api.JobReference(projectId='project1', + jobId='job_name1') + destination_table = bigquery_api.Table( + rangePartitioning=bigquery_api.RangePartitioning()) + + dofn = bqfl.TriggerLoadJobs( + schema=_ELEMENTS_SCHEMA, + test_client=mock.Mock(), + temporary_tables=True) + dofn.start_bundle() + dofn.bq_wrapper.get_table = mock.Mock(return_value=destination_table) + dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference) + + list(dofn.process(partition, 'test_job', pane_info=mock.Mock(index=0))) + + load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs + self.assertEqual( + load_call['additional_load_parameters']['rangePartitioning'], + destination_table.rangePartitioning) + dofn.bq_wrapper.get_table.assert_called_once() + + def test_temporary_table_load_keeps_explicit_partitioning_parameters(self): + destination = 'project1:dataset1.table1' + partition = (destination, (0, ['gs://bucket/file1'])) + explicit_partitioning = {'timePartitioning': {'type': 'DAY'}} + job_reference = bigquery_api.JobReference(projectId='project1', + jobId='job_name1') + + dofn = bqfl.TriggerLoadJobs( + schema=_ELEMENTS_SCHEMA, + test_client=mock.Mock(), + temporary_tables=True, + additional_bq_parameters=explicit_partitioning) + dofn.start_bundle() + dofn.bq_wrapper.get_table = mock.Mock() + dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference) + + list(dofn.process(partition, 'test_job', pane_info=mock.Mock(index=0))) + + load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs + self.assertEqual(load_call['additional_load_parameters'], + explicit_partitioning) + dofn.bq_wrapper.get_table.assert_not_called() + def test_multiple_partition_files(self): destination = 'project1:dataset1.table1'