From 049421bc26ae566408f8be0841a4ac44d5782c24 Mon Sep 17 00:00:00 2001 From: Gopinath Nelluri Date: Tue, 10 Feb 2026 13:33:12 -0500 Subject: [PATCH] Fix OOM on large spooled result sets Previously, TrinoQuery.fetch() eagerness caused all segments to load into memory at once when using fault-tolerant execution. This led to OOM errors on large datasets. Changes: - Enable lazy loading by returning SegmentIterator directly in fetch(). - Update execute() to handle result rows as iterators instead of requiring lists. - Add unit test to verify lazy fetching implementation. --- tests/unit/test_spooling.py | 84 +++++++++++++++++++++++++++++++++++++ trino/client.py | 34 ++++++++++++--- 2 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_spooling.py diff --git a/tests/unit/test_spooling.py b/tests/unit/test_spooling.py new file mode 100644 index 00000000..de83c3f1 --- /dev/null +++ b/tests/unit/test_spooling.py @@ -0,0 +1,84 @@ + +import unittest +from unittest.mock import MagicMock, patch +from trino.client import TrinoQuery, TrinoRequest, ClientSession, TrinoResult +from trino.client import SegmentIterator + +class TestTrinoQueryLazy(unittest.TestCase): + def setUp(self): + self.mock_request = MagicMock(spec=TrinoRequest) + self.client_session = ClientSession("user") + self.mock_request.client_session = self.client_session + + def test_fetch_returns_iterator_for_spooled_segments(self): + # Mock the initial POST response + post_response = MagicMock() + post_response.id = "query_1" + post_response.stats = {} + post_response.info_uri = "info" + post_response.next_uri = "next_1" + post_response.rows = [] # No rows initially + + self.mock_request.process.return_value = post_response + self.mock_request.post.return_value = MagicMock() + + query = TrinoQuery(self.mock_request, "SELECT 1") + + # Execute should return empty result initially but try to fetch + # We need to mock fetch behavior too since execute calls it if rows are empty + + # Mock the GET response for fetch() + get_response_status = MagicMock() + get_response_status.next_uri = None # Finished + get_response_status.stats = {} + # Status rows as dict indicates spooling protocol + get_response_status.rows = { + "encoding": "json", + "segments": [ + {"type": "spooled", "uri": "u1", "ackUri": "a1", "metadata": {"segmentSize": "10", "uncompressedSize": "10"}} + ], + "metadata": {} + } + + # When execute calls fetch(), it calls request.get -> process -> returns get_response_status + self.mock_request.process.side_effect = [post_response, get_response_status] + self.mock_request.get.return_value = MagicMock() + + # Mock _to_segments to return a list of decodable segments + # We can just verify that fetch returns a SegmentIterator + # But _to_segments is internal. + + # We need to patch SegmentIterator or check the return type + + result = query.execute() + + # Verify result.rows is a SegmentIterator, NOT a list + self.assertIsInstance(result.rows, SegmentIterator) + self.assertNotIsInstance(result.rows, list) + + def test_fetch_returns_list_for_normal_segments(self): + # Mock the initial POST response + post_response = MagicMock() + post_response.id = "query_1" + post_response.stats = {} + post_response.info_uri = "info" + post_response.next_uri = "next_1" + post_response.rows = [] + + # Mock the GET response for fetch() + get_response_status = MagicMock() + get_response_status.next_uri = None + get_response_status.stats = {} + get_response_status.rows = [[1], [2]] # Normal list rows + + self.mock_request.process.side_effect = [post_response, get_response_status] + + query = TrinoQuery(self.mock_request, "SELECT 1") + result = query.execute() + + # Verify result.rows is a list (appended) + self.assertIsInstance(result.rows, list) + self.assertEqual(result.rows, [[1], [2]]) + +if __name__ == '__main__': + unittest.main() diff --git a/trino/client.py b/trino/client.py index 3ab27e33..8e98559a 100644 --- a/trino/client.py +++ b/trino/client.py @@ -904,9 +904,32 @@ def execute(self, additional_http_headers=None) -> TrinoResult: rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows self._result = TrinoResult(self, rows) - # Execute should block until at least one row is received or query is finished or cancelled - while not self.finished and not self.cancelled and len(self._result.rows) == 0: - self._result.rows += self.fetch() + """ + Execute should block until at least one row is received or query is finished or cancelled + + For Standard Execution, rows is a list, we can check len. the first response usually contains no rows (just stats), + so we need to continue fetching until we get some rows or query is finished or cancelled. + + For Spooled Execution, rows start as empty list and eventually fetch returns the rows as iterator, + we can't check len of an iterator easily without peeking. + + So, if we get rows as non empty list or iterator, we stop blocking and return it to the caller to consume it. + """ + + while not self.finished and not self.cancelled: + if isinstance(self._result.rows, list) and len(self._result.rows) == 0: + new_rows = self.fetch() + if isinstance(new_rows, list): + self._result.rows += new_rows + else: + # It's an iterator (spooled segments), replace rows with it + self._result.rows = new_rows + # We have an iterator now, so we can return result to user + break + else: + # We have data (list with items or an iterator), so return + break + return self._result def _update_state(self, status): @@ -920,7 +943,7 @@ def _update_state(self, status): if status.columns: self._columns = status.columns - def fetch(self) -> List[Union[List[Any]], Any]: + def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]: """Continue fetching data for the current query_id""" try: response = self._request.get(self._request.next_uri) @@ -941,7 +964,8 @@ def fetch(self) -> List[Union[List[Any]], Any]: spooled = self._to_segments(rows) if self._fetch_mode == "segments": return spooled - return list(SegmentIterator(spooled, self._row_mapper)) + # Return iterator directly, do NOT materialize with list() + return SegmentIterator(spooled, self._row_mapper) elif isinstance(status.rows, list): return self._row_mapper.map(rows) else: