Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ celerybeat.pid

# Environments
.env
.venv
.venv*
env/
venv/
ENV/
Expand Down Expand Up @@ -169,3 +169,6 @@ cython_debug/

# PyPI configuration file
.pypirc

.idea
.DS_Store
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 2.2.1 /2026-06-29

## What's Changed
* Fix websocket poison connection and leaks on failed requests by @basfroman in https://github.com/latent-to/async-substrate-interface/pull/367

**Full Changelog**: https://github.com/latent-to/async-substrate-interface/compare/v2.2.0...v2.2.1

## 2.2.0 /2026-06-11

## What's Changed
Expand Down
220 changes: 136 additions & 84 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,10 +641,35 @@ def state(self):
return self.ws.state

async def __aenter__(self):
await self._restart_handler_if_dead()
if self.state not in (State.CONNECTING, State.OPEN):
await self.connect()
return self

async def _restart_handler_if_dead(self) -> None:
"""
Revive the background send/recv handler if it has terminated.

When `_handler` finishes (for example by returning `TimeoutError("Max retries exceeded.")` after exhausting its
retries), the underlying socket may still be in the OPEN state. In that case neither `__aenter__` nor `connect`
recreate the handler, since both only act when the state is not OPEN/CONNECTING, so the connection is wedged
permanently and every `retrieve` re-raises the dead task's stored error. Detect that here and force a clean
reconnect (fresh socket and handler) under the lock.
"""
task = self._send_recv_task
if task is None or not task.done():
return
async with self._lock:
task = self._send_recv_task
if task is None or not task.done():
# Another caller already revived the handler.
return
if not task.cancelled():
# Consume the dead task's outcome so it is not later reported as an unretrieved exception.
task.exception()
self._attempts = 0
await self._connect_internal(force=True)

async def mark_waiting_for_response(self):
"""
Mark that a response is expected. This will cause the websocket to not automatically close.
Expand Down Expand Up @@ -1226,8 +1251,8 @@ async def retrieve(self, item_id: str) -> Optional[dict]:
item: Optional[asyncio.Future] = self._received.get(item_id)
if item is not None:
if item.done():
self.max_subscriptions.release()
res = item.result()
self.max_subscriptions.release()
del self._received[item_id]
return res
else:
Expand All @@ -1253,6 +1278,23 @@ async def retrieve(self, item_id: str) -> Optional[dict]:
raise e
return None

async def discard_request(self, item_id: str) -> None:
"""
Drop a request that never completed and release the subscription permit that `send` acquired for it.

