Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions python/lib/sift_client/_internal/low_level_wrappers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import logging
from datetime import datetime, timezone
from math import ceil
from typing import TYPE_CHECKING, Any, cast

import pandas as pd
Expand Down Expand Up @@ -231,7 +230,8 @@ async def get_channel_data(
run_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int | None = None,
max_results: int | None = None,
page_size: int | None = None,
ignore_cache: bool = False,
) -> dict[str, pd.DataFrame]:
"""Get the data for a channel during a run."""
Expand All @@ -247,8 +247,6 @@ async def get_channel_data(
)

tasks = []
page_size = limit if limit and limit < 1000 else 1000
limit = ceil(limit / page_size) if limit else 10
# Queue up calls for non-cached channels in batches.
batch_size = REQUEST_BATCH_SIZE
for i in range(0, len(not_cached_channels), batch_size): # type: ignore
Expand All @@ -264,7 +262,7 @@ async def get_channel_data(
"end_time": end_time,
},
page_size=page_size,
max_results=limit,
max_results=max_results,
)
)
tasks.append(task)
Expand Down Expand Up @@ -294,7 +292,7 @@ async def get_channel_data(
"end_time": new_end_time or end_time,
},
page_size=page_size,
max_results=limit,
max_results=max_results,
)
)
tasks.append(task)
Expand Down
12 changes: 8 additions & 4 deletions python/lib/sift_client/_tests/sift_types/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def test_data_method_calls_get_data(self, mock_channel, mock_client):
run_id="run123",
start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
end_time=datetime(2024, 1, 2, tzinfo=timezone.utc),
limit=100,
max_results=100,
page_size=None,
)

# Verify client method was called with correct parameters
Expand All @@ -78,7 +79,8 @@ def test_data_method_calls_get_data(self, mock_channel, mock_client):
run="run123",
start_time=datetime(2024, 1, 1, tzinfo=timezone.utc),
end_time=datetime(2024, 1, 2, tzinfo=timezone.utc),
limit=100,
max_results=100,
page_size=None,
)
assert result == mock_data

Expand All @@ -99,7 +101,8 @@ def test_data_method_as_arrow(self, mock_channel, mock_client):
run="run123",
start_time=None,
end_time=None,
limit=None,
max_results=None,
page_size=None,
)
mock_client.channels.get_data.assert_not_called()
assert result == mock_data
Expand All @@ -118,6 +121,7 @@ def test_data_method_with_minimal_params(self, mock_channel, mock_client):
run=None,
start_time=None,
end_time=None,
limit=None,
max_results=None,
page_size=None,
)
assert result == mock_data
2 changes: 1 addition & 1 deletion python/lib/sift_client/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def get_data():
run="run123",
start_time=datetime.now() - timedelta(hours=1),
end_time=datetime.now(),
limit=10000
max_results=10000
)

# data is a dict mapping channel names to DataFrames
Expand Down
17 changes: 11 additions & 6 deletions python/lib/sift_client/resources/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ async def get_data(
run: Run | str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int | None = None,
max_results: int | None = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other resources used limit for the purpose, why change the arg name?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the real issue here was the math i was doing w/ page size but was just a rush to unblock Gui. Opened a different PR to fix that and do things the same way we do for channels list #399

page_size: int | None = None,
) -> dict[str, pd.DataFrame]:
"""Get data for one or more channels.

Expand All @@ -186,7 +187,8 @@ async def get_data(
run: The Run or run_id to get data for.
start_time: The start time to get data for.
end_time: The end time to get data for.
limit: The maximum number of data points to return. Will be in increments of page_size or default page size defined by the call if no page_size is provided.
max_results: The maximum number of data points to return.
page_size: The number of data points to return per page.

Returns:
A dictionary mapping channel names to pandas DataFrames containing the channel data.
Expand All @@ -199,7 +201,8 @@ async def get_data(
run_id=run_id,
start_time=start_time,
end_time=end_time,
limit=limit,
max_results=max_results,
page_size=page_size,
)

async def get_data_as_arrow(
Expand All @@ -209,9 +212,10 @@ async def get_data_as_arrow(
run: Run | str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int | None = None,
max_results: int | None = None,
page_size: int | None = None,
) -> dict[str, pa.Table]:
"""Get data for one or more channels as pyarrow tables."""
"""Same as get_data but returns data as pyarrow tables."""
from pyarrow import Table as ArrowTable

run_id = run.id_ if isinstance(run, Run) else run
Expand All @@ -220,6 +224,7 @@ async def get_data_as_arrow(
run=run_id,
start_time=start_time,
end_time=end_time,
limit=limit,
max_results=max_results,
page_size=page_size,
)
return {k: ArrowTable.from_pandas(v) for k, v in data.items()}
11 changes: 7 additions & 4 deletions python/lib/sift_client/resources/sync_stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ class ChannelsAPI:
run: Run | str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int | None = None,
max_results: int | None = None,
page_size: int | None = None,
) -> dict[str, pd.DataFrame]:
"""Get data for one or more channels.

Expand All @@ -440,7 +441,8 @@ class ChannelsAPI:
run: The Run or run_id to get data for.
start_time: The start time to get data for.
end_time: The end time to get data for.
limit: The maximum number of data points to return. Will be in increments of page_size or default page size defined by the call if no page_size is provided.
max_results: The maximum number of data points to return.
page_size: The number of data points to return per page.

Returns:
A dictionary mapping channel names to pandas DataFrames containing the channel data.
Expand All @@ -454,9 +456,10 @@ class ChannelsAPI:
run: Run | str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int | None = None,
max_results: int | None = None,
page_size: int | None = None,
) -> dict[str, pa.Table]:
"""Get data for one or more channels as pyarrow tables."""
"""Same as get_data but returns data as pyarrow tables."""
...

def list_(
Expand Down
12 changes: 8 additions & 4 deletions python/lib/sift_client/sift_types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def data(
run_id: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
limit: int | None = None,
page_size: int | None = None,
max_results: int | None = None,
as_arrow: bool = False,
):
"""Retrieve channel data for this channel during the specified run.
Expand All @@ -307,7 +308,8 @@ def data(
run_id: The run ID to get data for.
start_time: The start time to get data for.
end_time: The end time to get data for.
limit: The maximum number of data points to return.
page_size: The number of data points to return per page.
max_results: The maximum number of data points to return.
as_arrow: Whether to return the data as an Arrow table.

Returns:
Expand All @@ -319,15 +321,17 @@ def data(
run=run_id,
start_time=start_time,
end_time=end_time,
limit=limit, # type: ignore
max_results=max_results,
page_size=page_size,
)
else:
data = self.client.channels.get_data(
channels=[self],
run=run_id,
start_time=start_time,
end_time=end_time,
limit=limit, # type: ignore
max_results=max_results,
page_size=page_size,
)
return data

Expand Down
Loading