From 18f6d74e52aecbd734b5b7cc75390d12596cfcf1 Mon Sep 17 00:00:00 2001 From: Tim Cuthbertson Date: Mon, 4 Jul 2022 10:02:26 +1000 Subject: [PATCH 1/8] Implement backpressure for sending streams --- .../scala/fs2/grpc/client/Fs2ClientCall.scala | 40 +++++++----- .../client/Fs2StreamClientCallListener.scala | 7 +- .../client/internal/Fs2UnaryCallHandler.scala | 26 ++++---- .../scala/fs2/grpc/server/Fs2ServerCall.scala | 2 +- .../grpc/server/Fs2ServerCallHandler.scala | 10 ++- .../grpc/server/Fs2ServerCallListener.scala | 13 ++-- .../server/Fs2StreamServerCallListener.scala | 8 ++- .../grpc/server/internal/Fs2ServerCall.scala | 6 +- .../internal/Fs2UnaryServerCallHandler.scala | 28 ++++---- .../scala/fs2/grpc/shared/StreamOutput.scala | 65 +++++++++++++++++++ .../scala/fs2/grpc/client/ClientSuite.scala | 58 +++++++++++++++++ .../fs2/grpc/client/DummyClientCall.scala | 10 +++ .../fs2/grpc/server/DummyServerCall.scala | 9 +++ .../scala/fs2/grpc/server/ServerSuite.scala | 53 ++++++++++++++- 14 files changed, 275 insertions(+), 60 deletions(-) create mode 100644 runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala index 2915be94..b8a20d91 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala @@ -23,11 +23,12 @@ package fs2 package grpc package client -import cats.syntax.all._ -import cats.effect.{Async, Resource} import cats.effect.std.Dispatcher +import cats.effect.{Async, Resource, SyncIO} +import cats.syntax.all._ import fs2.grpc.client.internal.Fs2UnaryCallHandler -import io.grpc.{Metadata, _} +import fs2.grpc.shared.StreamOutput +import io.grpc._ final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus]) final case class GrpcStatus(status: Status, trailers: Metadata) @@ -51,17 +52,11 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( private def request(numMessages: Int): F[Unit] = F.delay(call.request(numMessages)) - private def sendMessage(message: Request): F[Unit] = - F.delay(call.sendMessage(message)) - private def start[A <: ClientCall.Listener[Response]](createListener: F[A], md: Metadata): F[A] = createListener.flatTap(l => F.delay(call.start(l, md))) private def sendSingleMessage(message: Request): F[Unit] = - sendMessage(message) *> halfClose - - private def sendStream(stream: Stream[F, Request]): Stream[F, Unit] = - stream.evalMap(sendMessage) ++ Stream.eval(halfClose) + F.delay(call.sendMessage(message)) *> halfClose // @@ -69,17 +64,27 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( Fs2UnaryCallHandler.unary(call, options, message, headers) def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] = - Fs2UnaryCallHandler.stream(call, options, messages, headers) + StreamOutput.client(call, dispatcher).flatMap { output => + Fs2UnaryCallHandler.stream(call, options, messages, output, headers) + } def unaryToStreamingCall(message: Request, md: Metadata): Stream[F, Response] = Stream - .resource(mkStreamListenerR(md)) + .resource(mkStreamListenerR(md, SyncIO.unit)) .flatMap(Stream.exec(sendSingleMessage(message)) ++ _.stream.adaptError(ea)) - def streamingToStreamingCall(messages: Stream[F, Request], md: Metadata): Stream[F, Response] = + def streamingToStreamingCall(messages: Stream[F, Request], md: Metadata): Stream[F, Response] = { + val listenerAndOutput = Resource.eval(StreamOutput.client(call, dispatcher)).flatMap { output => + mkStreamListenerR(md, output.onReady).map((_, output)) + } + Stream - .resource(mkStreamListenerR(md)) - .flatMap(_.stream.adaptError(ea).concurrently(sendStream(messages))) + .resource(listenerAndOutput) + .flatMap { case (listener, output) => + listener.stream.adaptError(ea) + .concurrently(output.writeStream(messages) ++ Stream.eval(halfClose)) + } + } // @@ -89,10 +94,9 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( case (_, Resource.ExitCase.Errored(t)) => cancel(t.getMessage.some, t.some) } - private def mkStreamListenerR(md: Metadata): Resource[F, Fs2StreamClientCallListener[F, Response]] = { - + private def mkStreamListenerR(md: Metadata, signalReadiness: SyncIO[Unit]): Resource[F, Fs2StreamClientCallListener[F, Response]] = { val prefetchN = options.prefetchN.max(1) - val create = Fs2StreamClientCallListener.create[F, Response](request, dispatcher, prefetchN) + val create = Fs2StreamClientCallListener.create[F, Response](request, signalReadiness, dispatcher, prefetchN) val acquire = start(create, md) <* request(prefetchN) val release = handleExitCase(cancelSucceed = true) diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala index ef657391..348f01e9 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala @@ -23,6 +23,7 @@ package fs2 package grpc package client +import cats.effect.SyncIO import cats.implicits._ import cats.effect.kernel.Concurrent import cats.effect.std.Dispatcher @@ -30,6 +31,7 @@ import io.grpc.{ClientCall, Metadata, Status} class Fs2StreamClientCallListener[F[_], Response] private ( ingest: StreamIngest[F, Response], + signalReadiness: SyncIO[Unit], dispatcher: Dispatcher[F] ) extends ClientCall.Listener[Response] { @@ -39,6 +41,8 @@ class Fs2StreamClientCallListener[F[_], Response] private ( override def onClose(status: Status, trailers: Metadata): Unit = dispatcher.unsafeRunSync(ingest.onClose(GrpcStatus(status, trailers))) + override def onReady(): Unit = signalReadiness.unsafeRunSync() + val stream: Stream[F, Response] = ingest.messages } @@ -46,11 +50,12 @@ object Fs2StreamClientCallListener { private[client] def create[F[_]: Concurrent, Response]( request: Int => F[Unit], + signalReadiness: SyncIO[Unit], dispatcher: Dispatcher[F], prefetchN: Int ): F[Fs2StreamClientCallListener[F, Response]] = StreamIngest[F, Response](request, prefetchN).map( - new Fs2StreamClientCallListener[F, Response](_, dispatcher) + new Fs2StreamClientCallListener[F, Response](_, signalReadiness, dispatcher) ) } diff --git a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala index 40491575..45ac759a 100644 --- a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala @@ -21,16 +21,14 @@ package fs2.grpc.client.internal -import cats.effect.Sync -import cats.effect.SyncIO +import cats.effect.kernel.{Async, Outcome, Ref} import cats.effect.syntax.all._ -import cats.effect.kernel.Async -import cats.effect.kernel.Outcome -import cats.effect.kernel.Ref -import cats.syntax.functor._ +import cats.effect.{Sync, SyncIO} import cats.syntax.flatMap._ +import cats.syntax.functor._ import fs2._ import fs2.grpc.client.ClientOptions +import fs2.grpc.shared.StreamOutput import io.grpc._ private[client] object Fs2UnaryCallHandler { @@ -65,7 +63,8 @@ private[client] object Fs2UnaryCallHandler { class Done[R] extends ReceiveState[R] private def mkListener[Response]( - state: Ref[SyncIO, ReceiveState[Response]] + state: Ref[SyncIO, ReceiveState[Response]], + signalReadiness: SyncIO[Unit] ): ClientCall.Listener[Response] = new ClientCall.Listener[Response] { override def onMessage(message: Response): Unit = @@ -110,6 +109,8 @@ private[client] object Fs2UnaryCallHandler { } } }.unsafeRunSync() + + override def onReady(): Unit = signalReadiness.unsafeRunSync() } def unary[F[_], Request, Response]( @@ -119,7 +120,7 @@ private[client] object Fs2UnaryCallHandler { headers: Metadata )(implicit F: Async[F]): F[Response] = F.async[Response] { cb => ReceiveState.init(cb, options.errorAdapter).map { state => - call.start(mkListener[Response](state), headers) + call.start(mkListener[Response](state, SyncIO.unit), headers) // Initially ask for two responses from flow-control so that if a misbehaving server // sends more than one responses, we can catch it and fail it in the listener. call.request(2) @@ -133,18 +134,15 @@ private[client] object Fs2UnaryCallHandler { call: ClientCall[Request, Response], options: ClientOptions, messages: Stream[F, Request], + output: StreamOutput[F, Request], headers: Metadata )(implicit F: Async[F]): F[Response] = F.async[Response] { cb => ReceiveState.init(cb, options.errorAdapter).flatMap { state => - call.start(mkListener[Response](state), headers) + call.start(mkListener[Response](state, output.onReady), headers) // Initially ask for two responses from flow-control so that if a misbehaving server // sends more than one responses, we can catch it and fail it in the listener. call.request(2) - messages - .map(call.sendMessage) - .compile - .drain - .guaranteeCase { + output.writeStream(messages).compile.drain.guaranteeCase { case Outcome.Succeeded(_) => F.delay(call.halfClose()) case Outcome.Errored(e) => F.delay(call.cancel(e.getMessage, e)) case Outcome.Canceled() => onCancel(call) diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala index a7dba3c6..c31acb9c 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala @@ -33,7 +33,7 @@ private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCal def closeStream(status: Status, trailers: Metadata)(implicit F: Sync[F]): F[Unit] = F.delay(call.close(status, trailers)) - def sendMessage(message: Response)(implicit F: Sync[F]): F[Unit] = + def sendSingleMessage(message: Response)(implicit F: Sync[F]): F[Unit] = F.delay(call.sendMessage(message)) def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] = diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala index b98c9526..831a4d4c 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala @@ -25,7 +25,9 @@ package server import cats.effect._ import cats.effect.std.Dispatcher +import cats.syntax.all._ import fs2.grpc.server.internal.Fs2UnaryServerCallHandler +import fs2.grpc.shared.StreamOutput import io.grpc._ class Fs2ServerCallHandler[F[_]: Async] private ( @@ -47,7 +49,7 @@ class Fs2ServerCallHandler[F[_]: Async] private ( implementation: (Stream[F, Request], Metadata) => F[Response] ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options)) + val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, SyncIO.unit, dispatcher, options)) listener.unsafeUnaryResponse(new Metadata(), implementation(_, headers)) listener } @@ -57,8 +59,10 @@ class Fs2ServerCallHandler[F[_]: Async] private ( implementation: (Stream[F, Request], Metadata) => Stream[F, Response] ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options)) - listener.unsafeStreamResponse(new Metadata(), implementation(_, headers)) + val (listener, streamOutput) = dispatcher.unsafeRunSync(StreamOutput.server(call, dispatcher).flatMap { output => + Fs2StreamServerCallListener[F](call, output.onReady, dispatcher, options).map((_, output)) + }) + listener.unsafeStreamResponse(streamOutput, new Metadata(), implementation(_, headers)) listener } } diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala index 1c2b22c9..de8ea01a 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala @@ -23,9 +23,10 @@ package fs2 package grpc package server -import cats.syntax.all._ import cats.effect._ import cats.effect.std.Dispatcher +import cats.syntax.all._ +import fs2.grpc.shared.StreamOutput import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException} private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { @@ -49,10 +50,10 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { } private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] = - call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendMessage + call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendSingleMessage - private def handleStreamResponse(headers: Metadata, response: Stream[F, Response])(implicit F: Sync[F]): F[Unit] = - call.sendHeaders(headers) *> call.request(1) *> response.evalMap(call.sendMessage).compile.drain + private def handleStreamResponse(headers: Metadata, sendResponse: Stream[F, Unit])(implicit F: Sync[F]): F[Unit] = + call.sendHeaders(headers) *> call.request(1) *> sendResponse.compile.drain private def unsafeRun(f: F[Unit])(implicit F: Async[F]): Unit = { val bracketed = F.guaranteeCase(f) { @@ -70,8 +71,8 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { ): Unit = unsafeRun(handleUnaryResponse(headers, implementation(source))) - def unsafeStreamResponse(headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit + def unsafeStreamResponse(streamOutput: StreamOutput[F, Response], headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit F: Async[F] ): Unit = - unsafeRun(handleStreamResponse(headers, implementation(source))) + unsafeRun(handleStreamResponse(headers, streamOutput.writeStream(implementation(source)))) } diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala index 137e9174..e6de51ee 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala @@ -26,12 +26,13 @@ package server import cats.Functor import cats.syntax.all._ import cats.effect.kernel.Deferred -import cats.effect.Async +import cats.effect.{Async, SyncIO} import cats.effect.std.{Dispatcher, Queue} import io.grpc.ServerCall class Fs2StreamServerCallListener[F[_], Request, Response] private ( requestQ: Queue[F, Option[Request]], + signalReadiness: SyncIO[Unit], val isCancelled: Deferred[F, Unit], val call: Fs2ServerCall[F, Request, Response], val dispatcher: Dispatcher[F] @@ -47,6 +48,8 @@ class Fs2StreamServerCallListener[F[_], Request, Response] private ( dispatcher.unsafeRunSync(requestQ.offer(message.some)) } + override def onReady(): Unit = signalReadiness.unsafeRunSync() + override def onHalfClose(): Unit = dispatcher.unsafeRunSync(requestQ.offer(none)) @@ -60,13 +63,14 @@ object Fs2StreamServerCallListener { private[server] def apply[Request, Response]( call: ServerCall[Request, Response], + signalReadiness: SyncIO[Unit], dispatcher: Dispatcher[F], options: ServerOptions )(implicit F: Async[F]): F[Fs2StreamServerCallListener[F, Request, Response]] = for { inputQ <- Queue.unbounded[F, Option[Request]] isCancelled <- Deferred[F, Unit] serverCall <- Fs2ServerCall[F, Request, Response](call, options) - } yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, isCancelled, serverCall, dispatcher) + } yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, signalReadiness, isCancelled, serverCall, dispatcher) } diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala index 760735bb..44bbff77 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala @@ -42,19 +42,19 @@ private[server] object Fs2ServerCall { } private[server] final class Fs2ServerCall[Request, Response]( - call: ServerCall[Request, Response] + call: ServerCall[Request, Response], ) { import Fs2ServerCall.Cancel - def stream[F[_]](response: Stream[F, Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = + def stream[F[_]](sendStream: Stream[F, Response] => Stream[F, Unit], response: Stream[F, Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = run( response.pull.peek1 .flatMap { case Some((_, stream)) => Pull.suspend { call.sendHeaders(new Metadata()) - stream.map(call.sendMessage).pull.echo + sendStream(stream).pull.echo } case None => Pull.done } diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala index e7a0733e..2547acb3 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala @@ -21,12 +21,10 @@ package fs2.grpc.server.internal -import cats.effect.Ref -import cats.effect.Sync -import cats.effect.SyncIO import cats.effect.std.Dispatcher -import fs2.grpc.server.ServerCallOptions -import fs2.grpc.server.ServerOptions +import cats.effect.{Async, Ref, Sync, SyncIO} +import fs2.grpc.server.{ServerCallOptions, ServerOptions} +import fs2.grpc.shared.StreamOutput import io.grpc._ private[server] object Fs2UnaryServerCallHandler { @@ -49,6 +47,7 @@ private[server] object Fs2UnaryServerCallHandler { private def mkListener[Request, Response]( call: Fs2ServerCall[Request, Response], + signalReadiness: SyncIO[Unit], state: Ref[SyncIO, CallerState[Request]] ): ServerCall.Listener[Request] = new ServerCall.Listener[Request] { @@ -84,6 +83,8 @@ private[server] object Fs2UnaryServerCallHandler { } .unsafeRunSync() + override def onReady(): Unit = signalReadiness.unsafeRunSync() + private def sendError(status: Status): SyncIO[Unit] = state.set(Cancelled()) >> call.close(status, new Metadata()) } @@ -97,23 +98,28 @@ private[server] object Fs2UnaryServerCallHandler { private val opt = options.callOptionsFn(ServerCallOptions.default) def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = - startCallSync(call, opt)(call => req => call.unary(impl(req, headers), dispatcher)).unsafeRunSync() + startCallSync(call, SyncIO.unit, opt)(call => req => call.unary(impl(req, headers), dispatcher)).unsafeRunSync() } - def stream[F[_]: Sync, Request, Response]( + def stream[F[_], Request, Response]( impl: (Request, Metadata) => fs2.Stream[F, Response], options: ServerOptions, dispatcher: Dispatcher[F] - ): ServerCallHandler[Request, Response] = + )(implicit F: Async[F]): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { private val opt = options.callOptionsFn(ServerCallOptions.default) - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = - startCallSync(call, opt)(call => req => call.stream(impl(req, headers), dispatcher)).unsafeRunSync() + def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { + val outputStream = dispatcher.unsafeRunSync(StreamOutput.server(call, dispatcher)) + startCallSync(call, outputStream.onReady, opt)(call => req => { + call.stream(outputStream.writeStream, impl(req, headers), dispatcher) + }).unsafeRunSync() + } } private def startCallSync[F[_], Request, Response]( call: ServerCall[Request, Response], + signalReadiness: SyncIO[Unit], options: ServerCallOptions )(f: Fs2ServerCall[Request, Response] => Request => SyncIO[Cancel]): SyncIO[ServerCall.Listener[Request]] = { for { @@ -122,6 +128,6 @@ private[server] object Fs2UnaryServerCallHandler { // sends more than 1 requests, ServerCall will catch it. _ <- call.request(2) state <- CallerState.init(f(call)) - } yield mkListener[Request, Response](call, state) + } yield mkListener[Request, Response](call, signalReadiness, state) } } diff --git a/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala b/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala new file mode 100644 index 00000000..b6e9fb10 --- /dev/null +++ b/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala @@ -0,0 +1,65 @@ +package fs2.grpc.shared + +import cats.effect.std.Dispatcher +import cats.effect.{Async, Deferred, Ref, SyncIO} +import cats.syntax.all._ +import fs2.Stream +import io.grpc.{ClientCall, ServerCall} + +private[grpc] trait StreamOutput[F[_], T] { + def onReady: SyncIO[Unit] + + def writeStream(s: Stream[F, T]): Stream[F, Unit] +} + +private [grpc] object StreamOutput { + def client[F[_], Request, Response]( + c: ClientCall[Request, Response], + dispatcher: Dispatcher[F] + )(implicit F: Async[F]): F[StreamOutput[F, Request]] = { + Ref[F].of(Option.empty[Deferred[F, Unit]]).map { waiting => + new StreamOutputImpl[F, Request](waiting, dispatcher, + isReady = F.delay(c.isReady), + sendMessage = m => F.delay(c.sendMessage(m))) + } + } + + def server[F[_], Request, Response]( + c: ServerCall[Request, Response], + dispatcher: Dispatcher[F] + )(implicit F: Async[F]): F[StreamOutput[F, Response]] = { + Ref[F].of(Option.empty[Deferred[F, Unit]]).map { waiting => + new StreamOutputImpl[F, Response](waiting, dispatcher, + isReady = F.delay(c.isReady), + sendMessage = m => F.delay(c.sendMessage(m))) + } + } +} + +private[grpc] class StreamOutputImpl[F[_], T]( + waiting: Ref[F, Option[Deferred[F, Unit]]], + dispatcher: Dispatcher[F], + isReady: F[Boolean], + sendMessage: T => F[Unit], +)(implicit F: Async[F]) extends StreamOutput[F, T] { + override def onReady: SyncIO[Unit] = SyncIO.delay(dispatcher.unsafeRunAndForget(signal)) + + private def signal: F[Unit] = waiting.getAndSet(None).flatMap { + case None => F.unit + case Some(wake) => wake.complete(()).void + } + + override def writeStream(s: Stream[F, T]): Stream[F, Unit] = s.evalMap(sendWhenReady) + + private def sendWhenReady(msg: T): F[Unit] = { + val send = sendMessage(msg) + isReady.ifM(send, { + Deferred[F, Unit].flatMap { wakeup => + waiting.set(wakeup.some) *> + isReady.ifM(signal, F.unit) *> // trigger manually in case onReady was invoked before we installed wakeup + wakeup.get *> + send + } + }) + } +} diff --git a/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala b/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala index c8f77d28..a375f526 100644 --- a/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala @@ -146,6 +146,34 @@ class ClientSuite extends Fs2GrpcSuite { } + runTest0("stream to streamingToUnary - send respects readiness") { (tc, io, d) => + val dummy = new DummyClientCall() + val client = fs2ClientCall(dummy, d) + val requests = Stream.emits(List("a", "b", "c", "d", "e")) + .chunkLimit(1) + .unchunks + .map { value => + if (value == "c") dummy.setIsReady(false) + value + } + + val result = client + .streamingToUnaryCall(requests, new Metadata()) + .unsafeToFuture()(io) + + tc.tick() + + // Check that client has not sent messages when isReady == false + assertEquals(dummy.messagesSent.size, 2) + assertEquals(result.value, None) + + dummy.setIsReady(true) + tc.tick() + + // Check that client sends remaining messages after channel is ready + assertEquals(dummy.messagesSent.size, 5) + } + runTest0("0-length to streamingToUnary") { (tc, io, d) => val dummy = new DummyClientCall() val client = fs2ClientCall(dummy, d) @@ -248,6 +276,36 @@ class ClientSuite extends Fs2GrpcSuite { } + runTest0("stream to streamingToStreaming - send respects readiness") { (tc, io, d) => + val dummy = new DummyClientCall() + val client = fs2ClientCall(dummy, d) + val requests = Stream.emits(List("a", "b", "c", "d", "e")) + .chunkLimit(1) + .unchunks + .map { value => + if (value == "c") dummy.setIsReady(false) + value + } + + val result = client + .streamingToStreamingCall(requests, new Metadata()) + .compile + .toList + .unsafeToFuture()(io) + + tc.tick() + + // Check that client has not sent messages when isReady == false + assertEquals(dummy.messagesSent.size, 2) + assertEquals(result.value, None) + + dummy.setIsReady(true) + tc.tick() + + // Check that client sends remaining messages after channel is ready + assertEquals(dummy.messagesSent.size, 5) + } + runTest0("cancellation for streamingToStreaming") { (tc, io, d) => val dummy = new DummyClientCall() val client = fs2ClientCall(dummy, d) diff --git a/runtime/src/test/scala/fs2/grpc/client/DummyClientCall.scala b/runtime/src/test/scala/fs2/grpc/client/DummyClientCall.scala index 8cd5dabf..d2190e49 100644 --- a/runtime/src/test/scala/fs2/grpc/client/DummyClientCall.scala +++ b/runtime/src/test/scala/fs2/grpc/client/DummyClientCall.scala @@ -30,6 +30,7 @@ class DummyClientCall extends ClientCall[String, Int] { var requested: Int = 0 val messagesSent: ArrayBuffer[String] = ArrayBuffer[String]() var listener: Option[ClientCall.Listener[Int]] = None + var ready = true var cancelled: Option[(String, Throwable)] = None var halfClosed = false @@ -48,4 +49,13 @@ class DummyClientCall extends ClientCall[String, Int] { messagesSent += message () } + + override def isReady: Boolean = ready + + def setIsReady(value: Boolean): Unit = { + ready = value + if (value) { + listener.get.onReady() + } + } } diff --git a/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala b/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala index 31036241..6ce628b9 100644 --- a/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala +++ b/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala @@ -30,6 +30,7 @@ import scala.collection.mutable.ArrayBuffer class DummyServerCall extends ServerCall[String, Int] { val messages: ArrayBuffer[Int] = ArrayBuffer[Int]() var currentStatus: Option[Status] = None + var ready: Boolean = true override def request(numMessages: Int): Unit = () override def sendMessage(message: Int): Unit = { @@ -44,4 +45,12 @@ class DummyServerCall extends ServerCall[String, Int] { currentStatus = Some(status) } override def isCancelled: Boolean = false + + override def isReady: Boolean = ready + def setIsReady(value: Boolean, listener: ServerCall.Listener[_]): Unit = { + ready = value + if (ready) { + listener.onReady() + } + } } diff --git a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala index 68ded68e..1a1b2860 100644 --- a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala @@ -23,13 +23,14 @@ package fs2 package grpc package server -import scala.concurrent.duration._ import cats.effect._ import cats.effect.std.Dispatcher import cats.effect.testkit.TestContext import fs2.grpc.server.internal.Fs2UnaryServerCallHandler import io.grpc._ +import scala.concurrent.duration._ + class ServerSuite extends Fs2GrpcSuite { private val compressionOps = @@ -226,6 +227,56 @@ class ServerSuite extends Fs2GrpcSuite { assertEquals(dummy.currentStatus.get.isOk, false) } + runTest("streamingToStreaming send respects isReady") { (tc, d) => + val dummy = new DummyServerCall + + val listenerRef = Ref.unsafe[SyncIO, Option[ServerCall.Listener[_]]](None) + val handler = Fs2ServerCallHandler[IO](d, ServerOptions.default) + .streamingToStreamingCall[String, Int]((req, _) => unreadyAfterTwoEmissions(dummy, listenerRef).concurrently(req)) + val listener = handler.startCall(dummy, new Metadata()) + listenerRef.set(Some(listener)).unsafeRunSync() + + tc.tick() + + assertEquals(dummy.messages.toList, List(1, 2)) + + dummy.setIsReady(true, listener) + tc.tick() + + assertEquals(dummy.messages.toList, List(1, 2, 3, 4, 5)) + } + + runTest("unaryToStreaming send respects isReady") { (tc, d) => + val dummy = new DummyServerCall + + val listenerRef = Ref.unsafe[SyncIO, Option[ServerCall.Listener[_]]](None) + val handler = + Fs2UnaryServerCallHandler.stream[IO, String, Int]((_, _) => unreadyAfterTwoEmissions(dummy, listenerRef), ServerOptions.default, d) + + val listener = handler.startCall(dummy, new Metadata()) + listenerRef.set(Some(listener)).unsafeRunSync() + + listener.onMessage("a") + listener.onHalfClose() + tc.tick() + + assertEquals(dummy.messages.toList, List(1, 2)) + + dummy.setIsReady(true, listener) + tc.tick() + + assertEquals(dummy.messages.toList, List(1, 2, 3, 4, 5)) + } + + private def unreadyAfterTwoEmissions(dummy: DummyServerCall, listener: Ref[SyncIO, Option[ServerCall.Listener[_]]]) = + Stream.emits(List(1, 2, 3, 4, 5)) + .chunkLimit(1) + .unchunks + .map { value => + if (value == 3) dummy.setIsReady(false, listener.get.unsafeRunSync().get) + value + } + runTest("streaming to unary")(streamingToUnary()) runTest("streaming to unary with compression")(streamingToUnary(compressionOps)) From 8e366c78124b882a2a2ff443a3234fdf8af4d425 Mon Sep 17 00:00:00 2001 From: Tim Cuthbertson Date: Sat, 23 Jul 2022 22:25:45 +1000 Subject: [PATCH 2/8] Server: reuse client StreamIngest for incoming stream --- .../client/Fs2StreamClientCallListener.scala | 6 ++- .../scala/fs2/grpc/client/StreamIngest.scala | 19 +++++---- .../server/Fs2StreamServerCallListener.scala | 19 +++++---- .../fs2/grpc/server/DummyServerCall.scala | 5 ++- .../scala/fs2/grpc/server/ServerSuite.scala | 39 +++++++++++++++++++ 5 files changed, 65 insertions(+), 23 deletions(-) diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala index 348f01e9..e0e34081 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala @@ -38,8 +38,10 @@ class Fs2StreamClientCallListener[F[_], Response] private ( override def onMessage(message: Response): Unit = dispatcher.unsafeRunSync(ingest.onMessage(message)) - override def onClose(status: Status, trailers: Metadata): Unit = - dispatcher.unsafeRunSync(ingest.onClose(GrpcStatus(status, trailers))) + override def onClose(status: Status, trailers: Metadata): Unit = { + val error = Option.when(!status.isOk)(status.asRuntimeException(trailers)) + dispatcher.unsafeRunSync(ingest.onClose(error)) + } override def onReady(): Unit = signalReadiness.unsafeRunSync() diff --git a/runtime/src/main/scala/fs2/grpc/client/StreamIngest.scala b/runtime/src/main/scala/fs2/grpc/client/StreamIngest.scala index 9d2775c7..5779c6bd 100644 --- a/runtime/src/main/scala/fs2/grpc/client/StreamIngest.scala +++ b/runtime/src/main/scala/fs2/grpc/client/StreamIngest.scala @@ -27,26 +27,26 @@ import cats.implicits._ import cats.effect.Concurrent import cats.effect.std.Queue -private[client] trait StreamIngest[F[_], T] { +private[grpc] trait StreamIngest[F[_], T] { def onMessage(msg: T): F[Unit] - def onClose(status: GrpcStatus): F[Unit] + def onClose(error: Option[Throwable]): F[Unit] def messages: Stream[F, T] } -private[client] object StreamIngest { +private[grpc] object StreamIngest { def apply[F[_]: Concurrent, T]( request: Int => F[Unit], prefetchN: Int ): F[StreamIngest[F, T]] = Queue - .unbounded[F, Either[GrpcStatus, T]] + .unbounded[F, Either[Option[Throwable], T]] .map(q => create[F, T](request, prefetchN, q)) def create[F[_], T]( request: Int => F[Unit], prefetchN: Int, - queue: Queue[F, Either[GrpcStatus, T]] + queue: Queue[F, Either[Option[Throwable], T]] )(implicit F: Concurrent[F]): StreamIngest[F, T] = new StreamIngest[F, T] { val limit: Int = @@ -58,17 +58,16 @@ private[client] object StreamIngest { def onMessage(msg: T): F[Unit] = queue.offer(msg.asRight) *> ensureMessages - def onClose(status: GrpcStatus): F[Unit] = - queue.offer(status.asLeft) + def onClose(error: Option[Throwable]): F[Unit] = + queue.offer(error.asLeft) val messages: Stream[F, T] = { val run: F[Option[T]] = queue.take.flatMap { case Right(v) => ensureMessages *> v.some.pure[F] - case Left(GrpcStatus(status, trailers)) => - if (!status.isOk) F.raiseError(status.asRuntimeException(trailers)) - else none[T].pure[F] + case Left(Some(error)) => F.raiseError(error) + case Left(None) => none[T].pure[F] } Stream.repeatEval(run).unNoneTerminate diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala index e6de51ee..8131c811 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala @@ -28,10 +28,11 @@ import cats.syntax.all._ import cats.effect.kernel.Deferred import cats.effect.{Async, SyncIO} import cats.effect.std.{Dispatcher, Queue} +import fs2.grpc.client.StreamIngest import io.grpc.ServerCall class Fs2StreamServerCallListener[F[_], Request, Response] private ( - requestQ: Queue[F, Option[Request]], + ingest: StreamIngest[F, Request], signalReadiness: SyncIO[Unit], val isCancelled: Deferred[F, Unit], val call: Fs2ServerCall[F, Request, Response], @@ -43,18 +44,15 @@ class Fs2StreamServerCallListener[F[_], Request, Response] private ( override def onCancel(): Unit = dispatcher.unsafeRunSync(isCancelled.complete(()).void) - override def onMessage(message: Request): Unit = { - call.call.request(1) - dispatcher.unsafeRunSync(requestQ.offer(message.some)) - } + override def onMessage(message: Request): Unit = + dispatcher.unsafeRunSync(ingest.onMessage(message)) override def onReady(): Unit = signalReadiness.unsafeRunSync() override def onHalfClose(): Unit = - dispatcher.unsafeRunSync(requestQ.offer(none)) + dispatcher.unsafeRunSync(ingest.onClose(None)) - override def source: Stream[F, Request] = - Stream.repeatEval(requestQ.take).unNoneTerminate + override def source: Stream[F, Request] = ingest.messages } object Fs2StreamServerCallListener { @@ -67,10 +65,11 @@ object Fs2StreamServerCallListener { dispatcher: Dispatcher[F], options: ServerOptions )(implicit F: Async[F]): F[Fs2StreamServerCallListener[F, Request, Response]] = for { - inputQ <- Queue.unbounded[F, Option[Request]] isCancelled <- Deferred[F, Unit] + request = (n: Int) => F.delay(call.request(n)) + ingest <- StreamIngest[F, Request](request, prefetchN = 1) serverCall <- Fs2ServerCall[F, Request, Response](call, options) - } yield new Fs2StreamServerCallListener[F, Request, Response](inputQ, signalReadiness, isCancelled, serverCall, dispatcher) + } yield new Fs2StreamServerCallListener[F, Request, Response](ingest, signalReadiness, isCancelled, serverCall, dispatcher) } diff --git a/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala b/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala index 6ce628b9..ba0663c2 100644 --- a/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala +++ b/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala @@ -30,9 +30,12 @@ import scala.collection.mutable.ArrayBuffer class DummyServerCall extends ServerCall[String, Int] { val messages: ArrayBuffer[Int] = ArrayBuffer[Int]() var currentStatus: Option[Status] = None + var requested: Int = 0 var ready: Boolean = true - override def request(numMessages: Int): Unit = () + override def request(numMessages: Int): Unit = { + requested += numMessages + } override def sendMessage(message: Int): Unit = { messages += message () diff --git a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala index 1a1b2860..69db3a04 100644 --- a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala @@ -303,4 +303,43 @@ class ServerSuite extends Fs2GrpcSuite { assertEquals(dummy.currentStatus.get.isOk, true) } + runTest("streamingToUnary back pressure") { (tc, d) => + val dummy = new DummyServerCall + val deferred = d.unsafeRunSync(Deferred[IO, Unit]) + val handler = Fs2ServerCallHandler[IO](d, ServerOptions.default) + .streamingToUnaryCall[String, Int]((requests, _) => { + requests.evalMap(_ => deferred.get).compile.drain.as(1) + }) + val listener = handler.startCall(dummy, new Metadata()) + + tc.tick() + + assertEquals(dummy.requested, 1) + + listener.onMessage("1") + tc.tick() + + listener.onMessage("2") + listener.onMessage("3") + tc.tick() + + // requested should ideally be 2, however StreamIngest can double-request in some execution + // orderings if the push() is followed by pop() before the push checks the queue length. + val initialRequested = dummy.requested + assert(initialRequested == 2 || initialRequested == 3, s"expected requested to be 2 or 3, got ${initialRequested}") + + // don't request any more messages while downstream is blocked + listener.onMessage("4") + listener.onMessage("5") + listener.onMessage("6") + tc.tick() + + assertEquals(dummy.requested - initialRequested, 0) + + // allow all messages through, the final pop() will trigger a new request + d.unsafeRunAndForget(deferred.complete(())) + tc.tick() + + assertEquals(dummy.requested - initialRequested, 1) + } } From a1668d78e8b2d0c363e7f073d67a724896c2df66 Mon Sep 17 00:00:00 2001 From: Tim Cuthbertson Date: Mon, 19 Sep 2022 13:51:05 +1000 Subject: [PATCH 3/8] StreamOutput: use SignallingRef instead of Ref[Deferred] --- .../scala/fs2/grpc/client/Fs2ClientCall.scala | 8 +-- .../client/internal/Fs2UnaryCallHandler.scala | 4 +- .../grpc/server/Fs2ServerCallHandler.scala | 4 +- .../internal/Fs2UnaryServerCallHandler.scala | 4 +- .../scala/fs2/grpc/shared/StreamOutput.scala | 50 +++++++++---------- 5 files changed, 34 insertions(+), 36 deletions(-) diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala index b8a20d91..01ca1a9b 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala @@ -64,8 +64,8 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( Fs2UnaryCallHandler.unary(call, options, message, headers) def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] = - StreamOutput.client(call, dispatcher).flatMap { output => - Fs2UnaryCallHandler.stream(call, options, messages, output, headers) + StreamOutput.client(call).flatMap { output => + Fs2UnaryCallHandler.stream(call, options, dispatcher, messages, output, headers) } def unaryToStreamingCall(message: Request, md: Metadata): Stream[F, Response] = @@ -74,8 +74,8 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( .flatMap(Stream.exec(sendSingleMessage(message)) ++ _.stream.adaptError(ea)) def streamingToStreamingCall(messages: Stream[F, Request], md: Metadata): Stream[F, Response] = { - val listenerAndOutput = Resource.eval(StreamOutput.client(call, dispatcher)).flatMap { output => - mkStreamListenerR(md, output.onReady).map((_, output)) + val listenerAndOutput = Resource.eval(StreamOutput.client(call)).flatMap { output => + mkStreamListenerR(md, output.onReadySync(dispatcher)).map((_, output)) } Stream diff --git a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala index 45ac759a..bfc83ec6 100644 --- a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala @@ -22,6 +22,7 @@ package fs2.grpc.client.internal import cats.effect.kernel.{Async, Outcome, Ref} +import cats.effect.std.Dispatcher import cats.effect.syntax.all._ import cats.effect.{Sync, SyncIO} import cats.syntax.flatMap._ @@ -133,12 +134,13 @@ private[client] object Fs2UnaryCallHandler { def stream[F[_], Request, Response]( call: ClientCall[Request, Response], options: ClientOptions, + dispatcher: Dispatcher[F], messages: Stream[F, Request], output: StreamOutput[F, Request], headers: Metadata )(implicit F: Async[F]): F[Response] = F.async[Response] { cb => ReceiveState.init(cb, options.errorAdapter).flatMap { state => - call.start(mkListener[Response](state, output.onReady), headers) + call.start(mkListener[Response](state, output.onReadySync(dispatcher)), headers) // Initially ask for two responses from flow-control so that if a misbehaving server // sends more than one responses, we can catch it and fail it in the listener. call.request(2) diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala index 831a4d4c..37064909 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala @@ -59,8 +59,8 @@ class Fs2ServerCallHandler[F[_]: Async] private ( implementation: (Stream[F, Request], Metadata) => Stream[F, Response] ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val (listener, streamOutput) = dispatcher.unsafeRunSync(StreamOutput.server(call, dispatcher).flatMap { output => - Fs2StreamServerCallListener[F](call, output.onReady, dispatcher, options).map((_, output)) + val (listener, streamOutput) = dispatcher.unsafeRunSync(StreamOutput.server(call).flatMap { output => + Fs2StreamServerCallListener[F](call, output.onReadySync(dispatcher), dispatcher, options).map((_, output)) }) listener.unsafeStreamResponse(streamOutput, new Metadata(), implementation(_, headers)) listener diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala index 2547acb3..57b0b9ad 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala @@ -110,8 +110,8 @@ private[server] object Fs2UnaryServerCallHandler { private val opt = options.callOptionsFn(ServerCallOptions.default) def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val outputStream = dispatcher.unsafeRunSync(StreamOutput.server(call, dispatcher)) - startCallSync(call, outputStream.onReady, opt)(call => req => { + val outputStream = dispatcher.unsafeRunSync(StreamOutput.server(call)) + startCallSync(call, outputStream.onReadySync(dispatcher), opt)(call => req => { call.stream(outputStream.writeStream, impl(req, headers), dispatcher) }).unsafeRunSync() } diff --git a/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala b/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala index b6e9fb10..769c344b 100644 --- a/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala +++ b/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala @@ -1,35 +1,36 @@ package fs2.grpc.shared import cats.effect.std.Dispatcher -import cats.effect.{Async, Deferred, Ref, SyncIO} +import cats.effect.{Async, SyncIO} import cats.syntax.all._ import fs2.Stream +import fs2.concurrent.SignallingRef import io.grpc.{ClientCall, ServerCall} private[grpc] trait StreamOutput[F[_], T] { - def onReady: SyncIO[Unit] + def onReady: F[Unit] + + def onReadySync(dispatcher: Dispatcher[F]): SyncIO[Unit] = SyncIO.delay(dispatcher.unsafeRunSync(onReady)) def writeStream(s: Stream[F, T]): Stream[F, Unit] } private [grpc] object StreamOutput { - def client[F[_], Request, Response]( - c: ClientCall[Request, Response], - dispatcher: Dispatcher[F] - )(implicit F: Async[F]): F[StreamOutput[F, Request]] = { - Ref[F].of(Option.empty[Deferred[F, Unit]]).map { waiting => - new StreamOutputImpl[F, Request](waiting, dispatcher, + def client[F[_], Request, Response](c: ClientCall[Request, Response]) + (implicit F: Async[F]): F[StreamOutput[F, Request]] = { + SignallingRef[F].of(0L).map { readyState => + new StreamOutputImpl[F, Request]( + readyState, isReady = F.delay(c.isReady), sendMessage = m => F.delay(c.sendMessage(m))) } } - def server[F[_], Request, Response]( - c: ServerCall[Request, Response], - dispatcher: Dispatcher[F] - )(implicit F: Async[F]): F[StreamOutput[F, Response]] = { - Ref[F].of(Option.empty[Deferred[F, Unit]]).map { waiting => - new StreamOutputImpl[F, Response](waiting, dispatcher, + def server[F[_], Request, Response](c: ServerCall[Request, Response]) + (implicit F: Async[F]): F[StreamOutput[F, Response]] = { + SignallingRef[F].of(0L).map { readyState => + new StreamOutputImpl[F, Response]( + readyState, isReady = F.delay(c.isReady), sendMessage = m => F.delay(c.sendMessage(m))) } @@ -37,28 +38,23 @@ private [grpc] object StreamOutput { } private[grpc] class StreamOutputImpl[F[_], T]( - waiting: Ref[F, Option[Deferred[F, Unit]]], - dispatcher: Dispatcher[F], + readyCountRef: SignallingRef[F, Long], isReady: F[Boolean], sendMessage: T => F[Unit], )(implicit F: Async[F]) extends StreamOutput[F, T] { - override def onReady: SyncIO[Unit] = SyncIO.delay(dispatcher.unsafeRunAndForget(signal)) - - private def signal: F[Unit] = waiting.getAndSet(None).flatMap { - case None => F.unit - case Some(wake) => wake.complete(()).void - } + override def onReady: F[Unit] = readyCountRef.update(_ + 1L) override def writeStream(s: Stream[F, T]): Stream[F, Unit] = s.evalMap(sendWhenReady) private def sendWhenReady(msg: T): F[Unit] = { val send = sendMessage(msg) isReady.ifM(send, { - Deferred[F, Unit].flatMap { wakeup => - waiting.set(wakeup.some) *> - isReady.ifM(signal, F.unit) *> // trigger manually in case onReady was invoked before we installed wakeup - wakeup.get *> - send + readyCountRef.get.flatMap { readyState => + // If isReady is now true, don't wait (we may have missed the onReady signal) + isReady.ifM(send, { + // otherwise wait until readyState has been incremented + readyCountRef.waitUntil(_ > readyState) *> send + }) } }) } From 0da8407b79520de66520e172b371777e0a00808e Mon Sep 17 00:00:00 2001 From: Tim Cuthbertson Date: Tue, 20 Sep 2022 13:49:02 +1000 Subject: [PATCH 4/8] add header comment --- .../scala/fs2/grpc/shared/StreamOutput.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala b/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala index 769c344b..acb7f782 100644 --- a/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala +++ b/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala @@ -1,3 +1,24 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + package fs2.grpc.shared import cats.effect.std.Dispatcher From 318b282c347c5a0a81c508b34bee7ebfecc55e1b Mon Sep 17 00:00:00 2001 From: Tim Cuthbertson Date: Tue, 20 Sep 2022 13:51:06 +1000 Subject: [PATCH 5/8] remove Option.when as it's not available in 2.11 --- .../scala/fs2/grpc/client/Fs2StreamClientCallListener.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala index e0e34081..d1594f45 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala @@ -39,7 +39,7 @@ class Fs2StreamClientCallListener[F[_], Response] private ( dispatcher.unsafeRunSync(ingest.onMessage(message)) override def onClose(status: Status, trailers: Metadata): Unit = { - val error = Option.when(!status.isOk)(status.asRuntimeException(trailers)) + val error = if (status.isOk) None else Some(status.asRuntimeException(trailers)) dispatcher.unsafeRunSync(ingest.onClose(error)) } From d77c526b976851572e17f83a23695c2a8b86dd95 Mon Sep 17 00:00:00 2001 From: Tim Cuthbertson Date: Tue, 20 Sep 2022 21:10:00 +1000 Subject: [PATCH 6/8] scalafmt --- .../scala/fs2/grpc/client/Fs2ClientCall.scala | 8 +++- .../client/internal/Fs2UnaryCallHandler.scala | 6 ++- .../grpc/server/Fs2ServerCallListener.scala | 6 ++- .../server/Fs2StreamServerCallListener.scala | 8 +++- .../grpc/server/internal/Fs2ServerCall.scala | 8 +++- .../internal/Fs2UnaryServerCallHandler.scala | 8 ++-- .../scala/fs2/grpc/shared/StreamOutput.scala | 47 +++++++++++-------- .../scala/fs2/grpc/client/ClientSuite.scala | 6 ++- .../scala/fs2/grpc/server/ServerSuite.scala | 9 +++- 9 files changed, 73 insertions(+), 33 deletions(-) diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala index 01ca1a9b..f9116816 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala @@ -81,7 +81,8 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( Stream .resource(listenerAndOutput) .flatMap { case (listener, output) => - listener.stream.adaptError(ea) + listener.stream + .adaptError(ea) .concurrently(output.writeStream(messages) ++ Stream.eval(halfClose)) } } @@ -94,7 +95,10 @@ class Fs2ClientCall[F[_], Request, Response] private[client] ( case (_, Resource.ExitCase.Errored(t)) => cancel(t.getMessage.some, t.some) } - private def mkStreamListenerR(md: Metadata, signalReadiness: SyncIO[Unit]): Resource[F, Fs2StreamClientCallListener[F, Response]] = { + private def mkStreamListenerR( + md: Metadata, + signalReadiness: SyncIO[Unit] + ): Resource[F, Fs2StreamClientCallListener[F, Response]] = { val prefetchN = options.prefetchN.max(1) val create = Fs2StreamClientCallListener.create[F, Response](request, signalReadiness, dispatcher, prefetchN) val acquire = start(create, md) <* request(prefetchN) diff --git a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala index bfc83ec6..005e9c67 100644 --- a/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/client/internal/Fs2UnaryCallHandler.scala @@ -144,7 +144,11 @@ private[client] object Fs2UnaryCallHandler { // Initially ask for two responses from flow-control so that if a misbehaving server // sends more than one responses, we can catch it and fail it in the listener. call.request(2) - output.writeStream(messages).compile.drain.guaranteeCase { + output + .writeStream(messages) + .compile + .drain + .guaranteeCase { case Outcome.Succeeded(_) => F.delay(call.halfClose()) case Outcome.Errored(e) => F.delay(call.cancel(e.getMessage, e)) case Outcome.Canceled() => onCancel(call) diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala index de8ea01a..98692636 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala @@ -71,7 +71,11 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] { ): Unit = unsafeRun(handleUnaryResponse(headers, implementation(source))) - def unsafeStreamResponse(streamOutput: StreamOutput[F, Response], headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit + def unsafeStreamResponse( + streamOutput: StreamOutput[F, Response], + headers: Metadata, + implementation: G[Request] => Stream[F, Response] + )(implicit F: Async[F] ): Unit = unsafeRun(handleStreamResponse(headers, streamOutput.writeStream(implementation(source)))) diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala index 8131c811..37e699a8 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala @@ -69,7 +69,13 @@ object Fs2StreamServerCallListener { request = (n: Int) => F.delay(call.request(n)) ingest <- StreamIngest[F, Request](request, prefetchN = 1) serverCall <- Fs2ServerCall[F, Request, Response](call, options) - } yield new Fs2StreamServerCallListener[F, Request, Response](ingest, signalReadiness, isCancelled, serverCall, dispatcher) + } yield new Fs2StreamServerCallListener[F, Request, Response]( + ingest, + signalReadiness, + isCancelled, + serverCall, + dispatcher + ) } diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala index 44bbff77..e71426c3 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala @@ -42,12 +42,16 @@ private[server] object Fs2ServerCall { } private[server] final class Fs2ServerCall[Request, Response]( - call: ServerCall[Request, Response], + call: ServerCall[Request, Response] ) { import Fs2ServerCall.Cancel - def stream[F[_]](sendStream: Stream[F, Response] => Stream[F, Unit], response: Stream[F, Response], dispatcher: Dispatcher[F])(implicit F: Sync[F]): SyncIO[Cancel] = + def stream[F[_]]( + sendStream: Stream[F, Response] => Stream[F, Unit], + response: Stream[F, Response], + dispatcher: Dispatcher[F] + )(implicit F: Sync[F]): SyncIO[Cancel] = run( response.pull.peek1 .flatMap { diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala index 57b0b9ad..1af54faa 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala @@ -111,9 +111,11 @@ private[server] object Fs2UnaryServerCallHandler { def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { val outputStream = dispatcher.unsafeRunSync(StreamOutput.server(call)) - startCallSync(call, outputStream.onReadySync(dispatcher), opt)(call => req => { - call.stream(outputStream.writeStream, impl(req, headers), dispatcher) - }).unsafeRunSync() + startCallSync(call, outputStream.onReadySync(dispatcher), opt)(call => + req => { + call.stream(outputStream.writeStream, impl(req, headers), dispatcher) + } + ).unsafeRunSync() } } diff --git a/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala b/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala index acb7f782..3aea9e74 100644 --- a/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala +++ b/runtime/src/main/scala/fs2/grpc/shared/StreamOutput.scala @@ -36,47 +36,56 @@ private[grpc] trait StreamOutput[F[_], T] { def writeStream(s: Stream[F, T]): Stream[F, Unit] } -private [grpc] object StreamOutput { - def client[F[_], Request, Response](c: ClientCall[Request, Response]) - (implicit F: Async[F]): F[StreamOutput[F, Request]] = { +private[grpc] object StreamOutput { + def client[F[_], Request, Response]( + c: ClientCall[Request, Response] + )(implicit F: Async[F]): F[StreamOutput[F, Request]] = { SignallingRef[F].of(0L).map { readyState => new StreamOutputImpl[F, Request]( readyState, isReady = F.delay(c.isReady), - sendMessage = m => F.delay(c.sendMessage(m))) + sendMessage = m => F.delay(c.sendMessage(m)) + ) } } - def server[F[_], Request, Response](c: ServerCall[Request, Response]) - (implicit F: Async[F]): F[StreamOutput[F, Response]] = { + def server[F[_], Request, Response]( + c: ServerCall[Request, Response] + )(implicit F: Async[F]): F[StreamOutput[F, Response]] = { SignallingRef[F].of(0L).map { readyState => new StreamOutputImpl[F, Response]( readyState, isReady = F.delay(c.isReady), - sendMessage = m => F.delay(c.sendMessage(m))) + sendMessage = m => F.delay(c.sendMessage(m)) + ) } } } private[grpc] class StreamOutputImpl[F[_], T]( - readyCountRef: SignallingRef[F, Long], - isReady: F[Boolean], - sendMessage: T => F[Unit], -)(implicit F: Async[F]) extends StreamOutput[F, T] { + readyCountRef: SignallingRef[F, Long], + isReady: F[Boolean], + sendMessage: T => F[Unit] +)(implicit F: Async[F]) + extends StreamOutput[F, T] { override def onReady: F[Unit] = readyCountRef.update(_ + 1L) override def writeStream(s: Stream[F, T]): Stream[F, Unit] = s.evalMap(sendWhenReady) private def sendWhenReady(msg: T): F[Unit] = { val send = sendMessage(msg) - isReady.ifM(send, { - readyCountRef.get.flatMap { readyState => - // If isReady is now true, don't wait (we may have missed the onReady signal) - isReady.ifM(send, { - // otherwise wait until readyState has been incremented - readyCountRef.waitUntil(_ > readyState) *> send - }) + isReady.ifM( + send, { + readyCountRef.get.flatMap { readyState => + // If isReady is now true, don't wait (we may have missed the onReady signal) + isReady.ifM( + send, { + // otherwise wait until readyState has been incremented + readyCountRef.waitUntil(_ > readyState) *> send + } + ) + } } - }) + ) } } diff --git a/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala b/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala index a375f526..c327d422 100644 --- a/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala @@ -149,7 +149,8 @@ class ClientSuite extends Fs2GrpcSuite { runTest0("stream to streamingToUnary - send respects readiness") { (tc, io, d) => val dummy = new DummyClientCall() val client = fs2ClientCall(dummy, d) - val requests = Stream.emits(List("a", "b", "c", "d", "e")) + val requests = Stream + .emits(List("a", "b", "c", "d", "e")) .chunkLimit(1) .unchunks .map { value => @@ -279,7 +280,8 @@ class ClientSuite extends Fs2GrpcSuite { runTest0("stream to streamingToStreaming - send respects readiness") { (tc, io, d) => val dummy = new DummyClientCall() val client = fs2ClientCall(dummy, d) - val requests = Stream.emits(List("a", "b", "c", "d", "e")) + val requests = Stream + .emits(List("a", "b", "c", "d", "e")) .chunkLimit(1) .unchunks .map { value => diff --git a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala index 69db3a04..5f67458b 100644 --- a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala @@ -251,7 +251,11 @@ class ServerSuite extends Fs2GrpcSuite { val listenerRef = Ref.unsafe[SyncIO, Option[ServerCall.Listener[_]]](None) val handler = - Fs2UnaryServerCallHandler.stream[IO, String, Int]((_, _) => unreadyAfterTwoEmissions(dummy, listenerRef), ServerOptions.default, d) + Fs2UnaryServerCallHandler.stream[IO, String, Int]( + (_, _) => unreadyAfterTwoEmissions(dummy, listenerRef), + ServerOptions.default, + d + ) val listener = handler.startCall(dummy, new Metadata()) listenerRef.set(Some(listener)).unsafeRunSync() @@ -269,7 +273,8 @@ class ServerSuite extends Fs2GrpcSuite { } private def unreadyAfterTwoEmissions(dummy: DummyServerCall, listener: Ref[SyncIO, Option[ServerCall.Listener[_]]]) = - Stream.emits(List(1, 2, 3, 4, 5)) + Stream + .emits(List(1, 2, 3, 4, 5)) .chunkLimit(1) .unchunks .map { value => From 55e2a8c8120a73b223064d98b7a7ec7adb996b15 Mon Sep 17 00:00:00 2001 From: Tim Cuthbertson Date: Tue, 20 Sep 2022 22:34:35 +1000 Subject: [PATCH 7/8] remove unused import --- .../scala/fs2/grpc/server/Fs2StreamServerCallListener.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala index 37e699a8..faef2996 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala @@ -27,7 +27,7 @@ import cats.Functor import cats.syntax.all._ import cats.effect.kernel.Deferred import cats.effect.{Async, SyncIO} -import cats.effect.std.{Dispatcher, Queue} +import cats.effect.std.Dispatcher import fs2.grpc.client.StreamIngest import io.grpc.ServerCall From 67b0ee6b38d7e572effa4d9a0d8c6fbbd2b97b10 Mon Sep 17 00:00:00 2001 From: Tim Cuthbertson Date: Wed, 21 Sep 2022 21:03:05 +1000 Subject: [PATCH 8/8] mark some classes as private to fix mima issues --- build.sbt | 3 ++- .../fs2/grpc/client/Fs2StreamClientCallListener.scala | 6 +++--- .../fs2/grpc/server/Fs2StreamServerCallListener.scala | 8 ++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/build.sbt b/build.sbt index 4f8d7f1d..7984997e 100644 --- a/build.sbt +++ b/build.sbt @@ -33,8 +33,9 @@ inThisBuild( mimaBinaryIssueFilters ++= Seq( // API that is not extended by end-users ProblemFilters.exclude[ReversedMissingMethodProblem]("fs2.grpc.GeneratedCompanion.mkClient"), - // package private API + // package private APIs ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.client.StreamIngest.create"), + ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.server.Fs2StreamServerCallListener*"), // deleted private classes ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.client.Fs2UnaryClientCallListener*"), ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.server.Fs2UnaryServerCallListener*") diff --git a/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala b/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala index d1594f45..8e076f4d 100644 --- a/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/client/Fs2StreamClientCallListener.scala @@ -29,7 +29,7 @@ import cats.effect.kernel.Concurrent import cats.effect.std.Dispatcher import io.grpc.{ClientCall, Metadata, Status} -class Fs2StreamClientCallListener[F[_], Response] private ( +private[client] class Fs2StreamClientCallListener[F[_], Response] private ( ingest: StreamIngest[F, Response], signalReadiness: SyncIO[Unit], dispatcher: Dispatcher[F] @@ -48,9 +48,9 @@ class Fs2StreamClientCallListener[F[_], Response] private ( val stream: Stream[F, Response] = ingest.messages } -object Fs2StreamClientCallListener { +private[client] object Fs2StreamClientCallListener { - private[client] def create[F[_]: Concurrent, Response]( + def create[F[_]: Concurrent, Response]( request: Int => F[Unit], signalReadiness: SyncIO[Unit], dispatcher: Dispatcher[F], diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala index faef2996..013b839d 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2StreamServerCallListener.scala @@ -31,7 +31,7 @@ import cats.effect.std.Dispatcher import fs2.grpc.client.StreamIngest import io.grpc.ServerCall -class Fs2StreamServerCallListener[F[_], Request, Response] private ( +private[server] class Fs2StreamServerCallListener[F[_], Request, Response] private ( ingest: StreamIngest[F, Request], signalReadiness: SyncIO[Unit], val isCancelled: Deferred[F, Unit], @@ -55,11 +55,11 @@ class Fs2StreamServerCallListener[F[_], Request, Response] private ( override def source: Stream[F, Request] = ingest.messages } -object Fs2StreamServerCallListener { +private[server] object Fs2StreamServerCallListener { class PartialFs2StreamServerCallListener[F[_]](val dummy: Boolean = false) extends AnyVal { - private[server] def apply[Request, Response]( + def apply[Request, Response]( call: ServerCall[Request, Response], signalReadiness: SyncIO[Unit], dispatcher: Dispatcher[F], @@ -79,6 +79,6 @@ object Fs2StreamServerCallListener { } - private[server] def apply[F[_]] = new PartialFs2StreamServerCallListener[F] + def apply[F[_]] = new PartialFs2StreamServerCallListener[F] }