This is idempotent and safe to call on ids that already completed:
`retrieve` removes those from `_received` after releasing their permit, so this becomes a no-op and never
double-releases. The id is deliberately not returned to `_in_use_ids`, so a late response from the node for it
is dropped by `_dispatch_response` (which checks `_received`) rather than misrouted to a reused id.
"""
async with self._lock:
fut = self._received.pop(item_id, None)
self._inflight.pop(item_id, None)
if fut is not None:
self.max_subscriptions.release()
if not fut.done():
fut.cancel()


class AsyncSubstrateInterface(SubstrateMixin):
ws: "Websocket"
Expand Down Expand Up @@ -2728,84 +2770,88 @@ async def _make_rpc_request(

async with self.ws as ws:
await ws.mark_waiting_for_response()
for payload in payloads:
item_id = await ws.send(payload["payload"])
request_manager.add_request(item_id, payload["id"])
# truncate to 2000 chars for debug logging
if len(stringified_payload := str(payload)) < 2_000:
output_payload = stringified_payload
else:
output_payload = f"{stringified_payload[:2_000]} (truncated)"
logger.debug(
f"Submitted payload ID {payload['id']} with websocket ID {item_id}: {output_payload}"
)
try:
for payload in payloads:
item_id = await ws.send(payload["payload"])
request_manager.add_request(item_id, payload["id"])
# truncate to 2000 chars for debug logging
if len(stringified_payload := str(payload)) < 2_000:
output_payload = stringified_payload
else:
output_payload = f"{stringified_payload[:2_000]} (truncated)"
logger.debug(
f"Submitted payload ID {payload['id']} with websocket ID {item_id}: {output_payload}"
)

while True:
for item_id in request_manager.unresponded():
if (
item_id not in request_manager.responses
or inspect.iscoroutinefunction(result_handler)
):
if response := await ws.retrieve(item_id):
if (
inspect.iscoroutinefunction(result_handler)
and not subscription_added
):
# handles subscriptions, overwrites the previous mapping of {item_id : payload_id}
# with {subscription_id : payload_id}
try:
item_id = request_manager.overwrite_request(
item_id, response["result"]
)
subscription_added = True
except KeyError:
logger.error(
f"Error received from subtensor for {item_id}: {response}\n"
f"Currently received responses: {request_manager.get_results()}"
while True:
for item_id in request_manager.unresponded():
if (
item_id not in request_manager.responses
or inspect.iscoroutinefunction(result_handler)
):
if response := await ws.retrieve(item_id):
if (
inspect.iscoroutinefunction(result_handler)
and not subscription_added
):
# handles subscriptions, overwrites the previous mapping of {item_id : payload_id}
# with {subscription_id : payload_id}
try:
item_id = request_manager.overwrite_request(
item_id, response["result"]
)
subscription_added = True
except KeyError:
logger.error(
f"Error received from subtensor for {item_id}: {response}\n"
f"Currently received responses: {request_manager.get_results()}"
)
raise SubstrateRequestException(str(response))
(
decoded_response,
complete,
) = await self._process_response(
response,
item_id,
value_scale_type,
storage_item,
result_handler,
runtime=runtime,
)
if (
result_processor is not None
and not inspect.iscoroutinefunction(result_handler)
):
decoded_response = result_processor(
decoded_response, item_id
)
raise SubstrateRequestException(str(response))
(
decoded_response,
complete,
) = await self._process_response(
response,
item_id,
value_scale_type,
storage_item,
result_handler,
runtime=runtime,
)
if (
result_processor is not None
and not inspect.iscoroutinefunction(result_handler)
):
decoded_response = result_processor(
decoded_response, item_id
request_manager.add_response(
item_id, decoded_response, complete
)
request_manager.add_response(
item_id, decoded_response, complete
)
# truncate to 2000 chars for debug logging
if (
len(stringified_response := str(decoded_response))
< 2_000
):
output_response = stringified_response
# avoids clogging logs up needlessly (esp for Metadata stuff)
else:
output_response = (
f"{stringified_response[:2_000]} (truncated)"
# truncate to 2000 chars for debug logging
if (
len(stringified_response := str(decoded_response))
< 2_000
):
output_response = stringified_response
# avoids clogging logs up needlessly (esp for Metadata stuff)
else:
output_response = (
f"{stringified_response[:2_000]} (truncated)"
)
logger.debug(
f"Received response for item ID {item_id}:\n{output_response}\n"
f"Complete: {complete}"
)
logger.debug(
f"Received response for item ID {item_id}:\n{output_response}\n"
f"Complete: {complete}"
)

if request_manager.is_complete:
await ws.mark_response_received()
break
else:
await asyncio.sleep(0.01)
if request_manager.is_complete:
break
else:
await asyncio.sleep(0.01)
finally:
await ws.mark_response_received()
for item_id in request_manager.unresponded():
await ws.discard_request(item_id)

return request_manager.get_results()

Expand Down Expand Up @@ -3670,17 +3716,23 @@ async def runtime_calls(
# Send all calls as one JSON-RPC batch frame, then gather responses by id.
async with self.ws as ws:
await ws.mark_waiting_for_response()
item_ids = await ws.send_batch(payloads)
item_ids: list[str] = []
responses: dict[str, dict] = {}
pending = set(item_ids)
while pending:
for item_id in list(pending):
if (response := await ws.retrieve(item_id)) is not None:
responses[item_id] = response
pending.discard(item_id)
if pending:
await asyncio.sleep(0.01)
await ws.mark_response_received()
pending: set[str] = set()
try:
item_ids = await ws.send_batch(payloads)
pending = set(item_ids)
while pending:
for item_id in list(pending):
if (response := await ws.retrieve(item_id)) is not None:
responses[item_id] = response
pending.discard(item_id)
if pending:
await asyncio.sleep(0.01)
finally:
await ws.mark_response_received()
for item_id in pending:
await ws.discard_request(item_id)

# Decode each result against its own output type, preserving input order.
results: list[ScaleValue] = []
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "async-substrate-interface"
version = "2.2.0"
version = "2.2.1"
description = "Asyncio library for interacting with substrate. Mostly API-compatible with py-substrate-interface"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import subprocess
from collections import namedtuple

CONTAINER_NAME_PREFIX = "test_local_chain_"
LOCALNET_IMAGE_NAME = "ghcr.io/opentensor/subtensor-localnet:devnet-ready"
LOCALNET_IMAGE_NAME = os.getenv(
"LOCALNET_IMAGE_NAME", "ghcr.io/opentensor/subtensor-localnet:devnet-ready"
)

Container = namedtuple("Container", ["process", "name", "uri"])

Expand Down
71 changes: 71 additions & 0 deletions tests/e2e_tests/test_websocket_poison_recovery_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import subprocess

import pytest
from websockets.protocol import State

from async_substrate_interface.async_substrate import AsyncSubstrateInterface
from tests.conftest import start_docker_container
from tests.e2e_tests.test_substrate_addons import wait_for_output
from tests.helpers.async_proxy import AsyncSilenceProxy


@pytest.fixture(scope="function")
def local_chain():
container = start_docker_container(9955, "poison")
try:
if not wait_for_output(container.process, "Imported #1", timeout=60):
raise TimeoutError(
"Docker container did not start properly - 'Imported #1' not found"
)
yield container
finally:
subprocess.run(["docker", "kill", container.name])
container.process.kill()


@pytest.mark.asyncio
async def test_poison_pill_recovers_after_silence(local_chain):
"""
A dead handler left on an OPEN socket (the relayed "Max retries exceeded." poison pill) must recover on the next
call instead of failing forever.

The AsyncSilenceProxy sits between the client and the localnet. Pausing it makes the node go silent without closing
the socket, so the client's retries exhaust and the background handler dies while `ws.state` stays OPEN. Resuming
and issuing one more call must transparently rebuild the connection and succeed.
"""
proxy = await AsyncSilenceProxy(local_chain.uri).start()
try:
substrate = AsyncSubstrateInterface(
proxy.url,
retry_timeout=2.0,
max_retries=2,
ws_shutdown_timer=None,
)
try:
# Baseline: traffic flows through the proxy.
head = await substrate.get_chain_head()
assert head.startswith("0x")

# Go silent: the socket stays open but no responses come back, so the client's retries exhaust and the
# handler dies.
proxy.pause()
poison = None
try:
await substrate.get_chain_head()
except Exception as exc: # noqa: BLE001
poison = exc
assert poison is not None
assert "Max retries exceeded" in str(poison)

# The exact poison condition: a finished handler on an OPEN socket.
assert substrate.ws._send_recv_task.done()
assert substrate.ws.state is State.OPEN

# Resume and prove the next call rebuilds the connection and succeeds.
proxy.resume()
recovered = await substrate.get_chain_head()
assert recovered.startswith("0x")
finally:
await substrate.close()
finally:
await proxy.close()
Loading
Loading