Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 98 additions & 30 deletions runtime/src/main/scala/fs2/grpc/shared/StreamIngest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,69 @@ private[grpc] trait StreamIngest[F[_], T] {

private[grpc] object StreamIngest {

// Minimum value of `prefetchN` parameter to use
// internal buffering of incoming messages in `StreamIngest`
val BufferingThreshold = 14 // 12 + 4 + 14 * 8 = 128 bytes for compressed pointers

sealed trait State[+T]
final case class Done(opt: Option[Throwable]) extends State[Nothing]

sealed trait Buffering[+T] extends State[T] {
def requested: Int

def append[T2 >: T](value: T2): (Option[Chunk[T2]], Buffering[T2])

def requestForMore(): Option[(Int, Buffering[T])]
def split(): Option[(Chunk[T], Buffering[T])]

protected def receivedOne: Int = math.max(0, requested - 1)
}
final case class ChunkBuffer[+T](requested: Int, limit: Int, chunk: Chunk[T]) extends Buffering[T] {
def append[T2 >: T](value: T2): (Option[Chunk[T2]], Buffering[T2]) = {
val updChunk = chunk ++ Chunk.singleton(value)

if (updChunk.size < limit) (none, copy(requested = receivedOne, chunk = updChunk))
else (updChunk.toIndexedChunk.some, copy(requested = receivedOne, chunk = Chunk.empty))
}

def requestForMore(): Option[(Int, Buffering[T])] = {
val additional = limit - requested
if (additional <= 0) none
else (additional, copy(requested = requested + additional)).some
}

def split(): Option[(Chunk[T], Buffering[T])] =
if (chunk.isEmpty) none
else (chunk.toIndexedChunk, copy(chunk = Chunk.empty[T])).some
}
final case class ArrayBuffer[+T](requested: Int, array: Array[AnyRef], offset: Int, count: Int) extends Buffering[T] {
private def chunk(length: Int): Chunk[T] = Chunk.array(array, offset, length).asInstanceOf[Chunk[T]]

def append[T2 >: T](value: T2): (Option[Chunk[T2]], Buffering[T2]) = {
array(offset + count) = value.asInstanceOf[AnyRef]
val updCount = count + 1

if (offset + updCount < array.length) (none, copy(requested = receivedOne, count = updCount))
else (chunk(updCount).some, copy(receivedOne, new Array[AnyRef](array.length), 0, 0))
}

def requestForMore(): Option[(Int, Buffering[T])] = {
val additional = array.length - requested
if (additional <= 0) none
else (additional, copy(requested = requested + additional)).some
}

def split(): Option[(Chunk[T], Buffering[T])] =
if (count == 0) none
else (chunk(count), copy(offset = offset + count, count = 0)).some
}

object State {
def apply[T](limit: Int): State[T] =
if (limit < BufferingThreshold) ChunkBuffer(0, limit, Chunk.empty[T])
else ArrayBuffer(0, new Array[AnyRef](limit), 0, 0)
}

def apply[F[_]: Concurrent, T](
request: Int => F[Unit],
prefetchN: Int
Expand All @@ -57,39 +120,44 @@ private[grpc] object StreamIngest {
queue.offer(error.asLeft)

val messages: Stream[F, T] = {
type Requested = Int
type S = Either[Option[Throwable], (Requested, Chunk[T])]
def receivedOne(requested: Requested): Requested = math.max(0, requested - 1)
def requestIfNeeded(requested: Requested): F[Requested] = {
val additional = math.max(0, limit - requested)
request(additional).whenA(additional > 0).as(requested + additional)
}
def loop(state: State[T]): F[Option[(Chunk[T], State[T])]] = state match {
case Done(None) => F.pure(none)
case Done(Some(err)) => F.raiseError(err)
case buf: Buffering[T] =>
def requestIfNeeded: F[Buffering[T]] = buf.requestForMore() match {
case Some((adds, buf)) => request(adds).as(buf)
case None => F.pure(buf)
}

def zero(requested: Requested): S = (requested, Chunk.empty).asRight

def loop(state: S): F[Option[(Chunk[T], S)]] =
state match {
case Left(None) => F.pure(none)
case Left(Some(err)) => F.raiseError(err)
case Right((requested, acc)) =>
queue.tryTake.flatMap {
case Some(Right(value)) => loop((receivedOne(requested), (acc ++ Chunk.singleton(value))).asRight)
case Some(Left(err)) =>
if (acc.isEmpty) loop(err.asLeft)
else F.pure((acc.toIndexedChunk, err.asLeft).some)
case None =>
def await(requested: Requested) = if (acc.isEmpty) queue.take.flatMap {
case Right(value) =>
loop((receivedOne(requested), Chunk.singleton(value)).asRight)
case Left(err) => loop(err.asLeft)
}
else F.pure((acc.toIndexedChunk, zero(requested)).some)

requestIfNeeded(requested) >>= await
def bufferOrEmit(buf: Buffering[T], value: T): F[Option[(Chunk[T], State[T])]] =
buf.append(value) match {
case (Some(chunk), buf) => F.pure((chunk, buf).some)
case (None, buf) => loop(buf)
}
}

Stream.unfoldChunkEval[F, S, T](zero(0))(loop)
def waitOrEmit(buf: Buffering[T]): F[Option[(Chunk[T], State[T])]] = buf.split() match {
case some: Some[_] => F.pure(some)
case None =>
queue.take.flatMap {
case Right(value) => bufferOrEmit(buf, value)
case Left(opt) => loop(Done(opt))
}
}

queue.tryTake.flatMap {
case None => requestIfNeeded >>= waitOrEmit
case Some(Right(value)) => bufferOrEmit(buf, value)
case Some(Left(err)) =>
buf.split() match {
case Some((chunk, _)) => F.pure((chunk, Done(err)).some)
case None => loop(Done(err))
}
}
}

Stream.eval(F.pure(limit).map(State(_))).flatMap { z =>
Stream.unfoldChunkEval[F, State[T], T](z)(loop)
}
}
}

Expand Down
29 changes: 21 additions & 8 deletions runtime/src/test/scala/fs2/grpc/shared/StreamIngestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,44 @@ package fs2
package grpc
package shared

import cats.syntax.all._
import cats.effect._
import munit._

class StreamIngestSuite extends CatsEffectSuite with CatsEffectFunFixtures {

test("basic") {

def run(prefetchN: Int, takeN: Int, expectedReq: Int, expectedCount: Int) = {
def run(emitN: Int, prefetchN: Int, takeN: Int, expectedReq: Int, expectedCount: Int) = {
def capture(s: Stream[IO, Int]): IO[List[Int]] = s.take(takeN.toLong).compile.toList

for {
ref <- IO.ref(0)
ingest <- StreamIngest[IO, Int](req => ref.update(_ + req), prefetchN)
_ <- Stream.emits((1 to prefetchN)).evalTap(ingest.onMessage).compile.drain
messages <- ingest.messages.take(takeN.toLong).compile.toList
emitted = Stream.emits((1 to emitN)).covary[IO]
_ <- emitted.evalTap(ingest.onMessage).compile.drain
messages <- capture(ingest.messages)
expectedMsgs <- capture(emitted)
requested <- ref.get
} yield {
assertEquals(messages.size, expectedCount)
assertEquals(messages, expectedMsgs)
assertEquals(requested, expectedReq)
}
}

run(prefetchN = 1, takeN = 1, expectedReq = 1, expectedCount = 1) *>
run(prefetchN = 2, takeN = 1, expectedReq = 2, expectedCount = 1) *>
run(prefetchN = 2, takeN = 2, expectedReq = 2, expectedCount = 2) *>
run(prefetchN = 1024, takeN = 1024, expectedReq = 1024, expectedCount = 1024) *>
run(prefetchN = 1024, takeN = 1023, expectedReq = 1024, expectedCount = 1023)
List(
run(emitN = 1, prefetchN = 1, takeN = 1, expectedReq = 0, expectedCount = 1),
run(emitN = 1, prefetchN = 2, takeN = 1, expectedReq = 2, expectedCount = 1),
run(emitN = 2, prefetchN = 2, takeN = 1, expectedReq = 0, expectedCount = 1),
run(emitN = 2, prefetchN = 4, takeN = 1, expectedReq = 4, expectedCount = 1),
run(emitN = 2, prefetchN = 1, takeN = 2, expectedReq = 0, expectedCount = 2),
run(emitN = 2, prefetchN = 4, takeN = 2, expectedReq = 4, expectedCount = 2),
run(emitN = 1024, prefetchN = 1024, takeN = 1024, expectedReq = 0, expectedCount = 1024),
run(emitN = 1024, prefetchN = 2048, takeN = 1024, expectedReq = 2048, expectedCount = 1024),
run(emitN = 1024, prefetchN = 1024, takeN = 1023, expectedReq = 0, expectedCount = 1023),
run(emitN = 1024, prefetchN = 2048, takeN = 1023, expectedReq = 2048, expectedCount = 1023)
).combineAll

}

Expand Down