Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ inThisBuild(
ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.client.StreamIngest.create"),
ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.server.Fs2StreamServerCallListener*"),
ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.client.Fs2StreamClientCallListener*"),
ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.client.StreamIngest*"),
ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.codegen.Fs2GrpcServicePrinter$constants$"),
ProblemFilters.exclude[MissingFieldProblem]("fs2.grpc.codegen.Fs2GrpcServicePrinter.constants"),
// deleted private classes
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/main/scala/fs2/grpc/client/Fs2ClientCall.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class Fs2ClientCall[F[_], Request, Response] private[client] (
): Resource[F, Fs2StreamClientCallListener[F, Response]] = {
val prefetchN = options.prefetchN.max(1)
val create = Fs2StreamClientCallListener.create[F, Response](request, signalReadiness, dispatcher, prefetchN)
val acquire = start(create, md) <* request(prefetchN)
Comment thread
ahjohannessen marked this conversation as resolved.
val acquire = start(create, md)
val release = handleExitCase(cancelSucceed = true)

Resource.makeCase(acquire)(release)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import cats.effect.SyncIO
import cats.implicits._
import cats.effect.kernel.Concurrent
import cats.effect.std.Dispatcher
import fs2.grpc.shared.StreamIngest
import io.grpc.{ClientCall, Metadata, Status}

private[client] class Fs2StreamClientCallListener[F[_], Response] private (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import cats.syntax.all._
import cats.effect.kernel.Deferred
import cats.effect.{Async, SyncIO}
import cats.effect.std.Dispatcher
import fs2.grpc.client.StreamIngest
import fs2.grpc.shared.StreamIngest
import io.grpc.ServerCall

private[server] class Fs2StreamServerCallListener[F[_], Request, Response] private (
Expand Down Expand Up @@ -67,7 +67,8 @@ private[server] object Fs2StreamServerCallListener {
)(implicit F: Async[F]): F[Fs2StreamServerCallListener[F, Request, Response]] = for {
isCancelled <- Deferred[F, Unit]
request = (n: Int) => F.delay(call.request(n))
ingest <- StreamIngest[F, Request](request, prefetchN = 1)
prefetchN = math.max(options.prefetchN, 1)
ingest <- StreamIngest[F, Request](request, prefetchN)
serverCall <- Fs2ServerCall[F, Request, Response](call, options)
} yield new Fs2StreamServerCallListener[F, Request, Response](
ingest,
Expand Down
15 changes: 13 additions & 2 deletions runtime/src/main/scala/fs2/grpc/server/ServerOptions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,22 @@ package grpc
package server

sealed abstract class ServerOptions private (
val prefetchN: Int,
val callOptionsFn: ServerCallOptions => ServerCallOptions
) {

private def copy(
callOptionsFn: ServerCallOptions => ServerCallOptions
): ServerOptions = new ServerOptions(callOptionsFn) {}
prefetchN: Int = this.prefetchN,
callOptionsFn: ServerCallOptions => ServerCallOptions = this.callOptionsFn
): ServerOptions = new ServerOptions(prefetchN, callOptionsFn) {}

/** Prefetch up to @param n messages from a client. The server will try to keep the internal buffer filled according
* to the provided value.
*
* If the provided value is less than 1 it defaults to 1.
*/
def withPrefetchN(n: Int): ServerOptions =
copy(prefetchN = math.max(n, 1))

/** Function that is applied on `fs2.grpc.ServerCallOptions.default` for each new RPC call.
*/
Expand All @@ -40,6 +50,7 @@ sealed abstract class ServerOptions private (
object ServerOptions {

val default: ServerOptions = new ServerOptions(
prefetchN = 1,
callOptionsFn = identity
) {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@

package fs2
package grpc
package client
package shared

import cats.implicits._
import cats.effect.Concurrent
import cats.effect.{Concurrent, Ref}
import cats.effect.std.Queue

private[grpc] trait StreamIngest[F[_], T] {
Expand All @@ -39,41 +39,63 @@ private[grpc] object StreamIngest {
request: Int => F[Unit],
prefetchN: Int
): F[StreamIngest[F, T]] =
Queue
.unbounded[F, Either[Option[Throwable], T]]
.map(q => create[F, T](request, prefetchN, q))
(Ref[F].of(0), Queue.unbounded[F, Either[Option[Throwable], T]])
.mapN((r, q) => create[F, T](request, prefetchN, r, q))

def create[F[_], T](
request: Int => F[Unit],
prefetchN: Int,
requested: Ref[F, Int],
queue: Queue[F, Either[Option[Throwable], T]]
)(implicit F: Concurrent[F]): StreamIngest[F, T] = new StreamIngest[F, T] {

val limit: Int =
math.max(1, prefetchN)

val ensureMessages: F[Unit] =
queue.size.flatMap(qs => request(1).whenA(qs < limit))
private val limit: Int = math.max(1, prefetchN)
private def updateRequests: F[Unit] = {
queue.size.flatMap { queued =>
requested.flatModify { requested =>
val total = queued + requested
val additional = math.max(0, limit - total)

(
requested + additional,
request(additional).whenA(additional > 0)
)
}
}
}

def onMessage(msg: T): F[Unit] =
queue.offer(msg.asRight) *> ensureMessages
queue.offer(msg.asRight) *> requested.update(r => math.max(0, r - 1))

def onClose(error: Option[Throwable]): F[Unit] =
queue.offer(error.asLeft)

val messages: Stream[F, T] = {

val run: F[Option[T]] =
queue.take.flatMap {
case Right(v) => ensureMessages *> v.some.pure[F]
case Left(Some(error)) => F.raiseError(error)
case Left(None) => none[T].pure[F]
type S = Either[Option[Throwable], Chunk[T]]

def zero: S = Chunk.empty.asRight
def loop(state: S): F[Option[(Chunk[T], S)]] =
state match {
case Left(None) => F.pure(none)
case Left(Some(err)) => F.raiseError(err)
case Right(acc) =>
queue.tryTake.flatMap {
case Some(Right(value)) => loop((acc ++ Chunk.singleton(value)).asRight)
case Some(Left(err)) =>
if (acc.isEmpty) loop(err.asLeft)
else F.pure((acc.toIndexedChunk, err.asLeft).some)
case None =>
val await = if (acc.isEmpty) queue.take.flatMap {
case Right(value) => loop(Chunk.singleton(value).asRight)
case Left(err) => loop(err.asLeft)
}
else F.pure((acc.toIndexedChunk, zero).some)

updateRequests *> await
}
}

Stream.repeatEval(run).unNoneTerminate

Stream.unfoldChunkEval(zero)(loop)
}

}

}
2 changes: 1 addition & 1 deletion runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class ServerSuite extends Fs2GrpcSuite {

tc.tick()

assertEquals(dummy.requested, 1)
assertEquals(dummy.requested, 2)

listener.onMessage("1")
tc.tick()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

package fs2
package grpc
package client
package shared

import cats.effect._
import munit._
Expand All @@ -45,9 +45,9 @@ class StreamIngestSuite extends CatsEffectSuite with CatsEffectFunFixtures {

run(prefetchN = 1, takeN = 1, expectedReq = 1, expectedCount = 1) *>
run(prefetchN = 2, takeN = 1, expectedReq = 2, expectedCount = 1) *>
run(prefetchN = 2, takeN = 2, expectedReq = 3, expectedCount = 2) *>
run(prefetchN = 1024, takeN = 1024, expectedReq = 2047, expectedCount = 1024) *>
run(prefetchN = 1024, takeN = 1023, expectedReq = 2046, expectedCount = 1023)
run(prefetchN = 2, takeN = 2, expectedReq = 2, expectedCount = 2) *>
run(prefetchN = 1024, takeN = 1024, expectedReq = 1024, expectedCount = 1024) *>
run(prefetchN = 1024, takeN = 1023, expectedReq = 1024, expectedCount = 1023)

}

Expand Down