Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ inThisBuild(
mimaBinaryIssueFilters ++= Seq(
// API that is not extended by end-users
ProblemFilters.exclude[ReversedMissingMethodProblem]("fs2.grpc.GeneratedCompanion.mkClient"),
// package private API
// package private APIs
ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.client.StreamIngest.create"),
ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.server.Fs2StreamServerCallListener*"),
// deleted private classes
ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.client.Fs2UnaryClientCallListener*"),
ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.server.Fs2UnaryServerCallListener*")
Expand Down
44 changes: 26 additions & 18 deletions runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ package fs2
package grpc
package client

import cats.syntax.all._
import cats.effect.{Async, Resource}
import cats.effect.std.Dispatcher
import cats.effect.{Async, Resource, SyncIO}
import cats.syntax.all._
import fs2.grpc.client.internal.Fs2UnaryCallHandler
import io.grpc.{Metadata, _}
import fs2.grpc.shared.StreamOutput
import io.grpc._

final case class UnaryResult[A](value: Option[A], status: Option[GrpcStatus])
final case class GrpcStatus(status: Status, trailers: Metadata)
Expand All @@ -51,35 +52,40 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
private def request(numMessages: Int): F[Unit] =
F.delay(call.request(numMessages))

private def sendMessage(message: Request): F[Unit] =
F.delay(call.sendMessage(message))

private def start[A <: ClientCall.Listener[Response]](createListener: F[A], md: Metadata): F[A] =
createListener.flatTap(l => F.delay(call.start(l, md)))

private def sendSingleMessage(message: Request): F[Unit] =
sendMessage(message) *> halfClose

private def sendStream(stream: Stream[F, Request]): Stream[F, Unit] =
stream.evalMap(sendMessage) ++ Stream.eval(halfClose)
F.delay(call.sendMessage(message)) *> halfClose

//

def unaryToUnaryCall(message: Request, headers: Metadata): F[Response] =
Fs2UnaryCallHandler.unary(call, options, message, headers)

def streamingToUnaryCall(messages: Stream[F, Request], headers: Metadata): F[Response] =
Fs2UnaryCallHandler.stream(call, options, messages, headers)
StreamOutput.client(call).flatMap { output =>
Fs2UnaryCallHandler.stream(call, options, dispatcher, messages, output, headers)
}

def unaryToStreamingCall(message: Request, md: Metadata): Stream[F, Response] =
Stream
.resource(mkStreamListenerR(md))
.resource(mkStreamListenerR(md, SyncIO.unit))
.flatMap(Stream.exec(sendSingleMessage(message)) ++ _.stream.adaptError(ea))

def streamingToStreamingCall(messages: Stream[F, Request], md: Metadata): Stream[F, Response] =
def streamingToStreamingCall(messages: Stream[F, Request], md: Metadata): Stream[F, Response] = {
val listenerAndOutput = Resource.eval(StreamOutput.client(call)).flatMap { output =>
mkStreamListenerR(md, output.onReadySync(dispatcher)).map((_, output))
}

Stream
.resource(mkStreamListenerR(md))
.flatMap(_.stream.adaptError(ea).concurrently(sendStream(messages)))
.resource(listenerAndOutput)
.flatMap { case (listener, output) =>
listener.stream
.adaptError(ea)
.concurrently(output.writeStream(messages) ++ Stream.eval(halfClose))
}
}

//

Expand All @@ -89,10 +95,12 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
case (_, Resource.ExitCase.Errored(t)) => cancel(t.getMessage.some, t.some)
}

private def mkStreamListenerR(md: Metadata): Resource[F, Fs2StreamClientCallListener[F, Response]] = {

private def mkStreamListenerR(
md: Metadata,
signalReadiness: SyncIO[Unit]
): Resource[F, Fs2StreamClientCallListener[F, Response]] = {
val prefetchN = options.prefetchN.max(1)
val create = Fs2StreamClientCallListener.create[F, Response](request, dispatcher, prefetchN)
val create = Fs2StreamClientCallListener.create[F, Response](request, signalReadiness, dispatcher, prefetchN)
val acquire = start(create, md) <* request(prefetchN)
val release = handleExitCase(cancelSucceed = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,41 @@ package fs2
package grpc
package client

import cats.effect.SyncIO
import cats.implicits._
import cats.effect.kernel.Concurrent
import cats.effect.std.Dispatcher
import io.grpc.{ClientCall, Metadata, Status}

class Fs2StreamClientCallListener[F[_], Response] private (
private[client] class Fs2StreamClientCallListener[F[_], Response] private (
ingest: StreamIngest[F, Response],
signalReadiness: SyncIO[Unit],
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason for SyncIO? Have you considered SignallingRef from fs2?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super fluent in the guts of cats/fs2, so happy to take recommendations.

SyncIO seems useful somewhere in the mix, because:

  • the invocations from the GRPC listeners need to run sync, so we need either SyncIO or a Dispatcher
  • the outgoing unary call types set signalReadiness=SyncIO.unit, because there's no need to respect readiness with only one outgoing message.
    • It'd be wasteful to call a dispatcher in that case just to run F.unit.
    • Though now that I think about it, a single outgoing message probably never triggers the onReady code path so maybe it doesn't matter? 🤷

I refactored to use SignallingRef, which seems nicer thanks: a1668d7.

As part of that I removed the SyncIO from the StreamOutput class, but it's still used in various listeners due to the above reasoning.

If you prefer consistency and aren't worried about onReady for unary calls being a bit wasteful (since it's probably never called), I think I can just use F across the board.

dispatcher: Dispatcher[F]
) extends ClientCall.Listener[Response] {

override def onMessage(message: Response): Unit =
dispatcher.unsafeRunSync(ingest.onMessage(message))

override def onClose(status: Status, trailers: Metadata): Unit =
dispatcher.unsafeRunSync(ingest.onClose(GrpcStatus(status, trailers)))
override def onClose(status: Status, trailers: Metadata): Unit = {
val error = if (status.isOk) None else Some(status.asRuntimeException(trailers))
dispatcher.unsafeRunSync(ingest.onClose(error))
}

override def onReady(): Unit = signalReadiness.unsafeRunSync()

val stream: Stream[F, Response] = ingest.messages
}

object Fs2StreamClientCallListener {
private[client] object Fs2StreamClientCallListener {

private[client] def create[F[_]: Concurrent, Response](
def create[F[_]: Concurrent, Response](
request: Int => F[Unit],
signalReadiness: SyncIO[Unit],
dispatcher: Dispatcher[F],
prefetchN: Int
): F[Fs2StreamClientCallListener[F, Response]] =
StreamIngest[F, Response](request, prefetchN).map(
new Fs2StreamClientCallListener[F, Response](_, dispatcher)
new Fs2StreamClientCallListener[F, Response](_, signalReadiness, dispatcher)
)

}
19 changes: 9 additions & 10 deletions runtime/src/main/scala/fs2/grpc/client/StreamIngest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,26 @@ import cats.implicits._
import cats.effect.Concurrent
import cats.effect.std.Queue

private[client] trait StreamIngest[F[_], T] {
private[grpc] trait StreamIngest[F[_], T] {
def onMessage(msg: T): F[Unit]
def onClose(status: GrpcStatus): F[Unit]
def onClose(error: Option[Throwable]): F[Unit]
def messages: Stream[F, T]
}

private[client] object StreamIngest {
private[grpc] object StreamIngest {

def apply[F[_]: Concurrent, T](
request: Int => F[Unit],
prefetchN: Int
): F[StreamIngest[F, T]] =
Queue
.unbounded[F, Either[GrpcStatus, T]]
.unbounded[F, Either[Option[Throwable], T]]
.map(q => create[F, T](request, prefetchN, q))

def create[F[_], T](
request: Int => F[Unit],
prefetchN: Int,
queue: Queue[F, Either[GrpcStatus, T]]
queue: Queue[F, Either[Option[Throwable], T]]
)(implicit F: Concurrent[F]): StreamIngest[F, T] = new StreamIngest[F, T] {

val limit: Int =
Expand All @@ -58,17 +58,16 @@ private[client] object StreamIngest {
def onMessage(msg: T): F[Unit] =
queue.offer(msg.asRight) *> ensureMessages

def onClose(status: GrpcStatus): F[Unit] =
queue.offer(status.asLeft)
def onClose(error: Option[Throwable]): F[Unit] =
queue.offer(error.asLeft)

val messages: Stream[F, T] = {

val run: F[Option[T]] =
queue.take.flatMap {
case Right(v) => ensureMessages *> v.some.pure[F]
case Left(GrpcStatus(status, trailers)) =>
if (!status.isOk) F.raiseError(status.asRuntimeException(trailers))
else none[T].pure[F]
case Left(Some(error)) => F.raiseError(error)
case Left(None) => none[T].pure[F]
}

Stream.repeatEval(run).unNoneTerminate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,15 @@

package fs2.grpc.client.internal

import cats.effect.Sync
import cats.effect.SyncIO
import cats.effect.kernel.{Async, Outcome, Ref}
import cats.effect.std.Dispatcher
import cats.effect.syntax.all._
import cats.effect.kernel.Async
import cats.effect.kernel.Outcome
import cats.effect.kernel.Ref
import cats.syntax.functor._
import cats.effect.{Sync, SyncIO}
import cats.syntax.flatMap._
import cats.syntax.functor._
import fs2._
import fs2.grpc.client.ClientOptions
import fs2.grpc.shared.StreamOutput
import io.grpc._

private[client] object Fs2UnaryCallHandler {
Expand Down Expand Up @@ -65,7 +64,8 @@ private[client] object Fs2UnaryCallHandler {
class Done[R] extends ReceiveState[R]

private def mkListener[Response](
state: Ref[SyncIO, ReceiveState[Response]]
state: Ref[SyncIO, ReceiveState[Response]],
signalReadiness: SyncIO[Unit]
): ClientCall.Listener[Response] =
new ClientCall.Listener[Response] {
override def onMessage(message: Response): Unit =
Expand Down Expand Up @@ -110,6 +110,8 @@ private[client] object Fs2UnaryCallHandler {
}
}
}.unsafeRunSync()

override def onReady(): Unit = signalReadiness.unsafeRunSync()
}

def unary[F[_], Request, Response](
Expand All @@ -119,7 +121,7 @@ private[client] object Fs2UnaryCallHandler {
headers: Metadata
)(implicit F: Async[F]): F[Response] = F.async[Response] { cb =>
ReceiveState.init(cb, options.errorAdapter).map { state =>
call.start(mkListener[Response](state), headers)
call.start(mkListener[Response](state, SyncIO.unit), headers)
// Initially ask for two responses from flow-control so that if a misbehaving server
// sends more than one responses, we can catch it and fail it in the listener.
call.request(2)
Expand All @@ -132,16 +134,18 @@ private[client] object Fs2UnaryCallHandler {
def stream[F[_], Request, Response](
call: ClientCall[Request, Response],
options: ClientOptions,
dispatcher: Dispatcher[F],
messages: Stream[F, Request],
output: StreamOutput[F, Request],
headers: Metadata
)(implicit F: Async[F]): F[Response] = F.async[Response] { cb =>
ReceiveState.init(cb, options.errorAdapter).flatMap { state =>
call.start(mkListener[Response](state), headers)
call.start(mkListener[Response](state, output.onReadySync(dispatcher)), headers)
// Initially ask for two responses from flow-control so that if a misbehaving server
// sends more than one responses, we can catch it and fail it in the listener.
call.request(2)
messages
.map(call.sendMessage)
output
.writeStream(messages)
.compile
.drain
.guaranteeCase {
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/main/scala/fs2/grpc/server/Fs2ServerCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private[server] class Fs2ServerCall[F[_], Request, Response](val call: ServerCal
def closeStream(status: Status, trailers: Metadata)(implicit F: Sync[F]): F[Unit] =
F.delay(call.close(status, trailers))

def sendMessage(message: Response)(implicit F: Sync[F]): F[Unit] =
def sendSingleMessage(message: Response)(implicit F: Sync[F]): F[Unit] =
F.delay(call.sendMessage(message))

def request(numMessages: Int)(implicit F: Sync[F]): F[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ package server

import cats.effect._
import cats.effect.std.Dispatcher
import cats.syntax.all._
import fs2.grpc.server.internal.Fs2UnaryServerCallHandler
import fs2.grpc.shared.StreamOutput
import io.grpc._

class Fs2ServerCallHandler[F[_]: Async] private (
Expand All @@ -47,7 +49,7 @@ class Fs2ServerCallHandler[F[_]: Async] private (
implementation: (Stream[F, Request], Metadata) => F[Response]
): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options))
val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, SyncIO.unit, dispatcher, options))
listener.unsafeUnaryResponse(new Metadata(), implementation(_, headers))
listener
}
Expand All @@ -57,8 +59,10 @@ class Fs2ServerCallHandler[F[_]: Async] private (
implementation: (Stream[F, Request], Metadata) => Stream[F, Response]
): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] {
def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = {
val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options))
listener.unsafeStreamResponse(new Metadata(), implementation(_, headers))
val (listener, streamOutput) = dispatcher.unsafeRunSync(StreamOutput.server(call).flatMap { output =>
Fs2StreamServerCallListener[F](call, output.onReadySync(dispatcher), dispatcher, options).map((_, output))
})
listener.unsafeStreamResponse(streamOutput, new Metadata(), implementation(_, headers))
listener
}
}
Expand Down
17 changes: 11 additions & 6 deletions runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallListener.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ package fs2
package grpc
package server

import cats.syntax.all._
import cats.effect._
import cats.effect.std.Dispatcher
import cats.syntax.all._
import fs2.grpc.shared.StreamOutput
import io.grpc.{Metadata, Status, StatusException, StatusRuntimeException}

private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
Expand All @@ -49,10 +50,10 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
}

private def handleUnaryResponse(headers: Metadata, response: F[Response])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendMessage
call.sendHeaders(headers) *> call.request(1) *> response >>= call.sendSingleMessage

private def handleStreamResponse(headers: Metadata, response: Stream[F, Response])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> response.evalMap(call.sendMessage).compile.drain
private def handleStreamResponse(headers: Metadata, sendResponse: Stream[F, Unit])(implicit F: Sync[F]): F[Unit] =
call.sendHeaders(headers) *> call.request(1) *> sendResponse.compile.drain

private def unsafeRun(f: F[Unit])(implicit F: Async[F]): Unit = {
val bracketed = F.guaranteeCase(f) {
Expand All @@ -70,8 +71,12 @@ private[server] trait Fs2ServerCallListener[F[_], G[_], Request, Response] {
): Unit =
unsafeRun(handleUnaryResponse(headers, implementation(source)))

def unsafeStreamResponse(headers: Metadata, implementation: G[Request] => Stream[F, Response])(implicit
def unsafeStreamResponse(
streamOutput: StreamOutput[F, Response],
headers: Metadata,
implementation: G[Request] => Stream[F, Response]
)(implicit
F: Async[F]
): Unit =
unsafeRun(handleStreamResponse(headers, implementation(source)))
unsafeRun(handleStreamResponse(headers, streamOutput.writeStream(implementation(source))))
}
Loading