diff --git a/lighter/ws_client.py b/lighter/ws_client.py index 1369417..d226dec 100644 --- a/lighter/ws_client.py +++ b/lighter/ws_client.py @@ -1,4 +1,6 @@ import json +import threading +import asyncio from websockets.sync.client import connect from websockets.client import connect as connect_async from lighter.configuration import Configuration @@ -12,6 +14,7 @@ def __init__( account_ids=[], on_order_book_update=print, on_account_update=print, + ping_interval=30, ): if host is None: host = Configuration.get_default().host.replace("https://", "") @@ -33,6 +36,7 @@ def __init__( self.on_account_update = on_account_update self.ws = None + self.ping_interval = ping_interval def on_message(self, ws, message): if isinstance(message, str): @@ -53,6 +57,8 @@ def on_message(self, ws, message): elif message_type == "ping": # Respond to ping with pong ws.send(json.dumps({"type": "pong"})) + elif message_type == "pong": + pass else: self.handle_unhandled_message(message) @@ -65,6 +71,9 @@ async def on_message_async(self, ws, message): elif message_type == "ping": # Respond to ping with pong await ws.send(json.dumps({"type": "pong"})) + elif message_type == "pong": + # Noop + pass else: self.on_message(ws, message) @@ -153,16 +162,60 @@ def on_error(self, ws, error): def on_close(self, ws, close_status_code, close_msg): raise Exception(f"Closed: {close_status_code} {close_msg}") - def run(self): - ws = connect(self.base_url) - self.ws = ws + def _ping_loop(self, stop_event): + while not stop_event.is_set(): + stop_event.wait(self.ping_interval) + if self.ws and not stop_event.is_set(): + try: + self.ws.send(json.dumps({"type": "ping"})) + except Exception as e: + print(f"Ping failed: {e}") + break + + async def _ping_loop_async(self, stop_event): + while not stop_event.is_set(): + try: + await asyncio.sleep(self.ping_interval) + if self.ws and not stop_event.is_set(): + await self.ws.send(json.dumps({"type": "ping"})) + except asyncio.CancelledError: + break + except Exception as e: + print(f"Async ping failed: {e}") + break - for message in ws: - self.on_message(ws, message) + def run(self): + stop_event = threading.Event() + ping_thread = None + try: + with connect(self.base_url) as ws: + self.ws = ws + ping_thread = threading.Thread(target=self._ping_loop, args=(stop_event,), daemon=True) + ping_thread.start() + + for message in ws: + self.on_message(ws, message) + finally: + stop_event.set() + if ping_thread: + ping_thread.join(timeout=1) + self.ws = None # clear after thread has exited async def run_async(self): - ws = await connect_async(self.base_url) - self.ws = ws + stop_event = asyncio.Event() + ping_task = None + try: + async with connect_async(self.base_url) as ws: + self.ws = ws + ping_task = asyncio.create_task(self._ping_loop_async(stop_event)) + + async for message in ws: + await self.on_message_async(ws, message) + finally: + stop_event.set() + if ping_task: + ping_task.cancel() + # Wait for the task to acknowledge cancellation + await asyncio.gather(ping_task, return_exceptions=True) + self.ws = None - async for message in ws: - await self.on_message_async(ws, message)