Skip to content

Commit 71ba3df

Browse files
authored
Make client accept a function for websocket uri and hadnshake metadata (#62)
Why === We're seeing stale tokens as part of the connection to pid2, we have no way of regenerating the token What changed ============ Make client accept a function for websocket uri and hadnshake metadata Test plan ========= Updated tests, unfortunately no comprehensive suite here so we'll test this in our internal repo
1 parent ba7c087 commit 71ba3df

3 files changed

Lines changed: 26 additions & 13 deletions

File tree

replit_river/client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
2-
from collections.abc import AsyncIterable, AsyncIterator
3-
from typing import Any, Callable, Optional, Union
2+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
3+
from typing import Any, Optional, Union
44

55
from replit_river.client_transport import ClientTransport
66
from replit_river.transport_options import TransportOptions
@@ -16,22 +16,23 @@
1616

1717

1818
class Client:
19+
1920
def __init__(
2021
self,
21-
websocket_uri: str,
22+
websocket_uri_factory: Callable[[], Awaitable[str]],
2223
client_id: str,
2324
server_id: str,
2425
transport_options: TransportOptions,
25-
handshake_metadata: Optional[Any] = None,
26+
handshake_metadata_factory: Optional[Callable[[], Awaitable[Any]]] = None,
2627
) -> None:
2728
self._client_id = client_id
2829
self._server_id = server_id
2930
self._transport = ClientTransport(
30-
websocket_uri=websocket_uri,
31+
websocket_uri_factory=websocket_uri_factory,
3132
client_id=client_id,
3233
server_id=server_id,
3334
transport_options=transport_options,
34-
handshake_metadata=handshake_metadata,
35+
handshake_metadata_factory=handshake_metadata_factory,
3536
)
3637

3738
async def close(self) -> None:

replit_river/client_transport.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
from collections.abc import Awaitable, Callable
34
from typing import Any, Optional, Tuple
45

56
import websockets
@@ -42,26 +43,27 @@
4243

4344

4445
class ClientTransport(Transport):
46+
4547
def __init__(
4648
self,
47-
websocket_uri: str,
49+
websocket_uri_factory: Callable[[], Awaitable[str]],
4850
client_id: str,
4951
server_id: str,
5052
transport_options: TransportOptions,
51-
handshake_metadata: Optional[Any] = None,
53+
handshake_metadata_factory: Optional[Callable[[], Awaitable[Any]]] = None,
5254
):
5355
super().__init__(
5456
transport_id=client_id,
5557
transport_options=transport_options,
5658
is_server=False,
5759
)
58-
self._websocket_uri = websocket_uri
60+
self._websocket_uri_factory = websocket_uri_factory
5961
self._client_id = client_id
6062
self._server_id = server_id
6163
self._rate_limiter = LeakyBucketRateLimit(
6264
transport_options.connection_retry_options
6365
)
64-
self._handshake_metadata = handshake_metadata
66+
self._handshake_metadata_factory = handshake_metadata_factory
6567
# We want to make sure there's only one session creation at a time
6668
self._create_session_lock = asyncio.Lock()
6769

@@ -107,12 +109,18 @@ async def _establish_new_connection(
107109
break
108110
rate_limit.consume_budget(client_id)
109111
try:
110-
ws = await websockets.connect(self._websocket_uri)
112+
websocket_uri = await self._websocket_uri_factory()
113+
ws = await websockets.connect(websocket_uri)
111114
session_id = (
112115
self.generate_session_id()
113116
if not old_session
114117
else old_session.session_id
115118
)
119+
120+
handshake_metadata = None
121+
if self._handshake_metadata_factory is not None:
122+
handshake_metadata = await self._handshake_metadata_factory()
123+
116124
try:
117125
(
118126
handshake_request,
@@ -121,7 +129,7 @@ async def _establish_new_connection(
121129
self._transport_id,
122130
self._server_id,
123131
session_id,
124-
self._handshake_metadata,
132+
handshake_metadata,
125133
ws,
126134
old_session,
127135
)

tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,14 @@ async def client(
137137
transport_options: TransportOptions,
138138
no_logging_error: NoErrors,
139139
) -> AsyncGenerator[Client, None]:
140+
141+
async def websocket_uri_factory() -> str:
142+
return "ws://localhost:8765"
143+
140144
try:
141145
async with serve(server.serve, "localhost", 8765):
142146
client = Client(
143-
"ws://localhost:8765",
147+
websocket_uri_factory,
144148
client_id="test_client",
145149
server_id="test_server",
146150
transport_options=transport_options,

0 commit comments

Comments
 (0)