Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,74 @@ assert rows[0][0] == "-2001-08-22"
assert cur.description[0][1] == "date"
```

## Progress Callback

The Trino client supports progress callbacks to track query execution progress in real-time. you can provide a callback function that gets called whenever the query status is updated.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The Trino client supports progress callbacks to track query execution progress in real-time. you can provide a callback function that gets called whenever the query status is updated.
The Trino client supports progress callbacks to track query execution progress in real-time. You can provide a callback that is called whenever the query status is updated.


### Basic Usage

```python
from trino.client import TrinoQuery, TrinoRequest, ClientSession, TrinoStatus
from typing import Dict, Any

def progress_callback(status: TrinoStatus, stats: Dict[str, Any]) -> None:
"""Progress callback function that gets called whenever the query status is updated."""
state = stats.get('state', 'UNKNOWN')
processed_bytes = stats.get('processedBytes', 0)
processed_rows = stats.get('processedRows', 0)
completed_splits = stats.get('completedSplits', 0)
total_splits = stats.get('totalSplits', 0)

print(f"Query {status.id}: {state} - {processed_bytes} bytes, {processed_rows} rows")
if total_splits > 0:
progress = (completed_splits / total_splits) * 100.0
print(f"Progress: {progress:.1f}% ({completed_splits}/{total_splits} splits)")

session = ClientSession(user="test_user", catalog="memory", schema="default")

request = TrinoRequest(
host="localhost",
port=8080,
client_session=session,
http_scheme="http"
)

query = TrinoQuery(
request=request,
query="SELECT * FROM large_table",
progress_callback=progress_callback
)

result = query.execute()

