diff --git a/build.sbt b/build.sbt index 92a42cc6..fc32de03 100644 --- a/build.sbt +++ b/build.sbt @@ -37,7 +37,8 @@ inThisBuild( ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.client.StreamIngest.create"), // deleted private classes ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.client.Fs2UnaryClientCallListener*"), - ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.server.Fs2UnaryServerCallListener*") + ProblemFilters.exclude[MissingClassProblem]("fs2.grpc.server.Fs2UnaryServerCallListener*"), + ProblemFilters.exclude[DirectMissingMethodProblem]("fs2.grpc.server.internal.*") ) ) ) diff --git a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala index b98c9526..9a92fab3 100644 --- a/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/Fs2ServerCallHandler.scala @@ -25,6 +25,7 @@ package server import cats.effect._ import cats.effect.std.Dispatcher +import fs2.grpc.server.internal.Fs2StreamServerCallHandler import fs2.grpc.server.internal.Fs2UnaryServerCallHandler import io.grpc._ @@ -36,32 +37,22 @@ class Fs2ServerCallHandler[F[_]: Async] private ( def unaryToUnaryCall[Request, Response]( implementation: (Request, Metadata) => F[Response] ): ServerCallHandler[Request, Response] = - Fs2UnaryServerCallHandler.unary(implementation, options, dispatcher) + Fs2UnaryServerCallHandler.mkHandler(implementation, options)(_.unary(_, dispatcher)) def unaryToStreamingCall[Request, Response]( implementation: (Request, Metadata) => Stream[F, Response] ): ServerCallHandler[Request, Response] = - Fs2UnaryServerCallHandler.stream(implementation, options, dispatcher) + Fs2UnaryServerCallHandler.mkHandler(implementation, options)(_.stream(_, dispatcher)) def streamingToUnaryCall[Request, Response]( implementation: (Stream[F, Request], Metadata) => F[Response] - ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options)) - listener.unsafeUnaryResponse(new Metadata(), implementation(_, headers)) - listener - } - } + ): ServerCallHandler[Request, Response] = + Fs2StreamServerCallHandler.mkHandler(implementation, options)(_.unary(_, dispatcher)) def streamingToStreamingCall[Request, Response]( implementation: (Stream[F, Request], Metadata) => Stream[F, Response] - ): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { - val listener = dispatcher.unsafeRunSync(Fs2StreamServerCallListener[F](call, dispatcher, options)) - listener.unsafeStreamResponse(new Metadata(), implementation(_, headers)) - listener - } - } + ): ServerCallHandler[Request, Response] = + Fs2StreamServerCallHandler.mkHandler(implementation, options)(_.stream(_, dispatcher)) } object Fs2ServerCallHandler { diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala index 760735bb..42463d63 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2ServerCall.scala @@ -73,6 +73,9 @@ private[server] final class Fs2ServerCall[Request, Response]( dispatcher ) + def requestOnPull[F[_]](implicit F: Sync[F]): Pipe[F, Request, Request] = + _.chunks.flatMap(chunk => Stream.evalUnChunk(F.as(F.delay(call.request(chunk.size)), chunk))) + def request(n: Int): SyncIO[Unit] = SyncIO(call.request(n)) diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2StreamServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2StreamServerCallHandler.scala new file mode 100644 index 00000000..4027e746 --- /dev/null +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2StreamServerCallHandler.scala @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2.grpc.server.internal + +import cats.effect.Async +import cats.effect.SyncIO +import fs2._ +import fs2.grpc.server.ServerCallOptions +import fs2.grpc.server.ServerOptions +import fs2.grpc.server.internal.Fs2ServerCall.Cancel +import io.grpc.ServerCall +import io.grpc._ + +object Fs2StreamServerCallHandler { + + private def mkListener[Request]( + channel: OneShotChannel[Request], + cancel: Cancel + ): ServerCall.Listener[Request] = + new ServerCall.Listener[Request] { + override def onCancel(): Unit = + cancel.unsafeRunSync() + + override def onMessage(message: Request): Unit = + channel.send(message).unsafeRunSync() + + override def onHalfClose(): Unit = + channel.close().unsafeRunSync() + } + + def mkHandler[F[_]: Async, G[_], Request, Response]( + impl: (Stream[F, Request], Metadata) => G[Response], + options: ServerOptions + )(start: (Fs2ServerCall[Request, Response], G[Response]) => SyncIO[Cancel]): ServerCallHandler[Request, Response] = + new ServerCallHandler[Request, Response] { + private val opt = options.callOptionsFn(ServerCallOptions.default) + + def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { + for { + call <- Fs2ServerCall.setup(opt, call) + _ <- call.request(1) // prefetch size + channel <- OneShotChannel.empty[Request] + cancel <- start(call, impl(channel.stream.through(call.requestOnPull), headers)) + } yield mkListener(channel, cancel) + }.unsafeRunSync() + } +} diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala index e7a0733e..ab43f7d7 100644 --- a/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala +++ b/runtime/src/main/scala/fs2/grpc/server/internal/Fs2UnaryServerCallHandler.scala @@ -22,9 +22,7 @@ package fs2.grpc.server.internal import cats.effect.Ref -import cats.effect.Sync import cats.effect.SyncIO -import cats.effect.std.Dispatcher import fs2.grpc.server.ServerCallOptions import fs2.grpc.server.ServerOptions import io.grpc._ @@ -88,40 +86,21 @@ private[server] object Fs2UnaryServerCallHandler { state.set(Cancelled()) >> call.close(status, new Metadata()) } - def unary[F[_]: Sync, Request, Response]( - impl: (Request, Metadata) => F[Response], - options: ServerOptions, - dispatcher: Dispatcher[F] - ): ServerCallHandler[Request, Response] = + def mkHandler[G[_], Request, Response]( + impl: (Request, Metadata) => G[Response], + options: ServerOptions + )(start: (Fs2ServerCall[Request, Response], G[Response]) => SyncIO[Cancel]): ServerCallHandler[Request, Response] = new ServerCallHandler[Request, Response] { private val opt = options.callOptionsFn(ServerCallOptions.default) - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = - startCallSync(call, opt)(call => req => call.unary(impl(req, headers), dispatcher)).unsafeRunSync() + def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = { + for { + call <- Fs2ServerCall.setup(opt, call) + // We expect only 1 request, but we ask for 2 requests here so that if a misbehaving client + // sends more than 1 requests, ServerCall will catch it. + _ <- call.request(2) + state <- CallerState.init[Request](req => start(call, impl(req, headers))) + } yield mkListener[Request, Response](call, state) + }.unsafeRunSync() } - - def stream[F[_]: Sync, Request, Response]( - impl: (Request, Metadata) => fs2.Stream[F, Response], - options: ServerOptions, - dispatcher: Dispatcher[F] - ): ServerCallHandler[Request, Response] = - new ServerCallHandler[Request, Response] { - private val opt = options.callOptionsFn(ServerCallOptions.default) - - def startCall(call: ServerCall[Request, Response], headers: Metadata): ServerCall.Listener[Request] = - startCallSync(call, opt)(call => req => call.stream(impl(req, headers), dispatcher)).unsafeRunSync() - } - - private def startCallSync[F[_], Request, Response]( - call: ServerCall[Request, Response], - options: ServerCallOptions - )(f: Fs2ServerCall[Request, Response] => Request => SyncIO[Cancel]): SyncIO[ServerCall.Listener[Request]] = { - for { - call <- Fs2ServerCall.setup(options, call) - // We expect only 1 request, but we ask for 2 requests here so that if a misbehaving client - // sends more than 1 requests, ServerCall will catch it. - _ <- call.request(2) - state <- CallerState.init(f(call)) - } yield mkListener[Request, Response](call, state) - } } diff --git a/runtime/src/main/scala/fs2/grpc/server/internal/OneShotChannel.scala b/runtime/src/main/scala/fs2/grpc/server/internal/OneShotChannel.scala new file mode 100644 index 00000000..9cbbc8dc --- /dev/null +++ b/runtime/src/main/scala/fs2/grpc/server/internal/OneShotChannel.scala @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2.grpc.server.internal + +import cats.effect._ +import cats.syntax.functor._ +import fs2._ +import fs2.grpc.server.internal.OneShotChannel.State +import scala.collection.immutable.Queue + +private[server] final class OneShotChannel[A](val state: Ref[SyncIO, State[A]]) extends AnyVal { + + import State._ + + /** Send message to stream. + */ + def send(a: A): SyncIO[Unit] = + state + .modify { + case open: Open[A] => (open.append(a), SyncIO.unit) + case s: Suspended[A] => (State.consumed, s.resume(State.open(a))) + case closed => (closed, SyncIO.unit) + } + .flatMap(identity) + + /** Close stream. + */ + def close(): SyncIO[Unit] = + state + .modify { + case open: Open[A] => (open.close(), SyncIO.unit) + case s: Suspended[A] => (State.done, s.resume(State.done)) + case closed => (closed, SyncIO.unit) + } + .flatMap(identity) + + import fs2._ + + /** This method can be called at most once + */ + def stream[F[_]](implicit F: Async[F]): Stream[F, A] = { + def go(): Pull[F, A, Unit] = + Pull + .eval(state.getAndSet(State.consumed).to[F]) + .flatMap { + case Consumed => + Pull.eval(F.async[State[A]] { cb => + val next = new Suspended[A](s => cb(Right(s))) + state + .modify { + case Consumed => (next, None) + case other => (State.consumed, Some(other)) + } + .to[F] + .map { + case Some(received) => + cb(Right(received)) + None + case None => + Some(state.set(State.consumed).to[F]) + } + }) + case other => Pull.pure(other) + } + .flatMap { + case open: Open[A] => open.toPull >> go() + case other => other.toPull + } + + go().stream + } +} + +private[server] object OneShotChannel { + def empty[A]: SyncIO[OneShotChannel[A]] = + Ref[SyncIO].of[State[A]](State.consumed).map(new OneShotChannel[A](_)) + + sealed trait State[A] { + def toPull[F[_]: Sync]: Pull[F, A, Unit] + } + + object State { + class UnexpectedState extends RuntimeException + private[OneShotChannel] val Consumed: State[Nothing] = new Open(Queue.empty) + def consumed[A]: State[A] = Consumed.asInstanceOf[State[A]] + + def done[A]: State[A] = new Closed(Queue.empty) + + def open[A](a: A): Open[A] = new Open(Queue(a)) + + class Open[A](queue: Queue[A]) extends State[A] { + def append(a: A): Open[A] = new Open(queue.enqueue(a)) + + def toPull[F[_]: Sync]: Pull[F, A, Unit] = Pull.output(Chunk.queue(queue)) + + def close(): State[A] = new Closed(queue) + } + + class Closed[A](queue: Queue[A]) extends State[A] { + def toPull[F[_]: Sync]: Pull[F, A, Unit] = Pull.output(Chunk.queue(queue)) + } + + class Suspended[A](f: State[A] => Unit) extends State[A] { + def resume(state: State[A]): SyncIO[Unit] = SyncIO(f(state)) + + def toPull[F[_]: Sync]: Pull[F, A, Unit] = Pull.raiseError(new UnexpectedState) // never happened + } + } +} diff --git a/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala b/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala index c8f77d28..12f977ce 100644 --- a/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/client/ClientSuite.scala @@ -95,6 +95,7 @@ class ClientSuite extends Fs2GrpcSuite { assertEquals(dummy.messagesSent.size, 1) assertEquals(dummy.requested, 2) + Thread.sleep(10) } runTest0("error response to unaryToUnary") { (tc, io, d) => diff --git a/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala b/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala index 31036241..90960f50 100644 --- a/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala +++ b/runtime/src/test/scala/fs2/grpc/server/DummyServerCall.scala @@ -30,6 +30,7 @@ import scala.collection.mutable.ArrayBuffer class DummyServerCall extends ServerCall[String, Int] { val messages: ArrayBuffer[Int] = ArrayBuffer[Int]() var currentStatus: Option[Status] = None + var explicitCompressor: Option[String] = None override def request(numMessages: Int): Unit = () override def sendMessage(message: Int): Unit = { @@ -43,5 +44,10 @@ class DummyServerCall extends ServerCall[String, Int] { override def close(status: Status, trailers: Metadata): Unit = { currentStatus = Some(status) } + + override def setCompression(compressor: String): Unit = { + explicitCompressor = Some(compressor) + } + override def isCancelled: Boolean = false } diff --git a/runtime/src/test/scala/fs2/grpc/server/OneShotChannelSuite.scala b/runtime/src/test/scala/fs2/grpc/server/OneShotChannelSuite.scala new file mode 100644 index 00000000..98ee36b8 --- /dev/null +++ b/runtime/src/test/scala/fs2/grpc/server/OneShotChannelSuite.scala @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2018 Gary Coady / Fs2 Grpc Developers + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2.grpc.server + +import cats.effect._ +import cats.effect.testkit.TestControl +import fs2._ +import fs2.grpc.server.internal.OneShotChannel +import munit.CatsEffectFunFixtures +import munit.CatsEffectSuite +import scala.concurrent.duration._ + +class OneShotChannelSuite extends CatsEffectSuite with CatsEffectFunFixtures { + + def run(size: Int, producerDelay: Int, consumerDelay: Int): IO[Unit] = { + for { + ch <- OneShotChannel.empty[Int].to[IO] + chunkN <- Ref[IO].of(0) + stream <- ch + .stream[IO] + .chunks + .evalTap(_ => chunkN.update(_ + 1)) + .unchunks + .zipLeft(Stream.awakeDelay[IO](consumerDelay.seconds)) + .compile + .toList + .start + _ <- Stream + .range(0, size) + .evalMap(c => ch.send(c).to[IO]) + .append(Stream.exec(ch.close().to[IO])) + .zipLeft(Stream.awakeDelay[IO](producerDelay.seconds)) + .compile + .drain + stream <- stream.joinWithNever + chunkN <- chunkN.get + } yield { + assertEquals(stream, Stream.range(0, size).toList) + if (consumerDelay > producerDelay) { + assert(clue(chunkN) < clue(size)) + } else { + assertEquals(chunkN, size) + } + } + } + + test("basic") { + TestControl.executeEmbed( + run(0, 1, 1) >> + run(1, 1, 1) >> + run(2, 1, 1) >> + run(8, 1, 1) + ) + } + test("slow producer") { + TestControl.executeEmbed( + run(8, 2, 1) >> + run(16, 4, 3) + ) + } + test("slow consumer") { + TestControl.executeEmbed( + run(8, 1, 2) >> + run(16, 3, 4) + ) + } +} diff --git a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala index 68ded68e..24425edd 100644 --- a/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala +++ b/runtime/src/test/scala/fs2/grpc/server/ServerSuite.scala @@ -23,58 +23,99 @@ package fs2 package grpc package server -import scala.concurrent.duration._ import cats.effect._ import cats.effect.std.Dispatcher import cats.effect.testkit.TestContext -import fs2.grpc.server.internal.Fs2UnaryServerCallHandler +import cats.effect.testkit.TestControl import io.grpc._ +import scala.concurrent.duration._ class ServerSuite extends Fs2GrpcSuite { private val compressionOps = ServerOptions.default.configureCallOptions(_.withServerCompressor(Some(GzipCompressor))) - runTest("single message to unaryToUnary")(singleUnaryToUnary()) - runTest("single message to unaryToUnary with compression")(singleUnaryToUnary(compressionOps)) - - private[this] def singleUnaryToUnary( - options: ServerOptions = ServerOptions.default - ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => - val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) - val listener = handler.startCall(dummy, new Metadata()) + private def startCall( + implement: Fs2ServerCallHandler[IO] => ServerCallHandler[String, Int], + serverOptions: ServerOptions = ServerOptions.default + )(call: ServerCall[String, Int], thunk: ServerCall.Listener[String] => IO[Unit]): IO[Unit] = + for { + releaseRef <- IO.ref[IO[Unit]](IO.unit) + startBarrier <- Deferred[IO, Unit] + tc <- TestControl.execute { + for { + allocated <- Dispatcher[IO].map(Fs2ServerCallHandler[IO](_, serverOptions)).allocated + (handler, release) = allocated + _ <- releaseRef.set(release) + listener <- IO(implement(handler).startCall(call, new Metadata())) + _ <- startBarrier.get + _ <- IO.defer(thunk(listener)) + } yield () + } + _ <- tc.tick + _ <- startBarrier.complete(()) + _ <- tc.tickAll + _ <- releaseRef.get + } yield () + + private def syncCall( + fs: (ServerCall.Listener[String] => Unit)* + ): ServerCall.Listener[String] => IO[Unit] = + listener => IO(fs.foreach(_.apply(listener))) + + test("unaryToUnary with compression") { + testCompression(_.unaryToUnaryCall((req, _) => IO(req.length))) + } - listener.onMessage("123") - listener.onHalfClose() - tc.tick() + test("unaryToStream with compression") { + testCompression(_.unaryToStreamingCall((req, _) => Stream.emit(req.length).repeatN(5))) + } - assertEquals(dummy.messages.size, 1) - assertEquals(dummy.messages(0), 3) - assertEquals(dummy.currentStatus.isDefined, true) - assertEquals(dummy.currentStatus.get.isOk, true) + test("streamToUnary with compression") { + testCompression(_.streamingToUnaryCall((req, _) => req.compile.foldMonoid.map(_.length))) } - runTest("cancellation for unaryToUnary") { (tc, d) => + test("streamToStream with compression")( + testCompression(_.streamingToStreamingCall((req, _) => req.map(_.length))) + ) + + private def testCompression(sync: Fs2ServerCallHandler[IO] => ServerCallHandler[String, Int]): IO[Unit] = { val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), ServerOptions.default, d) - val listener = handler.startCall(dummy, new Metadata()) + startCall(sync, compressionOps)(dummy, _ => IO.unit) >> IO { + assertEquals(dummy.explicitCompressor, Some("gzip")) + } + } - listener.onCancel() - tc.tick() + test("single message to unaryToUnary") { + val dummy = new DummyServerCall + startCall(_.unaryToUnaryCall((req, _) => IO(req.length)))( + dummy, + syncCall(_.onMessage("123"), _.onHalfClose()) + ) >> IO { + assertEquals(dummy.explicitCompressor, None) + assertEquals(dummy.messages.size, 1) + assertEquals(dummy.messages(0), 3) + assertEquals(dummy.currentStatus.isDefined, true) + assertEquals(dummy.currentStatus.get.isOk, true) + } + } - assertEquals(dummy.currentStatus, None) - assertEquals(dummy.messages.length, 0) + test("cancellation for unaryToUnary") { + val dummy = new DummyServerCall + startCall(_.unaryToUnaryCall((req, _) => IO(req.length)))( + dummy, + syncCall(_.onCancel()) + ) >> IO { + assertEquals(dummy.currentStatus, None) + assertEquals(dummy.messages.length, 0) + } } runTest("cancellation on the fly for unaryToUnary") { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]( - (req, _) => IO(req.length).delayBy(10.seconds), - ServerOptions.default, - d - ) - val listener = handler.startCall(dummy, new Metadata()) + val listener = Fs2ServerCallHandler[IO](d, ServerOptions.default) + .unaryToUnaryCall[String, Int]((req, _) => IO(req.length).delayBy(10.seconds)) + .startCall(dummy, new Metadata()) listener.onMessage("123") listener.onHalfClose() @@ -93,8 +134,9 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) - val listener = handler.startCall(dummy, new Metadata()) + val listener = Fs2ServerCallHandler[IO](d, options) + .unaryToUnaryCall[String, Int]((req, _) => IO(req.length)) + .startCall(dummy, new Metadata()) listener.onMessage("123") listener.onMessage("456") @@ -104,20 +146,14 @@ class ServerSuite extends Fs2GrpcSuite { assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.INTERNAL)) } - runTest("no messages to unaryToUnary")(noMessageUnaryToUnary()) - runTest("no messages to unaryToUnary with compression")(noMessageUnaryToUnary(compressionOps)) - - private def noMessageUnaryToUnary( - options: ServerOptions = ServerOptions.default - ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => + test("no messages to unaryToUnary") { val dummy = new DummyServerCall - val handler = Fs2UnaryServerCallHandler.unary[IO, String, Int]((req, _) => IO(req.length), options, d) - val listener = handler.startCall(dummy, new Metadata()) - - listener.onHalfClose() - tc.tick() - - assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.INTERNAL)) + startCall(_.unaryToUnaryCall((req, _) => IO(req.length)))( + dummy, + syncCall(_.onHalfClose()) + ) >> IO { + assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.INTERNAL)) + } } runTest0("resource awaits termination of server") { (tc, r, _) => @@ -138,9 +174,9 @@ class ServerSuite extends Fs2GrpcSuite { options: ServerOptions = ServerOptions.default ): (TestContext, Dispatcher[IO]) => Unit = { (tc, d) => val dummy = new DummyServerCall - val handler = - Fs2UnaryServerCallHandler.stream[IO, String, Int]((s, _) => Stream(s).map(_.length).repeat.take(5), options, d) - val listener = handler.startCall(dummy, new Metadata()) + val listener = Fs2ServerCallHandler[IO](d, options) + .unaryToStreamingCall[String, Int]((s, _) => Stream(s).map(_.length).repeat.take(5)) + .startCall(dummy, new Metadata()) listener.onMessage("123") listener.onHalfClose() @@ -168,19 +204,16 @@ class ServerSuite extends Fs2GrpcSuite { assertEquals(dummy.currentStatus.get.isOk, true) } - runTest("cancellation for streamingToStreaming") { (tc, d) => + test("cancellation for streamingToStreaming") { val dummy = new DummyServerCall - val handler = Fs2ServerCallHandler[IO](d, ServerOptions.default) - .streamingToStreamingCall[String, Int]((_, _) => - Stream.emit(3).repeat.take(5).zipLeft(Stream.awakeDelay[IO](1.seconds)) - ) - val listener = handler.startCall(dummy, new Metadata()) - - tc.tick() - listener.onCancel() - tc.tick() - - assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.CANCELLED)) + startCall( + _.streamingToStreamingCall((_, _) => Stream.emit(3).repeat.take(5).zipLeft(Stream.awakeDelay[IO](1.seconds))) + )( + dummy, + syncCall(_.onCancel()) + ) >> IO { + assertEquals(dummy.currentStatus.map(_.getCode), Some(Status.Code.CANCELLED)) + } } runTest("messages to streamingToStreaming")(multipleStreamingToStreaming())