Skip to content

Commit 659369e

Browse files
authored
Add support for gRPC client protocol (#80)
Signed-off-by: Anuraag Agrawal <anuraaga@gmail.com>
1 parent a3f0804 commit 659369e

13 files changed

Lines changed: 985 additions & 293 deletions

conformance/test/client.py

Lines changed: 274 additions & 109 deletions
Large diffs are not rendered by default.

conformance/test/server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -708,17 +708,18 @@ def _find_free_port():
708708
return s.getsockname()[1]
709709

710710

711+
Mode = Literal["sync", "async"]
711712
Server = Literal["daphne", "granian", "gunicorn", "hypercorn", "pyvoy", "uvicorn"]
712713

713714

714715
class Args(argparse.Namespace):
715-
mode: Literal["sync", "async"]
716+
mode: Mode
716717
server: Server
717718

718719

719720
async def main() -> None:
720721
parser = argparse.ArgumentParser(description="Conformance server")
721-
parser.add_argument("--mode", choices=["sync", "async"])
722+
parser.add_argument("--mode", choices=get_args(Mode))
722723
parser.add_argument("--server", choices=get_args(Server))
723724
args = parser.parse_args(namespace=Args())
724725

conformance/test/test_client.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,14 @@
1212
_client_py_path = str(_current_dir / "client.py")
1313
_config_path = str(_current_dir / "config.yaml")
1414

15-
_skipped_tests = [
16-
# Not implemented yet,
15+
_skipped_tests_sync = [
16+
# Need to use async APIs for proper cancellation support in Python.
17+
"--skip",
18+
"Client Cancellation/**",
19+
]
20+
21+
_httpx_opts = [
22+
# Trailers not supported
1723
"--skip",
1824
"**/Protocol:PROTOCOL_GRPC/**",
1925
"--skip",
@@ -24,20 +30,28 @@
2430
"gRPC Empty Responses/**",
2531
"--skip",
2632
"gRPC Proto Sub-Format Responses/**",
27-
]
28-
29-
_skipped_tests_sync = [
30-
*_skipped_tests,
31-
# Need to use async APIs for proper cancellation support in Python.
33+
# Bidirectional streaming not supported
3234
"--skip",
35+
"**/full-duplex/**",
36+
# Cancellation delivery isn't reliable
37+
"--known-flaky",
3338
"Client Cancellation/**",
39+
"--known-flaky",
40+
"Timeouts/**",
3441
]
3542

3643

37-
def test_client_sync() -> None:
44+
@pytest.mark.parametrize("client", ["httpx", "pyqwest"])
45+
def test_client_sync(client: str) -> None:
3846
args = maybe_patch_args_with_debug(
39-
[sys.executable, _client_py_path, "--mode", "sync"]
47+
[sys.executable, _client_py_path, "--mode", "sync", "--client", client]
4048
)
49+
50+
opts = []
51+
match client:
52+
case "httpx":
53+
opts = _httpx_opts
54+
4155
result = subprocess.run(
4256
[
4357
"go",
@@ -47,6 +61,7 @@ def test_client_sync() -> None:
4761
_config_path,
4862
"--mode",
4963
"client",
64+
*opts,
5065
*_skipped_tests_sync,
5166
"--",
5267
*args,
@@ -59,18 +74,17 @@ def test_client_sync() -> None:
5974
pytest.fail(f"\n{result.stdout}\n{result.stderr}")
6075

6176

62-
_skipped_tests_async = [
63-
*_skipped_tests,
64-
# Cancellation currently not working for full duplex
65-
"--skip",
66-
"Client Cancellation/**/full-duplex/**",
67-
]
68-
69-
70-
def test_client_async() -> None:
77+
@pytest.mark.parametrize("client", ["httpx", "pyqwest"])
78+
def test_client_async(client: str) -> None:
7179
args = maybe_patch_args_with_debug(
72-
[sys.executable, _client_py_path, "--mode", "async"]
80+
[sys.executable, _client_py_path, "--mode", "async", "--client", client]
7381
)
82+
83+
opts = []
84+
match client:
85+
case "httpx":
86+
opts = _httpx_opts
87+
7488
result = subprocess.run(
7589
[
7690
"go",
@@ -80,9 +94,7 @@ def test_client_async() -> None:
8094
_config_path,
8195
"--mode",
8296
"client",
83-
*_skipped_tests_async,
84-
"--known-flaky",
85-
"Client Cancellation/**",
97+
*opts,
8698
"--",
8799
*args,
88100
],

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ dev = [
5050
"mkdocs==1.6.1",
5151
"mkdocs-material==9.6.20",
5252
"mkdocstrings[python]==0.30.1",
53+
"pyqwest==0.1.0",
5354
"pyright[nodejs]==1.1.405",
5455
"pytest-timeout==2.4.0",
5556
"pyvoy==0.2.0",

src/connectrpc/_client_async.py

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from . import _client_shared
1212
from ._asyncio_timeout import timeout as asyncio_timeout
1313
from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec
14-
from ._envelope import EnvelopeReader
1514
from ._interceptor_async import (
1615
BidiStreamInterceptor,
1716
ClientStreamInterceptor,
@@ -21,10 +20,8 @@
2120
resolve_interceptors,
2221
)
2322
from ._protocol import ConnectWireError
24-
from ._protocol_connect import (
25-
CONNECT_STREAMING_HEADER_COMPRESSION,
26-
ConnectEnvelopeWriter,
27-
)
23+
from ._protocol_connect import ConnectClientProtocol, ConnectEnvelopeWriter
24+
from ._protocol_grpc import GRPCClientProtocol
2825
from ._response_metadata import handle_response_headers
2926
from .code import Code
3027
from .errors import ConnectError
@@ -42,6 +39,7 @@
4239
from types import TracebackType
4340

4441
from ._compression import Compression
42+
from ._envelope import EnvelopeReader
4543
from .method import MethodInfo
4644
from .request import Headers, RequestContext
4745

@@ -91,6 +89,7 @@ def __init__(
9189
address: str,
9290
*,
9391
proto_json: bool = False,
92+
grpc: bool = False,
9493
accept_compression: Iterable[str] | None = None,
9594
send_compression: str | None = None,
9695
timeout_ms: int | None = None,
@@ -128,6 +127,11 @@ def __init__(
128127
self._close_client = True
129128
self._closed = False
130129

130+
if grpc:
131+
self._protocol = GRPCClientProtocol()
132+
else:
133+
self._protocol = ConnectClientProtocol()
134+
131135
interceptors = resolve_interceptors(interceptors)
132136
execute_unary = self._send_request_unary
133137
for interceptor in (
@@ -192,7 +196,7 @@ async def execute_unary(
192196
timeout_ms: int | None = None,
193197
use_get: bool = False,
194198
) -> RES:
195-
ctx = _client_shared.create_request_context(
199+
ctx = self._protocol.create_request_context(
196200
method=method,
197201
http_method="GET" if use_get else "POST",
198202
user_headers=headers,
@@ -212,7 +216,7 @@ async def execute_client_stream(
212216
headers: Headers | Mapping[str, str] | None = None,
213217
timeout_ms: int | None = None,
214218
) -> RES:
215-
ctx = _client_shared.create_request_context(
219+
ctx = self._protocol.create_request_context(
216220
method=method,
217221
http_method="POST",
218222
user_headers=headers,
@@ -232,7 +236,7 @@ def execute_server_stream(
232236
headers: Headers | Mapping[str, str] | None = None,
233237
timeout_ms: int | None = None,
234238
) -> AsyncIterator[RES]:
235-
ctx = _client_shared.create_request_context(
239+
ctx = self._protocol.create_request_context(
236240
method=method,
237241
http_method="POST",
238242
user_headers=headers,
@@ -252,7 +256,7 @@ def execute_bidi_stream(
252256
headers: Headers | Mapping[str, str] | None = None,
253257
timeout_ms: int | None = None,
254258
) -> AsyncIterator[RES]:
255-
ctx = _client_shared.create_request_context(
259+
ctx = self._protocol.create_request_context(
256260
method=method,
257261
http_method="POST",
258262
user_headers=headers,
@@ -267,6 +271,11 @@ def execute_bidi_stream(
267271
async def _send_request_unary(
268272
self, request: REQ, ctx: RequestContext[REQ, RES]
269273
) -> RES:
274+
if isinstance(self._protocol, GRPCClientProtocol):
275+
return await _consume_single_response(
276+
self._send_request_bidi_stream(_yield_single_message(request), ctx)
277+
)
278+
270279
request_headers = httpx.Headers(list(ctx.request_headers().allitems()))
271280
url = f"{self._address}/{ctx.method().service_name}/{ctx.method().name}"
272281
if (timeout_ms := ctx.timeout_ms()) is not None:
@@ -303,14 +312,14 @@ async def _send_request_unary(
303312
timeout_s,
304313
)
305314

306-
_client_shared.validate_response_content_encoding(
307-
resp.headers.get("content-encoding", "")
308-
)
309-
_client_shared.validate_response_content_type(
315+
self._protocol.validate_response(
310316
self._codec.name(),
311317
resp.status_code,
312318
resp.headers.get("content-type", ""),
313319
)
320+
# Decompression itself is handled by httpx, but we validate it
321+
# by resolving it.
322+
self._protocol.handle_response_compression(resp.headers, stream=False)
314323
handle_response_headers(resp.headers)
315324

316325
if resp.status_code == 200:
@@ -360,51 +369,61 @@ async def _send_request_bidi_stream(
360369
timeout_s = None
361370
timeout = USE_CLIENT_DEFAULT
362371

372+
reader: EnvelopeReader | None = None
373+
resp: httpx.Response | None = None
363374
try:
364375
request_data = _streaming_request_content(
365376
request, self._codec, self._send_compression
366377
)
367378

368-
async with (
369-
asyncio_timeout(timeout_s),
370-
self._session.stream(
379+
async with asyncio_timeout(timeout_s):
380+
httpx_req = self._session.build_request(
371381
method="POST",
372382
url=url,
373383
headers=request_headers,
374384
content=request_data,
375385
timeout=timeout,
376-
) as resp,
377-
):
378-
compression = _client_shared.validate_response_content_encoding(
379-
resp.headers.get(CONNECT_STREAMING_HEADER_COMPRESSION, "")
380386
)
381-
_client_shared.validate_stream_response_content_type(
382-
self._codec.name(), resp.headers.get("content-type", "")
383-
)
384-
handle_response_headers(resp.headers)
385-
386-
if resp.status_code == 200:
387-
reader = EnvelopeReader(
388-
ctx.method().output,
389-
self._codec,
390-
compression,
391-
self._read_max_bytes,
392-
)
393-
async for chunk in resp.aiter_bytes():
394-
for message in reader.feed(chunk):
395-
yield message
396-
# Check for cancellation each message. While this seems heavyweight,
397-
# conformance tests require it.
398-
await sleep(0)
399-
else:
400-
raise ConnectWireError.from_response(resp).to_exception()
387+
resp = await self._session.send(httpx_req, stream=True)
388+
try:
389+
handle_response_headers(resp.headers)
390+
if resp.status_code == 200:
391+
self._protocol.validate_stream_response(
392+
self._codec.name(), resp.headers.get("content-type", "")
393+
)
394+
compression = self._protocol.handle_response_compression(
395+
resp.headers, stream=True
396+
)
397+
reader = self._protocol.create_envelope_reader(
398+
ctx.method().output,
399+
self._codec,
400+
compression,
401+
self._read_max_bytes,
402+
)
403+
async for chunk in resp.aiter_bytes():
404+
for message in reader.feed(chunk):
405+
yield message
406+
# Check for cancellation each message. While this seems heavyweight,
407+
# conformance tests require it.
408+
await sleep(0)
409+
reader.handle_response_complete(resp)
410+
else:
411+
raise ConnectWireError.from_response(resp).to_exception()
412+
finally:
413+
await asyncio.shield(resp.aclose())
401414
except (httpx.TimeoutException, TimeoutError, asyncio.TimeoutError) as e:
402415
raise ConnectError(Code.DEADLINE_EXCEEDED, "Request timed out") from e
403416
except ConnectError:
404417
raise
405418
except CancelledError as e:
406419
raise ConnectError(Code.CANCELED, "Request was cancelled") from e
407420
except Exception as e:
421+
if rst_err := _client_shared.maybe_map_stream_reset(e, ctx):
422+
# It is possible for a reset to come with trailers which should
423+
# be used.
424+
if reader and resp:
425+
reader.handle_response_complete(resp, rst_err)
426+
raise rst_err from e
408427
raise ConnectError(Code.UNAVAILABLE, str(e)) from e
409428

410429

0 commit comments

Comments
 (0)