while not query.finished:
rows = query.fetch()
```

### Progress Calculation

The callback receives a `stats` dictionary containing various metrics that can be used to calculate progress:

- `state`: Query state (RUNNING, FINISHED, FAILED, etc.)
- `processedBytes`: Total bytes processed
- `processedRows`: Total rows processed
- `completedSplits`: Number of completed splits
- `totalSplits`: Total number of splits

The most accurate progress calculation is based on splits completion:

```python
def calculate_progress(stats: Dict[str, Any]) -> float:
"""Calculate progress percentage based on splits completion."""
completed_splits = stats.get('completedSplits', 0)
total_splits = stats.get('totalSplits', 0)
if total_splits > 0:
return min(100.0, (completed_splits / total_splits) * 100.0)
elif stats.get('state') == 'FINISHED':
return 100.0
return 0.0
```

### Trino to Python type mappings

| Trino type | Python type |
Expand Down
166 changes: 166 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,3 +1450,169 @@ def delete_password(self, servicename, username):
return None

os.remove(file_path)


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_success(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat is sent periodically and does not stop on success."""
head_call_count = 0

def mock_head_response(url, timeout=10, **kwargs):
nonlocal head_call_count
head_call_count += 1

class Resp:
status_code = 200
return Resp()
# Mock the Session's head method
mock_session = mock.Mock()
mock_session.head.side_effect = mock_head_response
mock_requests.Session.return_value = mock_session
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval_seconds=0.1)

def finish_query(*args, **kwargs):
query._finished = True
return []
query.fetch = finish_query
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.3)
query._stop_heartbeat()
assert head_call_count >= 2


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_failure_stops(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops after 3 consecutive failures."""
def fake_head(url, timeout=10):
class Resp:
status_code = 500
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval_seconds=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.3)
assert not query._heartbeat_enabled
query._stop_heartbeat()


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_404_405_stops(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops if server returns 404 or 405."""
for code in (404, 405):
def fake_head(url, timeout=10, code=code):
class Resp:
status_code = code
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval_seconds=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.2)
assert not query._heartbeat_enabled
query._stop_heartbeat()


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_stops_on_finish(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops when the query is finished."""
head_call_count = 0

def mock_head_response(url, timeout=10, **kwargs):
nonlocal head_call_count
head_call_count += 1

class Resp:
status_code = 200
return Resp()
# Mock the Session's head method
mock_session = mock.Mock()
mock_session.head.side_effect = mock_head_response
mock_requests.Session.return_value = mock_session
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval_seconds=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.1)
query._finished = True
time.sleep(0.1)
query._stop_heartbeat()
# Heartbeat should have stopped after query finished
assert head_call_count >= 1


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_stops_on_cancel(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops when the query is cancelled."""
head_call_count = 0

def mock_head_response(url, timeout=10, **kwargs):
nonlocal head_call_count
head_call_count += 1

class Resp:
status_code = 200
return Resp()
# Mock the Session's head method
mock_session = mock.Mock()
mock_session.head.side_effect = mock_head_response
mock_requests.Session.return_value = mock_session
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval_seconds=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.1)
query._cancelled = True
time.sleep(0.1)
query._stop_heartbeat()
# Heartbeat should have stopped after query cancelled
assert head_call_count >= 1
90 changes: 89 additions & 1 deletion trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,15 @@ def get(self, url: str) -> Response:
def delete(self, url: str) -> Response:
return self._delete(url, timeout=self._request_timeout, proxies=PROXIES)

def send_heartbeat(self, url: str) -> Response:
"""Send HEAD request for query heartbeat."""
return self._http_session.head(
url,
headers=self.http_headers,
timeout=self._request_timeout,
proxies=PROXIES,
)

@staticmethod
def _process_error(error, query_id: Optional[str]) -> Union[TrinoExternalError, TrinoQueryError, TrinoUserError]:
error_type = error["errorType"]
Expand Down Expand Up @@ -817,7 +826,8 @@ def __init__(
request: TrinoRequest,
query: str,
legacy_primitive_types: bool = False,
fetch_mode: Literal["mapped", "segments"] = "mapped"
fetch_mode: Literal["mapped", "segments"] = "mapped",
heartbeat_interval_seconds: Optional[float] = None,
) -> None:
self._query_id: Optional[str] = None
self._stats: Dict[Any, Any] = {}
Expand All @@ -835,6 +845,10 @@ def __init__(
self._legacy_primitive_types = legacy_primitive_types
self._row_mapper: Optional[RowMapper] = None
self._fetch_mode = fetch_mode
self._heartbeat_interval_seconds = heartbeat_interval_seconds
self._heartbeat_enabled = False
self._heartbeat_thread: Optional[threading.Thread] = None
self._heartbeat_failures = 0

@property
def query_id(self) -> Optional[str]:
Expand Down Expand Up @@ -904,6 +918,10 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
self._result = TrinoResult(self, rows)

# Start heartbeat if interval is set and next_uri is available
if self._heartbeat_interval_seconds is not None and self._next_uri is not None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self._heartbeat_interval_seconds is not None and self._next_uri is not None:
if self._heartbeat_interval_seconds and self._next_uri:

isnt this the same?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

below as well

self._start_heartbeat()

# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
Expand All @@ -928,8 +946,15 @@ def fetch(self) -> List[Union[List[Any]], Any]:
raise trino.exceptions.TrinoConnectionError("failed to fetch: {}".format(e))
status = self._request.process(response)
self._update_state(status)

# Start heartbeat if interval is set and next_uri is now available
if self._heartbeat_interval_seconds is not None and self._next_uri is not None and not self._heartbeat_enabled:
self._start_heartbeat()

if status.next_uri is None:
self._finished = True
# Stop heartbeat when query finishes
self._stop_heartbeat()

if not self._row_mapper:
return []
Expand Down Expand Up @@ -969,6 +994,9 @@ def cancel(self) -> None:
if self._next_uri is None:
return

# Stop heartbeat when query is cancelled
self._stop_heartbeat()

logger.debug("cancelling query: %s", self.query_id)
try:
response = self._request.delete(self._next_uri)
Expand All @@ -981,6 +1009,66 @@ def cancel(self) -> None:

self._request.raise_response_error(response)

def _start_heartbeat(self) -> None:
"""Start sending periodic heartbeat requests."""
if self._heartbeat_interval_seconds is None or self._heartbeat_interval_seconds <= 0:
return

if self._heartbeat_enabled:
return

if self._next_uri is None:
return

self._heartbeat_enabled = True
self._heartbeat_failures = 0

def heartbeat_loop():
while self._heartbeat_enabled and not self._finished and not self._cancelled:
try:
if self._next_uri is None:
break

response = self._request.send_heartbeat(self._next_uri)
status_code = response.status_code

# Stop heartbeat on 404 or 405 (query not found or method not allowed)
if status_code in (404, 405):
self._heartbeat_enabled = False
break

# Reset failure count on success
if status_code == 200:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only 200 and not all 200<=x<300 values?

self._heartbeat_failures = 0
else:
self._heartbeat_failures += 1
# Stop after 3 consecutive failures
if self._heartbeat_failures >= 3:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You drop heartbeat silently after 3 failures. User has no idea.
Perhaps log a warning?

self._heartbeat_enabled = False
break

except Exception:
# On any exception, increment failure count
self._heartbeat_failures += 1
if self._heartbeat_failures >= 3:
self._heartbeat_enabled = False
break

# Sleep for the heartbeat interval
sleep(self._heartbeat_interval_seconds)

self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
self._heartbeat_thread.start()

def _stop_heartbeat(self) -> None:
"""Stop sending heartbeat requests."""
if not self._heartbeat_enabled:
return

self._heartbeat_enabled = False
if self._heartbeat_thread is not None:
self._heartbeat_thread.join(timeout=1.0)

def is_finished(self) -> bool:
import warnings
warnings.warn("is_finished is deprecated, use finished instead", DeprecationWarning)
Expand Down