diff --git a/runtime/src/main/scala/fs2/grpc/shared/StreamIngest.scala b/runtime/src/main/scala/fs2/grpc/shared/StreamIngest.scala index 8768474d..a9319008 100644 --- a/runtime/src/main/scala/fs2/grpc/shared/StreamIngest.scala +++ b/runtime/src/main/scala/fs2/grpc/shared/StreamIngest.scala @@ -35,6 +35,69 @@ private[grpc] trait StreamIngest[F[_], T] { private[grpc] object StreamIngest { + // Minimum value of `prefetchN` parameter to use + // internal buffering of incoming messages in `StreamIngest` + val BufferingThreshold = 14 // 12 + 4 + 14 * 8 = 128 bytes for compressed pointers + + sealed trait State[+T] + final case class Done(opt: Option[Throwable]) extends State[Nothing] + + sealed trait Buffering[+T] extends State[T] { + def requested: Int + + def append[T2 >: T](value: T2): (Option[Chunk[T2]], Buffering[T2]) + + def requestForMore(): Option[(Int, Buffering[T])] + def split(): Option[(Chunk[T], Buffering[T])] + + protected def receivedOne: Int = math.max(0, requested - 1) + } + final case class ChunkBuffer[+T](requested: Int, limit: Int, chunk: Chunk[T]) extends Buffering[T] { + def append[T2 >: T](value: T2): (Option[Chunk[T2]], Buffering[T2]) = { + val updChunk = chunk ++ Chunk.singleton(value) + + if (updChunk.size < limit) (none, copy(requested = receivedOne, chunk = updChunk)) + else (updChunk.toIndexedChunk.some, copy(requested = receivedOne, chunk = Chunk.empty)) + } + + def requestForMore(): Option[(Int, Buffering[T])] = { + val additional = limit - requested + if (additional <= 0) none + else (additional, copy(requested = requested + additional)).some + } + + def split(): Option[(Chunk[T], Buffering[T])] = + if (chunk.isEmpty) none + else (chunk.toIndexedChunk, copy(chunk = Chunk.empty[T])).some + } + final case class ArrayBuffer[+T](requested: Int, array: Array[AnyRef], offset: Int, count: Int) extends Buffering[T] { + private def chunk(length: Int): Chunk[T] = Chunk.array(array, offset, length).asInstanceOf[Chunk[T]] + + def append[T2 >: T](value: T2): (Option[Chunk[T2]], Buffering[T2]) = { + array(offset + count) = value.asInstanceOf[AnyRef] + val updCount = count + 1 + + if (offset + updCount < array.length) (none, copy(requested = receivedOne, count = updCount)) + else (chunk(updCount).some, copy(receivedOne, new Array[AnyRef](array.length), 0, 0)) + } + + def requestForMore(): Option[(Int, Buffering[T])] = { + val additional = array.length - requested + if (additional <= 0) none + else (additional, copy(requested = requested + additional)).some + } + + def split(): Option[(Chunk[T], Buffering[T])] = + if (count == 0) none + else (chunk(count), copy(offset = offset + count, count = 0)).some + } + + object State { + def apply[T](limit: Int): State[T] = + if (limit < BufferingThreshold) ChunkBuffer(0, limit, Chunk.empty[T]) + else ArrayBuffer(0, new Array[AnyRef](limit), 0, 0) + } + def apply[F[_]: Concurrent, T]( request: Int => F[Unit], prefetchN: Int @@ -57,39 +120,44 @@ private[grpc] object StreamIngest { queue.offer(error.asLeft) val messages: Stream[F, T] = { - type Requested = Int - type S = Either[Option[Throwable], (Requested, Chunk[T])] - def receivedOne(requested: Requested): Requested = math.max(0, requested - 1) - def requestIfNeeded(requested: Requested): F[Requested] = { - val additional = math.max(0, limit - requested) - request(additional).whenA(additional > 0).as(requested + additional) - } + def loop(state: State[T]): F[Option[(Chunk[T], State[T])]] = state match { + case Done(None) => F.pure(none) + case Done(Some(err)) => F.raiseError(err) + case buf: Buffering[T] => + def requestIfNeeded: F[Buffering[T]] = buf.requestForMore() match { + case Some((adds, buf)) => request(adds).as(buf) + case None => F.pure(buf) + } - def zero(requested: Requested): S = (requested, Chunk.empty).asRight - - def loop(state: S): F[Option[(Chunk[T], S)]] = - state match { - case Left(None) => F.pure(none) - case Left(Some(err)) => F.raiseError(err) - case Right((requested, acc)) => - queue.tryTake.flatMap { - case Some(Right(value)) => loop((receivedOne(requested), (acc ++ Chunk.singleton(value))).asRight) - case Some(Left(err)) => - if (acc.isEmpty) loop(err.asLeft) - else F.pure((acc.toIndexedChunk, err.asLeft).some) - case None => - def await(requested: Requested) = if (acc.isEmpty) queue.take.flatMap { - case Right(value) => - loop((receivedOne(requested), Chunk.singleton(value)).asRight) - case Left(err) => loop(err.asLeft) - } - else F.pure((acc.toIndexedChunk, zero(requested)).some) - - requestIfNeeded(requested) >>= await + def bufferOrEmit(buf: Buffering[T], value: T): F[Option[(Chunk[T], State[T])]] = + buf.append(value) match { + case (Some(chunk), buf) => F.pure((chunk, buf).some) + case (None, buf) => loop(buf) } - } - Stream.unfoldChunkEval[F, S, T](zero(0))(loop) + def waitOrEmit(buf: Buffering[T]): F[Option[(Chunk[T], State[T])]] = buf.split() match { + case some: Some[_] => F.pure(some) + case None => + queue.take.flatMap { + case Right(value) => bufferOrEmit(buf, value) + case Left(opt) => loop(Done(opt)) + } + } + + queue.tryTake.flatMap { + case None => requestIfNeeded >>= waitOrEmit + case Some(Right(value)) => bufferOrEmit(buf, value) + case Some(Left(err)) => + buf.split() match { + case Some((chunk, _)) => F.pure((chunk, Done(err)).some) + case None => loop(Done(err)) + } + } + } + + Stream.eval(F.pure(limit).map(State(_))).flatMap { z => + Stream.unfoldChunkEval[F, State[T], T](z)(loop) + } } } diff --git a/runtime/src/test/scala/fs2/grpc/shared/StreamIngestSuite.scala b/runtime/src/test/scala/fs2/grpc/shared/StreamIngestSuite.scala index dd1e0d07..24a9ce98 100644 --- a/runtime/src/test/scala/fs2/grpc/shared/StreamIngestSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/shared/StreamIngestSuite.scala @@ -23,6 +23,7 @@ package fs2 package grpc package shared +import cats.syntax.all._ import cats.effect._ import munit._ @@ -30,24 +31,36 @@ class StreamIngestSuite extends CatsEffectSuite with CatsEffectFunFixtures { test("basic") { - def run(prefetchN: Int, takeN: Int, expectedReq: Int, expectedCount: Int) = { + def run(emitN: Int, prefetchN: Int, takeN: Int, expectedReq: Int, expectedCount: Int) = { + def capture(s: Stream[IO, Int]): IO[List[Int]] = s.take(takeN.toLong).compile.toList + for { ref <- IO.ref(0) ingest <- StreamIngest[IO, Int](req => ref.update(_ + req), prefetchN) - _ <- Stream.emits((1 to prefetchN)).evalTap(ingest.onMessage).compile.drain - messages <- ingest.messages.take(takeN.toLong).compile.toList + emitted = Stream.emits((1 to emitN)).covary[IO] + _ <- emitted.evalTap(ingest.onMessage).compile.drain + messages <- capture(ingest.messages) + expectedMsgs <- capture(emitted) requested <- ref.get } yield { assertEquals(messages.size, expectedCount) + assertEquals(messages, expectedMsgs) assertEquals(requested, expectedReq) } } - run(prefetchN = 1, takeN = 1, expectedReq = 1, expectedCount = 1) *> - run(prefetchN = 2, takeN = 1, expectedReq = 2, expectedCount = 1) *> - run(prefetchN = 2, takeN = 2, expectedReq = 2, expectedCount = 2) *> - run(prefetchN = 1024, takeN = 1024, expectedReq = 1024, expectedCount = 1024) *> - run(prefetchN = 1024, takeN = 1023, expectedReq = 1024, expectedCount = 1023) + List( + run(emitN = 1, prefetchN = 1, takeN = 1, expectedReq = 0, expectedCount = 1), + run(emitN = 1, prefetchN = 2, takeN = 1, expectedReq = 2, expectedCount = 1), + run(emitN = 2, prefetchN = 2, takeN = 1, expectedReq = 0, expectedCount = 1), + run(emitN = 2, prefetchN = 4, takeN = 1, expectedReq = 4, expectedCount = 1), + run(emitN = 2, prefetchN = 1, takeN = 2, expectedReq = 0, expectedCount = 2), + run(emitN = 2, prefetchN = 4, takeN = 2, expectedReq = 4, expectedCount = 2), + run(emitN = 1024, prefetchN = 1024, takeN = 1024, expectedReq = 0, expectedCount = 1024), + run(emitN = 1024, prefetchN = 2048, takeN = 1024, expectedReq = 2048, expectedCount = 1024), + run(emitN = 1024, prefetchN = 1024, takeN = 1023, expectedReq = 0, expectedCount = 1023), + run(emitN = 1024, prefetchN = 2048, takeN = 1023, expectedReq = 2048, expectedCount = 1023) + ).combineAll }