1111from . import _client_shared
1212from ._asyncio_timeout import timeout as asyncio_timeout
1313from ._codec import Codec , get_proto_binary_codec , get_proto_json_codec
14- from ._envelope import EnvelopeReader
1514from ._interceptor_async import (
1615 BidiStreamInterceptor ,
1716 ClientStreamInterceptor ,
2120 resolve_interceptors ,
2221)
2322from ._protocol import ConnectWireError
24- from ._protocol_connect import (
25- CONNECT_STREAMING_HEADER_COMPRESSION ,
26- ConnectEnvelopeWriter ,
27- )
23+ from ._protocol_connect import ConnectClientProtocol , ConnectEnvelopeWriter
24+ from ._protocol_grpc import GRPCClientProtocol
2825from ._response_metadata import handle_response_headers
2926from .code import Code
3027from .errors import ConnectError
4239 from types import TracebackType
4340
4441 from ._compression import Compression
42+ from ._envelope import EnvelopeReader
4543 from .method import MethodInfo
4644 from .request import Headers , RequestContext
4745
@@ -91,6 +89,7 @@ def __init__(
9189 address : str ,
9290 * ,
9391 proto_json : bool = False ,
92+ grpc : bool = False ,
9493 accept_compression : Iterable [str ] | None = None ,
9594 send_compression : str | None = None ,
9695 timeout_ms : int | None = None ,
@@ -128,6 +127,11 @@ def __init__(
128127 self ._close_client = True
129128 self ._closed = False
130129
130+ if grpc :
131+ self ._protocol = GRPCClientProtocol ()
132+ else :
133+ self ._protocol = ConnectClientProtocol ()
134+
131135 interceptors = resolve_interceptors (interceptors )
132136 execute_unary = self ._send_request_unary
133137 for interceptor in (
@@ -192,7 +196,7 @@ async def execute_unary(
192196 timeout_ms : int | None = None ,
193197 use_get : bool = False ,
194198 ) -> RES :
195- ctx = _client_shared .create_request_context (
199+ ctx = self . _protocol .create_request_context (
196200 method = method ,
197201 http_method = "GET" if use_get else "POST" ,
198202 user_headers = headers ,
@@ -212,7 +216,7 @@ async def execute_client_stream(
212216 headers : Headers | Mapping [str , str ] | None = None ,
213217 timeout_ms : int | None = None ,
214218 ) -> RES :
215- ctx = _client_shared .create_request_context (
219+ ctx = self . _protocol .create_request_context (
216220 method = method ,
217221 http_method = "POST" ,
218222 user_headers = headers ,
@@ -232,7 +236,7 @@ def execute_server_stream(
232236 headers : Headers | Mapping [str , str ] | None = None ,
233237 timeout_ms : int | None = None ,
234238 ) -> AsyncIterator [RES ]:
235- ctx = _client_shared .create_request_context (
239+ ctx = self . _protocol .create_request_context (
236240 method = method ,
237241 http_method = "POST" ,
238242 user_headers = headers ,
@@ -252,7 +256,7 @@ def execute_bidi_stream(
252256 headers : Headers | Mapping [str , str ] | None = None ,
253257 timeout_ms : int | None = None ,
254258 ) -> AsyncIterator [RES ]:
255- ctx = _client_shared .create_request_context (
259+ ctx = self . _protocol .create_request_context (
256260 method = method ,
257261 http_method = "POST" ,
258262 user_headers = headers ,
@@ -267,6 +271,11 @@ def execute_bidi_stream(
267271 async def _send_request_unary (
268272 self , request : REQ , ctx : RequestContext [REQ , RES ]
269273 ) -> RES :
274+ if isinstance (self ._protocol , GRPCClientProtocol ):
275+ return await _consume_single_response (
276+ self ._send_request_bidi_stream (_yield_single_message (request ), ctx )
277+ )
278+
270279 request_headers = httpx .Headers (list (ctx .request_headers ().allitems ()))
271280 url = f"{ self ._address } /{ ctx .method ().service_name } /{ ctx .method ().name } "
272281 if (timeout_ms := ctx .timeout_ms ()) is not None :
@@ -303,14 +312,14 @@ async def _send_request_unary(
303312 timeout_s ,
304313 )
305314
306- _client_shared .validate_response_content_encoding (
307- resp .headers .get ("content-encoding" , "" )
308- )
309- _client_shared .validate_response_content_type (
315+ self ._protocol .validate_response (
310316 self ._codec .name (),
311317 resp .status_code ,
312318 resp .headers .get ("content-type" , "" ),
313319 )
320+ # Decompression itself is handled by httpx, but we validate it
321+ # by resolving it.
322+ self ._protocol .handle_response_compression (resp .headers , stream = False )
314323 handle_response_headers (resp .headers )
315324
316325 if resp .status_code == 200 :
@@ -360,51 +369,61 @@ async def _send_request_bidi_stream(
360369 timeout_s = None
361370 timeout = USE_CLIENT_DEFAULT
362371
372+ reader : EnvelopeReader | None = None
373+ resp : httpx .Response | None = None
363374 try :
364375 request_data = _streaming_request_content (
365376 request , self ._codec , self ._send_compression
366377 )
367378
368- async with (
369- asyncio_timeout (timeout_s ),
370- self ._session .stream (
379+ async with asyncio_timeout (timeout_s ):
380+ httpx_req = self ._session .build_request (
371381 method = "POST" ,
372382 url = url ,
373383 headers = request_headers ,
374384 content = request_data ,
375385 timeout = timeout ,
376- ) as resp ,
377- ):
378- compression = _client_shared .validate_response_content_encoding (
379- resp .headers .get (CONNECT_STREAMING_HEADER_COMPRESSION , "" )
380386 )
381- _client_shared .validate_stream_response_content_type (
382- self ._codec .name (), resp .headers .get ("content-type" , "" )
383- )
384- handle_response_headers (resp .headers )
385-
386- if resp .status_code == 200 :
387- reader = EnvelopeReader (
388- ctx .method ().output ,
389- self ._codec ,
390- compression ,
391- self ._read_max_bytes ,
392- )
393- async for chunk in resp .aiter_bytes ():
394- for message in reader .feed (chunk ):
395- yield message
396- # Check for cancellation each message. While this seems heavyweight,
397- # conformance tests require it.
398- await sleep (0 )
399- else :
400- raise ConnectWireError .from_response (resp ).to_exception ()
387+ resp = await self ._session .send (httpx_req , stream = True )
388+ try :
389+ handle_response_headers (resp .headers )
390+ if resp .status_code == 200 :
391+ self ._protocol .validate_stream_response (
392+ self ._codec .name (), resp .headers .get ("content-type" , "" )
393+ )
394+ compression = self ._protocol .handle_response_compression (
395+ resp .headers , stream = True
396+ )
397+ reader = self ._protocol .create_envelope_reader (
398+ ctx .method ().output ,
399+ self ._codec ,
400+ compression ,
401+ self ._read_max_bytes ,
402+ )
403+ async for chunk in resp .aiter_bytes ():
404+ for message in reader .feed (chunk ):
405+ yield message
406+ # Check for cancellation each message. While this seems heavyweight,
407+ # conformance tests require it.
408+ await sleep (0 )
409+ reader .handle_response_complete (resp )
410+ else :
411+ raise ConnectWireError .from_response (resp ).to_exception ()
412+ finally :
413+ await asyncio .shield (resp .aclose ())
401414 except (httpx .TimeoutException , TimeoutError , asyncio .TimeoutError ) as e :
402415 raise ConnectError (Code .DEADLINE_EXCEEDED , "Request timed out" ) from e
403416 except ConnectError :
404417 raise
405418 except CancelledError as e :
406419 raise ConnectError (Code .CANCELED , "Request was cancelled" ) from e
407420 except Exception as e :
421+ if rst_err := _client_shared .maybe_map_stream_reset (e , ctx ):
422+ # It is possible for a reset to come with trailers which should
423+ # be used.
424+ if reader and resp :
425+ reader .handle_response_complete (resp , rst_err )
426+ raise rst_err from e
408427 raise ConnectError (Code .UNAVAILABLE , str (e )) from e
409428
410429
0 commit comments