From 880081ed2181be9fb75143ba8b2b97fd67d6c948 Mon Sep 17 00:00:00 2001 From: waiho-gumloop Date: Wed, 25 Feb 2026 23:02:41 -0800 Subject: [PATCH 1/4] feat(dbapi): use inline begin to eliminate BeginTransaction RPC --- google/cloud/spanner_dbapi/connection.py | 10 ++++++++-- tests/unit/spanner_dbapi/test_connection.py | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 871eb152da..107a022b4c 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -392,7 +392,14 @@ def transaction_checkout(self): this connection yet. Return the started one otherwise. This method is a no-op if the connection is in autocommit mode and no - explicit transaction has been started + explicit transaction has been started. + + The transaction is returned without calling ``begin()``. The + underlying ``Transaction.execute_sql`` and ``execute_update`` + methods detect ``_transaction_id is None`` and use *inline begin* + — piggybacking a ``BeginTransaction`` on the first RPC via + ``TransactionSelector(begin=...)``. This eliminates a separate + ``BeginTransaction`` RPC round-trip per transaction. :rtype: :class:`google.cloud.spanner_v1.transaction.Transaction` :returns: A Cloud Spanner transaction object, ready to use. @@ -410,7 +417,6 @@ def transaction_checkout(self): self.transaction_tag = None self._snapshot = None self._spanner_transaction_started = True - self._transaction.begin() return self._transaction diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 6e8159425f..83d813243c 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -211,6 +211,26 @@ def test_transaction_checkout(self): connection._autocommit = True self.assertIsNone(connection.transaction_checkout()) + def test_transaction_checkout_does_not_call_begin(self): + """transaction_checkout must not call Transaction.begin(). + + The transaction should be returned with _transaction_id=None so that + execute_sql/execute_update can use inline begin via + TransactionSelector(begin=...), eliminating a separate + BeginTransaction RPC. + """ + connection = Connection(INSTANCE, DATABASE) + mock_session = mock.MagicMock() + mock_transaction = mock.MagicMock() + mock_session.transaction.return_value = mock_transaction + connection._session_checkout = mock.MagicMock(return_value=mock_session) + + txn = connection.transaction_checkout() + + self.assertEqual(txn, mock_transaction) + self.assertTrue(connection._spanner_transaction_started) + mock_transaction.begin.assert_not_called() + def test_snapshot_checkout(self): connection = build_connection(read_only=True) connection.autocommit = False From 48d4df819e8a17164cb6f7a19d4811ac31b47c2f Mon Sep 17 00:00:00 2001 From: waiho-gumloop Date: Fri, 27 Feb 2026 18:15:09 -0800 Subject: [PATCH 2/4] test(dbapi): add mockserver tests for inline begin and fix existing tests Add test_dbapi_inline_begin.py with 7 mockserver tests that verify: - Read-write DBAPI transactions send no BeginTransactionRequest - First ExecuteSqlRequest uses TransactionSelector(begin=...) - Read + write + commit request sequence is correct - DML-only transactions use inline begin - Read-only transactions still use explicit BeginTransaction - Transaction retry after abort works with inline begin Update existing mockserver tests that expected BeginTransactionRequest for read-write DBAPI transactions: - test_tags.py: Remove BeginTransactionRequest from expected sequences for all read-write tag tests, adjust tag index offsets - test_dbapi_isolation_level.py: Verify isolation level on the inline begin field of ExecuteSqlRequest instead of BeginTransactionRequest Made-with: Cursor --- .../test_dbapi_inline_begin.py | 248 ++++++++++++++++++ .../test_dbapi_isolation_level.py | 67 +++-- tests/mockserver_tests/test_tags.py | 10 +- 3 files changed, 282 insertions(+), 43 deletions(-) create mode 100644 tests/mockserver_tests/test_dbapi_inline_begin.py diff --git a/tests/mockserver_tests/test_dbapi_inline_begin.py b/tests/mockserver_tests/test_dbapi_inline_begin.py new file mode 100644 index 0000000000..d502325437 --- /dev/null +++ b/tests/mockserver_tests/test_dbapi_inline_begin.py @@ -0,0 +1,248 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that the DBAPI uses inline begin for read-write transactions. + +After removing the explicit ``Transaction.begin()`` call from +``Connection.transaction_checkout()``, the DBAPI should piggyback +``BeginTransaction`` on the first ``ExecuteSql`` / ``ExecuteUpdate`` request +via ``TransactionSelector(begin=...)``, eliminating one gRPC round-trip +per transaction. + +Read-only transactions are unaffected — they still use an explicit +``BeginTransaction`` RPC via ``snapshot_checkout()``. +""" + +from google.cloud.spanner_dbapi import Connection +from google.cloud.spanner_v1 import ( + BeginTransactionRequest, + CommitRequest, + ExecuteSqlRequest, + TransactionOptions, + TypeCode, +) +from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer +from google.cloud.spanner_v1.database_sessions_manager import TransactionType + +from tests.mockserver_tests.mock_server_test_base import ( + MockServerTestBase, + add_single_result, + add_update_count, + add_error, + aborted_status, +) + + +class TestDbapiInlineBegin(MockServerTestBase): + @classmethod + def setup_class(cls): + super().setup_class() + add_single_result( + "select name from singers", "name", TypeCode.STRING, [("Some Singer",)] + ) + add_update_count( + "insert into singers (id, name) values (1, 'Some Singer')", 1 + ) + + def test_read_write_no_begin_transaction_rpc(self): + """Read-write DBAPI transaction must not send BeginTransactionRequest.""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute("select name from singers") + cursor.fetchall() + connection.commit() + + begin_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, BeginTransactionRequest) + ] + self.assertEqual(0, len(begin_requests), + "Read-write DBAPI transactions should not send " + "a separate BeginTransactionRequest") + + def test_read_write_uses_inline_begin(self): + """The first ExecuteSqlRequest must carry TransactionSelector(begin=...).""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute("select name from singers") + cursor.fetchall() + connection.commit() + + sql_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, ExecuteSqlRequest) + ] + self.assertGreaterEqual(len(sql_requests), 1) + first_sql = sql_requests[0] + self.assertIn( + "read_write", first_sql.transaction.begin, + "First ExecuteSqlRequest should use inline begin with " + "TransactionSelector(begin=ReadWrite(...))", + ) + + def test_read_write_request_sequence(self): + """Read-write DBAPI transaction: ExecuteSql + Commit (no BeginTransaction).""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute("select name from singers") + cursor.fetchall() + connection.commit() + + self.assert_requests_sequence( + self.spanner_service.requests, + [ExecuteSqlRequest, CommitRequest], + TransactionType.READ_WRITE, + ) + + def test_read_write_dml_request_sequence(self): + """DML write via DBAPI: ExecuteSql + Commit (no BeginTransaction).""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.commit() + + self.assert_requests_sequence( + self.spanner_service.requests, + [ExecuteSqlRequest, CommitRequest], + TransactionType.READ_WRITE, + ) + + def test_read_then_write_request_sequence(self): + """Read + write in same transaction: 2x ExecuteSql + Commit.""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute("select name from singers") + cursor.fetchall() + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.commit() + + self.assert_requests_sequence( + self.spanner_service.requests, + [ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest], + TransactionType.READ_WRITE, + ) + + def test_read_only_still_uses_explicit_begin(self): + """Read-only transactions should still use explicit BeginTransaction.""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + connection.read_only = True + with connection.cursor() as cursor: + cursor.execute("select name from singers") + cursor.fetchall() + connection.commit() + + self.assert_requests_sequence( + self.spanner_service.requests, + [BeginTransactionRequest, ExecuteSqlRequest], + TransactionType.READ_ONLY, + ) + + def test_second_statement_uses_transaction_id(self): + """After inline begin, subsequent statements must use TransactionSelector(id=...). + + This verifies that the DBAPI correctly extracts the transaction_id from + the inline begin response and passes it to subsequent requests — proving + the transaction lifecycle is maintained. + """ + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute("select name from singers") + cursor.fetchall() + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.commit() + + sql_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, ExecuteSqlRequest) + ] + self.assertEqual(2, len(sql_requests)) + + first = sql_requests[0] + self.assertIn( + "read_write", first.transaction.begin, + "First statement should use inline begin", + ) + + second = sql_requests[1] + self.assertNotEqual( + b"", second.transaction.id, + "Second statement should use TransactionSelector(id=...) " + "with the transaction_id returned from inline begin, " + "not another TransactionSelector(begin=...)", + ) + + def test_rollback(self): + """Rollback should work without error after inline begin.""" + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.rollback() + + begin_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, BeginTransactionRequest) + ] + self.assertEqual(0, len(begin_requests)) + + def test_inline_begin_with_abort_retry(self): + """Transaction retry after abort should work with inline begin. + + The DBAPI replays recorded statements on abort. With inline begin, + the retried ExecuteSqlRequest should again use inline begin. + """ + add_error(SpannerServicer.Commit.__name__, aborted_status()) + + connection = Connection(self.instance, self.database) + connection.autocommit = False + with connection.cursor() as cursor: + cursor.execute( + "insert into singers (id, name) values (1, 'Some Singer')" + ) + connection.commit() + + begin_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, BeginTransactionRequest) + ] + self.assertEqual(0, len(begin_requests), + "Retried transaction should also use inline begin, " + "not explicit BeginTransactionRequest") + + sql_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, ExecuteSqlRequest) + ] + self.assertEqual(2, len(sql_requests), + "Expected 2 ExecuteSqlRequests: original + retry") + for i, req in enumerate(sql_requests): + self.assertIn( + "read_write", req.transaction.begin, + f"ExecuteSqlRequest[{i}] should use inline begin", + ) diff --git a/tests/mockserver_tests/test_dbapi_isolation_level.py b/tests/mockserver_tests/test_dbapi_isolation_level.py index e912914b19..a5c37e0eef 100644 --- a/tests/mockserver_tests/test_dbapi_isolation_level.py +++ b/tests/mockserver_tests/test_dbapi_isolation_level.py @@ -15,7 +15,7 @@ from google.api_core.exceptions import Unknown from google.cloud.spanner_dbapi import Connection from google.cloud.spanner_v1 import ( - BeginTransactionRequest, + ExecuteSqlRequest, TransactionOptions, ) from tests.mockserver_tests.mock_server_test_base import ( @@ -24,6 +24,13 @@ ) +def _get_first_execute_sql_request(requests): + """Return the first ExecuteSqlRequest from the captured requests.""" + return next( + req for req in requests if isinstance(req, ExecuteSqlRequest) + ) + + class TestDbapiIsolationLevel(MockServerTestBase): @classmethod def setup_class(cls): @@ -36,15 +43,9 @@ def test_isolation_level_default(self): cursor.execute("insert into singers (id, name) values (1, 'Some Singer')") self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) - ) - self.assertEqual(1, len(begin_requests)) + sql_request = _get_first_execute_sql_request(self.spanner_service.requests) self.assertEqual( - begin_requests[0].options.isolation_level, + sql_request.transaction.begin.isolation_level, TransactionOptions.IsolationLevel.ISOLATION_LEVEL_UNSPECIFIED, ) @@ -62,14 +63,12 @@ def test_custom_isolation_level(self): ) self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) + sql_request = _get_first_execute_sql_request( + self.spanner_service.requests + ) + self.assertEqual( + sql_request.transaction.begin.isolation_level, level ) - self.assertEqual(1, len(begin_requests)) - self.assertEqual(begin_requests[0].options.isolation_level, level) MockServerTestBase.spanner_service.clear_requests() def test_isolation_level_in_connection_kwargs(self): @@ -85,14 +84,12 @@ def test_isolation_level_in_connection_kwargs(self): ) self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) + sql_request = _get_first_execute_sql_request( + self.spanner_service.requests + ) + self.assertEqual( + sql_request.transaction.begin.isolation_level, level ) - self.assertEqual(1, len(begin_requests)) - self.assertEqual(begin_requests[0].options.isolation_level, level) MockServerTestBase.spanner_service.clear_requests() def test_transaction_isolation_level(self): @@ -109,14 +106,12 @@ def test_transaction_isolation_level(self): ) self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) + sql_request = _get_first_execute_sql_request( + self.spanner_service.requests + ) + self.assertEqual( + sql_request.transaction.begin.isolation_level, level ) - self.assertEqual(1, len(begin_requests)) - self.assertEqual(begin_requests[0].options.isolation_level, level) MockServerTestBase.spanner_service.clear_requests() def test_begin_isolation_level(self): @@ -133,14 +128,12 @@ def test_begin_isolation_level(self): ) self.assertEqual(1, cursor.rowcount) connection.commit() - begin_requests = list( - filter( - lambda msg: isinstance(msg, BeginTransactionRequest), - self.spanner_service.requests, - ) + sql_request = _get_first_execute_sql_request( + self.spanner_service.requests + ) + self.assertEqual( + sql_request.transaction.begin.isolation_level, level ) - self.assertEqual(1, len(begin_requests)) - self.assertEqual(begin_requests[0].options.isolation_level, level) MockServerTestBase.spanner_service.clear_requests() def test_begin_invalid_isolation_level(self): diff --git a/tests/mockserver_tests/test_tags.py b/tests/mockserver_tests/test_tags.py index 9e35517797..4d975c8ef7 100644 --- a/tests/mockserver_tests/test_tags.py +++ b/tests/mockserver_tests/test_tags.py @@ -115,7 +115,7 @@ def test_select_read_write_transaction_no_tags(self): requests = self.spanner_service.requests self.assert_requests_sequence( requests, - [BeginTransactionRequest, ExecuteSqlRequest, CommitRequest], + [ExecuteSqlRequest, CommitRequest], TransactionType.READ_WRITE, ) @@ -131,7 +131,7 @@ def test_select_read_write_transaction_with_request_tag(self): requests = self.spanner_service.requests self.assert_requests_sequence( requests, - [BeginTransactionRequest, ExecuteSqlRequest, CommitRequest], + [ExecuteSqlRequest, CommitRequest], TransactionType.READ_WRITE, ) @@ -148,7 +148,6 @@ def test_select_read_write_transaction_with_transaction_tag(self): self.assert_requests_sequence( requests, [ - BeginTransactionRequest, ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest, @@ -156,7 +155,7 @@ def test_select_read_write_transaction_with_transaction_tag(self): TransactionType.READ_WRITE, ) mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE) - tag_idx = 3 if mux_enabled else 2 + tag_idx = 2 if mux_enabled else 1 self.assertEqual( "my_transaction_tag", requests[tag_idx].request_options.transaction_tag ) @@ -180,7 +179,6 @@ def test_select_read_write_transaction_with_transaction_and_request_tag(self): self.assert_requests_sequence( requests, [ - BeginTransactionRequest, ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest, @@ -188,7 +186,7 @@ def test_select_read_write_transaction_with_transaction_and_request_tag(self): TransactionType.READ_WRITE, ) mux_enabled = is_multiplexed_enabled(TransactionType.READ_WRITE) - tag_idx = 3 if mux_enabled else 2 + tag_idx = 2 if mux_enabled else 1 self.assertEqual( "my_transaction_tag", requests[tag_idx].request_options.transaction_tag ) From 8aeedeb5a085041c94bb2251082be31ad1ed3c92 Mon Sep 17 00:00:00 2001 From: waiho-gumloop Date: Sat, 28 Feb 2026 07:43:03 -0800 Subject: [PATCH 3/4] chore: remove unused TransactionOptions import Made-with: Cursor --- tests/mockserver_tests/test_dbapi_inline_begin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/mockserver_tests/test_dbapi_inline_begin.py b/tests/mockserver_tests/test_dbapi_inline_begin.py index d502325437..eeb2a791cc 100644 --- a/tests/mockserver_tests/test_dbapi_inline_begin.py +++ b/tests/mockserver_tests/test_dbapi_inline_begin.py @@ -29,7 +29,6 @@ BeginTransactionRequest, CommitRequest, ExecuteSqlRequest, - TransactionOptions, TypeCode, ) from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer From 7db9b0c693942b45714a089d8bdc5154842a2af3 Mon Sep 17 00:00:00 2001 From: waiho-gumloop Date: Sat, 28 Feb 2026 08:01:16 -0800 Subject: [PATCH 4/4] test(dbapi): strengthen inline begin mockserver test assertions - Consolidate 3 redundant single-read tests into one comprehensive test that verifies: no BeginTransactionRequest, inline begin on first ExecuteSqlRequest, correct request sequence, and correct query results - Rename test_second_statement_uses_transaction_id to test_read_then_write_full_lifecycle with additional assertions: CommitRequest.transaction_id matches the transaction ID from inline begin - Strengthen test_rollback to verify RollbackRequest is sent with a non-empty transaction_id (was only checking no BeginTransactionRequest) - Add CommitRequest assertions to abort retry test: both the aborted and successful commits carry valid transaction IDs - Assert cursor.fetchall() return values in read tests to verify inline begin doesn't corrupt result set metadata - Add RollbackRequest import Made-with: Cursor --- .../test_dbapi_inline_begin.py | 168 +++++++++++------- 1 file changed, 108 insertions(+), 60 deletions(-) diff --git a/tests/mockserver_tests/test_dbapi_inline_begin.py b/tests/mockserver_tests/test_dbapi_inline_begin.py index eeb2a791cc..b8d61c7729 100644 --- a/tests/mockserver_tests/test_dbapi_inline_begin.py +++ b/tests/mockserver_tests/test_dbapi_inline_begin.py @@ -29,6 +29,7 @@ BeginTransactionRequest, CommitRequest, ExecuteSqlRequest, + RollbackRequest, TypeCode, ) from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer @@ -54,15 +55,27 @@ def setup_class(cls): "insert into singers (id, name) values (1, 'Some Singer')", 1 ) - def test_read_write_no_begin_transaction_rpc(self): - """Read-write DBAPI transaction must not send BeginTransactionRequest.""" + def test_read_write_inline_begin(self): + """Comprehensive check for a single-statement read-write transaction. + + Verifies: + - No BeginTransactionRequest is sent + - The ExecuteSqlRequest uses TransactionSelector(begin=ReadWrite(...)) + - The request sequence is [ExecuteSqlRequest, CommitRequest] + - The query returns correct data + """ connection = Connection(self.instance, self.database) connection.autocommit = False with connection.cursor() as cursor: cursor.execute("select name from singers") - cursor.fetchall() + rows = cursor.fetchall() connection.commit() + self.assertEqual( + [("Some Singer",)], rows, + "Query should return the mocked result set", + ) + begin_requests = [ r for r in self.spanner_service.requests if isinstance(r, BeginTransactionRequest) @@ -71,36 +84,21 @@ def test_read_write_no_begin_transaction_rpc(self): "Read-write DBAPI transactions should not send " "a separate BeginTransactionRequest") - def test_read_write_uses_inline_begin(self): - """The first ExecuteSqlRequest must carry TransactionSelector(begin=...).""" - connection = Connection(self.instance, self.database) - connection.autocommit = False - with connection.cursor() as cursor: - cursor.execute("select name from singers") - cursor.fetchall() - connection.commit() - sql_requests = [ r for r in self.spanner_service.requests if isinstance(r, ExecuteSqlRequest) ] self.assertGreaterEqual(len(sql_requests), 1) first_sql = sql_requests[0] + self.assertTrue( + first_sql.transaction.begin.read_write == first_sql.transaction.begin.read_write, + ) self.assertIn( "read_write", first_sql.transaction.begin, "First ExecuteSqlRequest should use inline begin with " "TransactionSelector(begin=ReadWrite(...))", ) - def test_read_write_request_sequence(self): - """Read-write DBAPI transaction: ExecuteSql + Commit (no BeginTransaction).""" - connection = Connection(self.instance, self.database) - connection.autocommit = False - with connection.cursor() as cursor: - cursor.execute("select name from singers") - cursor.fetchall() - connection.commit() - self.assert_requests_sequence( self.spanner_service.requests, [ExecuteSqlRequest, CommitRequest], @@ -123,24 +121,67 @@ def test_read_write_dml_request_sequence(self): TransactionType.READ_WRITE, ) - def test_read_then_write_request_sequence(self): - """Read + write in same transaction: 2x ExecuteSql + Commit.""" + def test_read_then_write_full_lifecycle(self): + """Read + write in same transaction: verifies the complete inline begin lifecycle. + + Checks: + - First ExecuteSqlRequest uses TransactionSelector(begin=ReadWrite(...)) + - Second ExecuteSqlRequest uses TransactionSelector(id=) + - CommitRequest uses the same transaction_id as the second statement + - Query returns correct data + - Request sequence is [ExecuteSql, ExecuteSql, Commit] + """ connection = Connection(self.instance, self.database) connection.autocommit = False with connection.cursor() as cursor: cursor.execute("select name from singers") - cursor.fetchall() + rows = cursor.fetchall() cursor.execute( "insert into singers (id, name) values (1, 'Some Singer')" ) connection.commit() + self.assertEqual( + [("Some Singer",)], rows, + "Query should return the mocked result set", + ) + self.assert_requests_sequence( self.spanner_service.requests, [ExecuteSqlRequest, ExecuteSqlRequest, CommitRequest], TransactionType.READ_WRITE, ) + sql_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, ExecuteSqlRequest) + ] + self.assertEqual(2, len(sql_requests)) + + first = sql_requests[0] + self.assertIn( + "read_write", first.transaction.begin, + "First statement should use inline begin", + ) + + second = sql_requests[1] + self.assertNotEqual( + b"", second.transaction.id, + "Second statement should use TransactionSelector(id=...) " + "with the transaction_id returned from inline begin", + ) + + commit_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, CommitRequest) + ] + self.assertEqual(1, len(commit_requests)) + self.assertEqual( + second.transaction.id, commit_requests[0].transaction_id, + "CommitRequest must reference the same transaction_id " + "that the second ExecuteSqlRequest used", + ) + def test_read_only_still_uses_explicit_begin(self): """Read-only transactions should still use explicit BeginTransaction.""" connection = Connection(self.instance, self.database) @@ -148,68 +189,61 @@ def test_read_only_still_uses_explicit_begin(self): connection.read_only = True with connection.cursor() as cursor: cursor.execute("select name from singers") - cursor.fetchall() + rows = cursor.fetchall() connection.commit() + self.assertEqual( + [("Some Singer",)], rows, + "Read-only query should return the mocked result set", + ) + self.assert_requests_sequence( self.spanner_service.requests, [BeginTransactionRequest, ExecuteSqlRequest], TransactionType.READ_ONLY, ) - def test_second_statement_uses_transaction_id(self): - """After inline begin, subsequent statements must use TransactionSelector(id=...). - - This verifies that the DBAPI correctly extracts the transaction_id from - the inline begin response and passes it to subsequent requests — proving - the transaction lifecycle is maintained. - """ + def test_rollback_after_inline_begin(self): + """Rollback after DML sends RollbackRequest with the correct transaction_id.""" connection = Connection(self.instance, self.database) connection.autocommit = False with connection.cursor() as cursor: - cursor.execute("select name from singers") - cursor.fetchall() cursor.execute( "insert into singers (id, name) values (1, 'Some Singer')" ) - connection.commit() + connection.rollback() + + begin_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, BeginTransactionRequest) + ] + self.assertEqual(0, len(begin_requests), + "Rollback path should not use BeginTransactionRequest") sql_requests = [ r for r in self.spanner_service.requests if isinstance(r, ExecuteSqlRequest) ] - self.assertEqual(2, len(sql_requests)) + self.assertEqual(1, len(sql_requests)) - first = sql_requests[0] + rollback_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, RollbackRequest) + ] + self.assertEqual(1, len(rollback_requests), + "A RollbackRequest should be sent after DML + rollback") + + txn_id_from_inline_begin = sql_requests[0].transaction.begin self.assertIn( - "read_write", first.transaction.begin, - "First statement should use inline begin", + "read_write", txn_id_from_inline_begin, + "DML should have used inline begin", ) - second = sql_requests[1] self.assertNotEqual( - b"", second.transaction.id, - "Second statement should use TransactionSelector(id=...) " - "with the transaction_id returned from inline begin, " - "not another TransactionSelector(begin=...)", + b"", rollback_requests[0].transaction_id, + "RollbackRequest must carry the transaction_id obtained via inline begin", ) - def test_rollback(self): - """Rollback should work without error after inline begin.""" - connection = Connection(self.instance, self.database) - connection.autocommit = False - with connection.cursor() as cursor: - cursor.execute( - "insert into singers (id, name) values (1, 'Some Singer')" - ) - connection.rollback() - - begin_requests = [ - r for r in self.spanner_service.requests - if isinstance(r, BeginTransactionRequest) - ] - self.assertEqual(0, len(begin_requests)) - def test_inline_begin_with_abort_retry(self): """Transaction retry after abort should work with inline begin. @@ -245,3 +279,17 @@ def test_inline_begin_with_abort_retry(self): "read_write", req.transaction.begin, f"ExecuteSqlRequest[{i}] should use inline begin", ) + + commit_requests = [ + r for r in self.spanner_service.requests + if isinstance(r, CommitRequest) + ] + self.assertEqual(2, len(commit_requests), + "Expected 2 CommitRequests: the aborted original + " + "the successful retry") + for i, cr in enumerate(commit_requests): + self.assertNotEqual( + b"", cr.transaction_id, + f"CommitRequest[{i}] must carry a transaction_id " + "from inline begin", + )