diff --git a/README.md b/README.md index 8c5e867a..3cc9585f 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,10 @@ [![Maven Central](https://img.shields.io/maven-central/v/com.softwaremill.ox/core_3)](https://central.sonatype.com/artifact/com.softwaremill.ox/core_3) [![ScalaDoc](https://javadoc.io/badge2/com.softwaremill.ox/core_3/ScalaDoc.svg)](https://javadoc.io/doc/com.softwaremill.ox/core_3) -Safe direct-style streaming, concurrency and resiliency for Scala on the JVM. Requires JDK 21+ & Scala 3. Ox covers -the following areas: +Safe direct-style streaming, concurrency and resiliency for Scala on the JVM. +Requires JDK 21+ & Scala 3. Experimental support for Scala Native. + +Ox covers the following areas: * streaming: push-based backpressured streaming designed for direct-style, with a rich set of stream transformations, flexible stream source & sink definitions and reactive streams integration @@ -22,13 +24,13 @@ preserving developer-friendly stack traces, and without compromising performance To use Ox, add the following dependency, using either [sbt](https://www.scala-sbt.org): ```scala -"com.softwaremill.ox" %% "core" % "1.0.4" +"com.softwaremill.ox" %%% "core" % "1.0.4" ``` Or [scala-cli](https://scala-cli.virtuslab.org): ```scala -//> using dep "com.softwaremill.ox::core:1.0.4" +//> using dep "com.softwaremill.ox:::core:1.0.4" ``` Documentation is available at [https://ox.softwaremill.com](https://ox.softwaremill.com), ScalaDocs can be browsed at [https://javadoc.io](https://www.javadoc.io/doc/com.softwaremill.ox). @@ -265,4 +267,4 @@ We offer commercial development services. [Contact us](https://softwaremill.com) ## Copyright -Copyright (C) 2023-2025 SoftwareMill [https://softwaremill.com](https://softwaremill.com). +Copyright (C) 2023-2026 SoftwareMill [https://softwaremill.com](https://softwaremill.com). diff --git a/build.sbt b/build.sbt index 9008d764..f292a83b 100644 --- a/build.sbt +++ b/build.sbt @@ -2,10 +2,13 @@ import com.softwaremill.SbtSoftwareMillCommon.commonSmlBuildSettings import com.softwaremill.Publish.{ossPublishSettings, updateDocs} import com.softwaremill.UpdateVersionInDocs import com.typesafe.tools.mima.core.{MissingClassProblem, ProblemFilters} +import scalanative.build._ + +lazy val scala3 = "3.3.7" lazy val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq( organization := "com.softwaremill.ox", - scalaVersion := "3.3.7", + scalaVersion := scala3, updateDocs := Def.taskDyn { val files1 = UpdateVersionInDocs(sLog.value, organization.value, version.value) Def.task { @@ -50,21 +53,53 @@ compileDocumentation := { lazy val rootProject = (project in file(".")) .settings(commonSettings) .settings(publishArtifact := false, name := "ox") - .aggregate(core, kafka, mdcLogback, flowReactiveStreams, cron, otelContext) + .aggregate(core.projectRefs ++ examples.projectRefs ++ Seq[ProjectReference](kafka, mdcLogback, flowReactiveStreams, cron, otelContext): _*) -lazy val core: Project = (project in file("core")) +lazy val core = (projectMatrix in file("core")) .settings(commonSettings) .settings( name := "core", - libraryDependencies ++= Seq( - "com.softwaremill.jox" % "channels" % "1.1.2", - scalaTest, - "org.apache.pekko" %% "pekko-stream" % "1.6.0" % Test, - "org.reactivestreams" % "reactive-streams-tck-flow" % "1.0.4" % Test - ), - Test / fork := true + libraryDependencies += "org.scalatest" %%% "scalatest" % "3.2.20" % Test + ) + .jvmPlatform( + scalaVersions = Seq(scala3), + settings = enableMimaSettings ++ Seq( + Test / fork := true, + libraryDependencies ++= Seq( + "com.softwaremill.jox" % "channels" % "1.1.2", + "org.apache.pekko" %% "pekko-stream" % "1.6.0" % Test, + "org.reactivestreams" % "reactive-streams-tck-flow" % "1.0.4" % Test + ) + ) + ) + .nativePlatform( + scalaVersions = Seq(scala3), + settings = Seq( + Test / fork := false, + nativeConfig ~= { _.withMultithreading(true) } + ) ) - .settings(enableMimaSettings) + +lazy val examples = (projectMatrix in file("examples")) + .settings(commonSettings) + .settings( + name := "examples", + publishArtifact := false, + Compile / mainClass := Some("VirtualThreadsNativeJvmBenchmark") + ) + .jvmPlatform( + scalaVersions = Seq(scala3), + settings = Seq( + assembly / assemblyJarName := "examples-assembly.jar" + ) + ) + .nativePlatform( + scalaVersions = Seq(scala3), + settings = Seq( + nativeConfig ~= { _.withMultithreading(true) } + ) + ) + .dependsOn(core) lazy val kafka: Project = (project in file("kafka")) .settings(commonSettings) @@ -80,7 +115,7 @@ lazy val kafka: Project = (project in file("kafka")) scalaTest ) ) - .dependsOn(core) + .dependsOn(core.jvm(scala3)) lazy val mdcLogback: Project = (project in file("mdc-logback")) .settings(commonSettings) @@ -91,7 +126,7 @@ lazy val mdcLogback: Project = (project in file("mdc-logback")) scalaTest ) ) - .dependsOn(core) + .dependsOn(core.jvm(scala3)) lazy val flowReactiveStreams: Project = (project in file("flow-reactive-streams")) .settings(commonSettings) @@ -102,7 +137,7 @@ lazy val flowReactiveStreams: Project = (project in file("flow-reactive-streams" scalaTest ) ) - .dependsOn(core) + .dependsOn(core.jvm(scala3)) lazy val cron: Project = (project in file("cron")) .settings(commonSettings) @@ -113,7 +148,7 @@ lazy val cron: Project = (project in file("cron")) scalaTest ) ) - .dependsOn(core % "test->test;compile->compile") + .dependsOn(core.jvm(scala3) % "test->test;compile->compile") lazy val otelContext: Project = (project in file("otel-context")) .settings(commonSettings) @@ -124,7 +159,7 @@ lazy val otelContext: Project = (project in file("otel-context")) scalaTest ) ) - .dependsOn(core % "test->test;compile->compile") + .dependsOn(core.jvm(scala3) % "test->test;compile->compile") lazy val documentation: Project = (project in file("generated-doc")) // important: it must not be doc/ .enablePlugins(MdocPlugin) @@ -142,7 +177,7 @@ lazy val documentation: Project = (project in file("generated-doc")) // importan libraryDependencies ++= Seq(logback % Test) ) .dependsOn( - core, + core.jvm(scala3), kafka, mdcLogback, flowReactiveStreams, diff --git a/core/src/main/scala/ox/channels/Channel.scala b/core/src/main/scalajvm/ox/channels/Channel.scala similarity index 100% rename from core/src/main/scala/ox/channels/Channel.scala rename to core/src/main/scalajvm/ox/channels/Channel.scala diff --git a/core/src/main/scala/ox/channels/ChannelClosed.scala b/core/src/main/scalajvm/ox/channels/ChannelClosed.scala similarity index 100% rename from core/src/main/scala/ox/channels/ChannelClosed.scala rename to core/src/main/scalajvm/ox/channels/ChannelClosed.scala diff --git a/core/src/main/scala/ox/channels/select.scala b/core/src/main/scalajvm/ox/channels/select.scala similarity index 100% rename from core/src/main/scala/ox/channels/select.scala rename to core/src/main/scalajvm/ox/channels/select.scala diff --git a/core/src/main/scalanative/ox/channels/Channel.scala b/core/src/main/scalanative/ox/channels/Channel.scala new file mode 100644 index 00000000..1e417f37 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/Channel.scala @@ -0,0 +1,138 @@ +package ox.channels + +import ox.channels.jox as j +import ox.channels.jox.{Channel as JChannel, Select as JSelect, SelectClause as JSelectClause, Sink as JSink, Source as JSource} + +import ChannelClosedUnion.orThrow + +import scala.annotation.unchecked.uncheckedVariance + +// select result: needs to be defined here, as implementations are defined here as well + +/** Results of a [[select]] call, when clauses are passed (instead of a number of [[Source]]s). Each result corresponds to a clause, and can + * be pattern-matched (using a path-dependent type) to inspect which clause was selected. + */ +sealed trait SelectResult[+T]: + def value: T + +/** The result returned in case a [[Default]] clause was selected in [[select]]. */ +case class DefaultResult[T](value: T) extends SelectResult[T] + +// select clauses: needs to be defined here, as implementations are defined here as well + +/** A clause to use as part of [[select]]. Clauses can be created having a channel instance, using [[Source.receiveClause]] and + * [[Sink.sendClause]]. + * + * A clause instance is immutable and can be reused in multiple [[select]] calls. + */ +sealed trait SelectClause[+T]: + private[ox] def delegate: JSelectClause[Any] + type Result <: SelectResult[T] + +/** A default clause, which will be chosen if no other clause can be selected immediately, during a [[select]] call. + * + * There should be at most one default clause, and it should always come last in the list of clauses. + */ +case class Default[T](value: T) extends SelectClause[T]: + override private[ox] def delegate: JSelectClause[Any] = JSelect.defaultClause(() => DefaultResult(value)) + type Result = DefaultResult[T] + +// + +/** A channel source, which can be used to receive values from the channel. See [[Channel]] for more details. */ +trait Source[+T] extends SourceOps[T] with SourceDrainOps[T]: + protected def delegate: JSource[Any] + + case class Received private[channels] (value: T @uncheckedVariance) extends SelectResult[T] + + case class Receive private[channels] (delegate: JSelectClause[Any]) extends SelectClause[T]: + type Result = Received + + def receiveClause: Receive = Receive(delegate.receiveClause(t => Received(t.asInstanceOf[T]))) + + def tryReceive(): Option[T] = tryReceiveOrClosed().orThrow + + def tryReceiveOrClosed(): Option[T] | ChannelClosed = + val r = delegate.tryReceiveOrClosed() + if r == null then None + else + ChannelClosed.fromJox(r.asInstanceOf[AnyRef]) match + case c: ChannelClosed => c + case v: T @unchecked => Some(v) + + def receiveOrClosed(): T | ChannelClosed = ChannelClosed.fromJoxOrT(delegate.receiveOrClosed()) + + def receiveOrDone(): T | ChannelClosed.Done.type = receiveOrClosed() match + case e: ChannelClosed.Error => throw e.toThrowable + case ChannelClosed.Done => ChannelClosed.Done + case t: T @unchecked => t + + def receive(): T = receiveOrClosed().orThrow + + def isClosedForReceive: Boolean = delegate.isClosedForReceive + + def isClosedForReceiveDetail: Option[ChannelClosed] = Option(ChannelClosed.fromJoxOrT(delegate.closedForReceive())) +end Source + +object Source extends SourceCompanionOps + +// + +/** A channel sink, which can be used to send values to the channel. See [[Channel]] for more details. */ +trait Sink[-T]: + protected def delegate: JSink[Any] + + case class Sent private[channels] () extends SelectResult[Unit]: + override def value: Unit = () + + case class Send private[channels] (delegate: JSelectClause[Any]) extends SelectClause[Unit]: + type Result = Sent + + def sendClause(t: T): Send = Send(delegate.asInstanceOf[JSink[T]].sendClause(t, () => Sent())) + + def trySend(t: T): Boolean = trySendOrClosed(t).orThrow + + def trySendOrClosed(t: T): Boolean | ChannelClosed = + val r = delegate.asInstanceOf[JSink[T]].trySendOrClosed(t) + if r == null then true + else + ChannelClosed.fromJox(r.asInstanceOf[AnyRef]) match + case c: ChannelClosed => c + case _ => false + + def sendOrClosed(t: T): Unit | ChannelClosed = + val r = ChannelClosed.fromJoxOrUnit(delegate.asInstanceOf[JSink[T]].sendOrClosed(t)) + if r == null then () else r + + def send(t: T): Unit = sendOrClosed(t).orThrow + + def errorOrClosed(reason: Throwable): Unit | ChannelClosed = ChannelClosed.fromJoxOrUnit(delegate.errorOrClosed(reason)) + + def error(reason: Throwable): Unit = errorOrClosed(reason).orThrow + + def doneOrClosed(): Unit | ChannelClosed = ChannelClosed.fromJoxOrUnit(delegate.doneOrClosed()) + + def done(): Unit = doneOrClosed().orThrow + + def isClosedForSend: Boolean = delegate.isClosedForSend + + def isClosedForSendDetail: Option[ChannelClosed] = Option(ChannelClosed.fromJoxOrT(delegate.closedForSend())) +end Sink + +// + +class Channel[T] private (capacity: Int) extends Source[T] with Sink[T]: + protected override val delegate: JChannel[Any] = capacity match + case 0 => JChannel.newRendezvousChannel() + case -1 => JChannel.newUnlimitedChannel() + case _ => JChannel.newBufferedChannel(capacity) + + override def toString: String = delegate.toString + +object Channel: + def bufferedDefault[T]: Channel[T] = BufferCapacity.newChannel[T] + def buffered[T](capacity: Int): Channel[T] = new Channel(capacity) + def rendezvous[T]: Channel[T] = new Channel(0) + def unlimited[T]: Channel[T] = new Channel(-1) + def withCapacity[T](capacity: Int): Channel[T] = new Channel(capacity) +end Channel diff --git a/core/src/main/scalanative/ox/channels/ChannelClosed.scala b/core/src/main/scalanative/ox/channels/ChannelClosed.scala new file mode 100644 index 00000000..64535622 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/ChannelClosed.scala @@ -0,0 +1,28 @@ +package ox.channels + +import ox.channels.jox as j + +/** Returned by channel methods (e.g. [[Source.receiveOrClosed]], [[Sink.sendOrClosed]], [[selectOrClosed]]) when the channel is closed. */ +sealed trait ChannelClosed: + def toThrowable: Throwable = this match + case ChannelClosed.Error(reason) => ChannelClosedException.Error(reason) + case ChannelClosed.Done => ChannelClosedException.Done() + +object ChannelClosed: + case class Error(reason: Throwable) extends ChannelClosed + case object Done extends ChannelClosed + + private[ox] def fromJoxOrT[T](joxResult: AnyRef): T | ChannelClosed = fromJox(joxResult).asInstanceOf[T | ChannelClosed] + private[ox] def fromJoxOrUnit(joxResult: AnyRef): Unit | ChannelClosed = + if joxResult == null then () else fromJox(joxResult).asInstanceOf[ChannelClosed] + + private[ox] def fromJox(joxResult: AnyRef): AnyRef | ChannelClosed = + joxResult match + case _: j.ChannelDone => Done + case e: j.ChannelError => Error(e.cause) + case _ => joxResult +end ChannelClosed + +enum ChannelClosedException(cause: Option[Throwable]) extends Exception(cause.orNull): + case Error(cause: Throwable) extends ChannelClosedException(Some(cause)) + case Done() extends ChannelClosedException(None) diff --git a/core/src/main/scalanative/ox/channels/jox/CellState.scala b/core/src/main/scalanative/ox/channels/jox/CellState.scala new file mode 100644 index 00000000..7ab3d092 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/CellState.scala @@ -0,0 +1,43 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/Channel.java +// (inner enums/classes: CellState, SendResult, ReceiveResult, ExpandBufferResult, ContinuationMarker, ChannelClosedMarker, SentClauseMarker; +// plus RestartSelectMarker, SelectState, TimeoutMarker from Select.java) + +// Possible states of a cell: one of these enum constants, Continuation, StoredSelectClause, or a buffered value +enum CellState: + case DONE + case INTERRUPTED_SEND // the send/receive differentiation is important for expandBuffer + case INTERRUPTED_RECEIVE + case BROKEN + case IN_BUFFER // used to inform a potentially concurrent sender that the cell is now in the buffer + case RESUMING // expandBuffer is resuming a sender + case CLOSED +end CellState + +enum SendResult: + case AWAITED, BUFFERED, RESUMED, FAILED, CLOSED + +enum ReceiveResult: + case FAILED, CLOSED + +enum ExpandBufferResult: + case DONE, FAILED, CLOSED + +enum ContinuationMarker: + case INTERRUPTED + +enum ChannelClosedMarker: + case CLOSED + +enum SentClauseMarker: + case SENT + +enum RestartSelectMarker: + case RESTART + +enum SelectState: + case REGISTERING, INTERRUPTED + +enum TimeoutMarker: + case INSTANCE diff --git a/core/src/main/scalanative/ox/channels/jox/Channel.scala b/core/src/main/scalanative/ox/channels/jox/Channel.scala new file mode 100644 index 00000000..6a980457 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/Channel.scala @@ -0,0 +1,703 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/Channel.java + +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} +import java.util.concurrent.locks.LockSupport + +final class Channel[T] private (val capacity: Int) extends Source[T] with Sink[T]: + import CellState.* + import Channel.{getSendersCounter, isClosed, setClosedFlag, TRY_SEND_NOT_SENT} + import Segment.{SEGMENT_SIZE, NULL_SEGMENT, findAndMoveForward} + + val isRendezvous: Boolean = capacity == 0 + private inline def isUnlimited: Boolean = capacity < 0 + + private val sendersAndClosedFlag = new AtomicLong(0L) + private val receivers = new AtomicLong(0L) + private val bufferEnd = new AtomicLong(capacity.toLong) + + private val sendSegment: AtomicReference[Segment] = new AtomicReference(null) + private val receiveSegment: AtomicReference[Segment] = new AtomicReference(null) + private val bufferEndSegment: AtomicReference[Segment] = new AtomicReference(null) + private val closedReason: AtomicReference[ChannelClosed | Null] = new AtomicReference(null) + + locally: + val isRendezvousOrUnlimited = isRendezvous || isUnlimited + val firstSegment = new Segment(0, null, if isRendezvousOrUnlimited then 2 else 3, isRendezvousOrUnlimited) + sendSegment.set(firstSegment) + receiveSegment.set(firstSegment) + bufferEndSegment.set(if isRendezvousOrUnlimited then NULL_SEGMENT else firstSegment) + processInitialBuffer() + + private def processInitialBuffer(): Unit = + var currentSegment = bufferEndSegment.get() + val segmentsToProcess = + if capacity <= 0 then 0 + else ((capacity + SEGMENT_SIZE - 1L) / SEGMENT_SIZE).toInt + + var segmentId = 0 + while segmentId < segmentsToProcess do + currentSegment = findAndMoveForward(bufferEndSegment, currentSegment, segmentId.toLong).nn + val cellsToProcess = + val rem = if segmentId == segmentsToProcess - 1 then (capacity % SEGMENT_SIZE) else SEGMENT_SIZE + if rem == 0 then SEGMENT_SIZE else rem + currentSegment.setup_markCellsProcessed(cellsToProcess) + segmentId += 1 + end processInitialBuffer + + // ******* + // Sending + // ******* + + @throws[InterruptedException] + override def send(value: T): Unit = + val r = sendOrClosed(value) + r match + case c: ChannelClosed => throw c.toException() + case _ => + + @throws[InterruptedException] + override def sendOrClosed(value: T): AnyRef = + doSend(value, null, null) + + /** Returns null when sent, ChannelClosed when closed, or StoredSelectClause if select is provided. */ + private def doSend(value: T, select: SelectInstance | Null, selectClause: SelectClause[?] | Null): AnyRef = + if value == null then throw new NullPointerException() + while true do + val segment = sendSegment.get() + val scf = sendersAndClosedFlag.getAndAdd(1L) + val s = getSendersCounter(scf) + + val id = s / SEGMENT_SIZE + val i = (s % SEGMENT_SIZE).toInt + + var seg = segment + if segment.getId != id then + seg = findAndMoveForward(sendSegment, segment, id) + if seg == null then return closedReason.get().nn + + if seg.getId != id then + sendersAndClosedFlag.compareAndSet(s, seg.getId * SEGMENT_SIZE) + // continue - skipping interrupted cells + else if isClosed(scf) then return closedReason.get().nn + else + val sendResult = updateCellSend(seg, i, s, value, select, selectClause, true) + sendResult match + case SendResult.BUFFERED | SendResult.AWAITED => return null + case SendResult.RESUMED => + seg.cleanPrev() + return null + case ss: StoredSelectClause => return ss + case SendResult.FAILED => + seg.cleanPrev() + // continue - trying with a new cell + case SendResult.CLOSED => return closedReason.get().nn + case _ => throw new IllegalStateException(s"Unexpected result: $sendResult in channel: $this") + end match + end if + else if isClosed(scf) then return closedReason.get().nn + else + val sendResult = updateCellSend(seg, i, s, value, select, selectClause, true) + sendResult match + case SendResult.BUFFERED | SendResult.AWAITED => return null + case SendResult.RESUMED => + seg.cleanPrev() + return null + case ss: StoredSelectClause => return ss + case SendResult.FAILED => + seg.cleanPrev() + // continue - trying with a new cell + case SendResult.CLOSED => return closedReason.get().nn + case _ => throw new IllegalStateException(s"Unexpected result: $sendResult in channel: $this") + end match + end if + end while + throw new AssertionError("unreachable") + end doSend + + // Non-blocking send + override def trySendOrClosed(value: T): AnyRef = + if value == null then throw new NullPointerException() + while true do + val segment = sendSegment.get() + val scf = sendersAndClosedFlag.get() + val s = getSendersCounter(scf) + + if isClosed(scf) then return closedReason.get().nn + + // capacity pre-check + if capacity >= 0 then + val bufEnd = bufferEnd.get() + val r = receivers.get() + if capacity == 0 then + if s >= r then return Channel.TRY_SEND_NOT_SENT + else if s >= bufEnd && s >= r then return Channel.TRY_SEND_NOT_SENT + + if !sendersAndClosedFlag.compareAndSet(scf, scf + 1) then () // continue + else + val id = s / SEGMENT_SIZE + val i = (s % SEGMENT_SIZE).toInt + + var seg = segment + if segment.getId != id then + seg = findAndMoveForward(sendSegment, segment, id) + if seg == null then return closedReason.get().nn + if seg.getId != id then + sendersAndClosedFlag.compareAndSet(s + 1, seg.getId * SEGMENT_SIZE) + () // continue + else + val r = finishTrySend(seg, i, s, value) + if r ne RETRY_SENTINEL then return r + else + val r = finishTrySend(seg, i, s, value) + if r ne RETRY_SENTINEL then return r + end if + end if + end while + throw new AssertionError("unreachable") + end trySendOrClosed + + private val RETRY_SENTINEL: AnyRef = new AnyRef + + /** Returns result or RETRY_SENTINEL to indicate the caller should loop. */ + private def finishTrySend(segment: Segment, i: Int, s: Long, value: T): AnyRef = + val sendResult = + try updateCellSend(segment, i, s, value, null, null, false) + catch case e: InterruptedException => throw new AssertionError("unreachable: non-blocking send", e) + sendResult match + case SendResult.BUFFERED => null + case SendResult.RESUMED => + segment.cleanPrev() + null + case SendResult.FAILED => + segment.cleanPrev() + RETRY_SENTINEL + case SendResult.CLOSED => closedReason.get() + case r if r eq Channel.TRY_SEND_NOT_SENT => Channel.TRY_SEND_NOT_SENT + case _ => throw new IllegalStateException(s"Unexpected result: $sendResult") + end match + end finishTrySend + + // Non-blocking receive + override def tryReceiveOrClosed(): AnyRef = + while true do + val scf = sendersAndClosedFlag.get() + val s = getSendersCounter(scf) + val r = receivers.get() + + if s <= r then + if isClosed(scf) then return closedForReceive() + else return null + + val segment = receiveSegment.get() + if !receivers.compareAndSet(r, r + 1) then () // continue + else + val id = r / SEGMENT_SIZE + val i = (r % SEGMENT_SIZE).toInt + + var seg = segment + if segment.getId != id then + seg = findAndMoveForward(receiveSegment, segment, id) + if seg == null then return closedReason.get().nn + if seg.getId != id then + receivers.compareAndSet(r + 1, seg.getId * SEGMENT_SIZE) + () // continue + else + val res = finishTryReceive(seg, i) + if res ne RETRY_SENTINEL then return res + else + val res = finishTryReceive(seg, i) + if res ne RETRY_SENTINEL then return res + end if + end if + end while + throw new AssertionError("unreachable") + end tryReceiveOrClosed + + /** Returns result, null (nothing available), or RETRY_SENTINEL to indicate the caller should loop. */ + private def finishTryReceive(segment: Segment, i: Int): AnyRef | Null = + val r = receivers.get() - 1 // the cell index we just reserved + val result = + try updateCellReceive(segment, i, r, null, null, false) + catch case e: InterruptedException => throw new AssertionError("unreachable: non-blocking receive", e) + if result eq ReceiveResult.CLOSED then closedReason.get() + else if result eq ReceiveResult.FAILED then + segment.cleanPrev() + RETRY_SENTINEL + else if result == null then null + else + segment.cleanPrev() + result + end if + end finishTryReceive + + /** Core send logic. Returns SendResult, TRY_SEND_NOT_SENT, or StoredSelectClause. */ + @throws[InterruptedException] + private def updateCellSend( + segment: Segment, + i: Int, + s: Long, + value: T, + select: SelectInstance | Null, + selectClause: SelectClause[?] | Null, + suspend: Boolean + ): AnyRef = + while true do + val state = segment.getCell(i) + if state == null then + if capacity >= 0 && s >= (if isRendezvous then 0 else bufferEnd.get()) && s >= receivers.get() then + // no receiver, not in buffer + if !suspend then + if segment.casCell(i, null, INTERRUPTED_SEND) then + segment.cellInterruptedSender() + return Channel.TRY_SEND_NOT_SENT + else if select != null then + val storedSelect = new StoredSelectClause(select, segment, i, true, selectClause.nn, value.asInstanceOf[AnyRef]) + if segment.casCell(i, null, storedSelect) then return storedSelect + else + val c = new Continuation(value.asInstanceOf[AnyRef]) + if segment.casCell(i, null, c) then + if c.await(segment, i, isRendezvous) eq ChannelClosedMarker.CLOSED then return SendResult.CLOSED + else return SendResult.AWAITED + else + // receiver in progress or in buffer -> elimination + if segment.casCell(i, null, value.asInstanceOf[AnyRef]) then return SendResult.BUFFERED + else if state eq IN_BUFFER then + if segment.casCell(i, IN_BUFFER, value.asInstanceOf[AnyRef]) then return SendResult.BUFFERED + else + state match + case c: Continuation => + if c.tryResume(value.asInstanceOf[AnyRef]) then + segment.setCell(i, DONE) + return SendResult.RESUMED + else return SendResult.FAILED + case ss: StoredSelectClause => + ss.payload = value.asInstanceOf[AnyRef] + if ss.select.trySelect(ss) then + segment.setCell(i, DONE) + return SendResult.RESUMED + else return SendResult.FAILED + case INTERRUPTED_RECEIVE | BROKEN => + return SendResult.FAILED + case CLOSED => + return SendResult.CLOSED + case _ => + throw new IllegalStateException(s"Unexpected state: $state in channel: $this") + end if + end while + throw new AssertionError("unreachable") + end updateCellSend + + // ********* + // Receiving + // ********* + + @throws[InterruptedException] + override def receive(): T = + val r = receiveOrClosed() + r match + case c: ChannelClosed => throw c.toException() + case _ => r.asInstanceOf[T] + + @throws[InterruptedException] + override def receiveOrClosed(): AnyRef = + doReceive(null, null) + + private def doReceive(select: SelectInstance | Null, selectClause: SelectClause[?] | Null): AnyRef = + while true do + val segment = receiveSegment.get() + val r = receivers.getAndAdd(1L) + + val id = r / SEGMENT_SIZE + val i = (r % SEGMENT_SIZE).toInt + + var seg = segment + if segment.getId != id then + seg = findAndMoveForward(receiveSegment, segment, id) + if seg == null then return closedReason.get().nn + if seg.getId != id then + receivers.compareAndSet(r, seg.getId * SEGMENT_SIZE) + () // continue + else + val result = updateCellReceive(seg, i, r, select, selectClause, true) + if result eq ReceiveResult.CLOSED then return closedReason.get().nn + else + if !result.isInstanceOf[StoredSelectClause] then seg.cleanPrev() + if result ne ReceiveResult.FAILED then return result + end if + else + val result = updateCellReceive(seg, i, r, select, selectClause, true) + if result eq ReceiveResult.CLOSED then return closedReason.get().nn + else + if !result.isInstanceOf[StoredSelectClause] then seg.cleanPrev() + if result ne ReceiveResult.FAILED then return result + end if + end while + throw new AssertionError("unreachable") + end doReceive + + @throws[InterruptedException] + private def updateCellReceive( + segment: Segment, + i: Int, + r: Long, + select: SelectInstance | Null, + selectClause: SelectClause[?] | Null, + suspend: Boolean + ): AnyRef = + while true do + val state = segment.getCell(i) + if state == null || (state eq IN_BUFFER) then + if r >= getSendersCounter(sendersAndClosedFlag.get()) then + if !suspend then + if segment.casCell(i, state, INTERRUPTED_RECEIVE) then + segment.cellInterruptedReceiver() + expandBuffer() + return null + else if select != null then + val storedSelect = new StoredSelectClause(select, segment, i, false, selectClause.nn, null) + if segment.casCell(i, state, storedSelect) then + expandBuffer() + return storedSelect + else + val c = new Continuation(null) + if segment.casCell(i, state, c) then + expandBuffer() + val result = c.await(segment, i, isRendezvous) + if result eq ChannelClosedMarker.CLOSED then return ReceiveResult.CLOSED + else return result + else if segment.casCell(i, state, BROKEN) then + expandBuffer() + return ReceiveResult.FAILED + else + state match + case c: Continuation => + if segment.casCell(i, state, RESUMING) then + if c.tryResume(Integer.valueOf(0)) then + segment.setCell(i, DONE) + expandBuffer() + return c.payload + else return ReceiveResult.FAILED + case ss: StoredSelectClause => + if segment.casCell(i, state, RESUMING) then + if ss.select.trySelect(ss) then + segment.setCell(i, DONE) + expandBuffer() + return ss.payload.asInstanceOf[AnyRef] + else return ReceiveResult.FAILED + case cs: CellState => + cs match + case CellState.INTERRUPTED_SEND => return ReceiveResult.FAILED + case CellState.RESUMING => Thread.onSpinWait() + case CellState.CLOSED => return ReceiveResult.CLOSED + case _ => throw new IllegalStateException(s"Unexpected state: $state in channel: $this") + case _ => + // buffered value + segment.setCell(i, DONE) + expandBuffer() + return state.asInstanceOf[AnyRef] + end if + end while + throw new AssertionError("unreachable") + end updateCellReceive + + // **************** + // Buffer expansion + // **************** + + private def expandBuffer(): Unit = + if capacity <= 0 then return + while true do + val segment = bufferEndSegment.get() + val b = bufferEnd.getAndAdd(1L) + + val id = b / SEGMENT_SIZE + val i = (b % SEGMENT_SIZE).toInt + + var seg = segment + if segment.getId != id then + seg = findAndMoveForward(bufferEndSegment, segment, id) + if seg == null then return + if seg.getId != id then + bufferEnd.compareAndSet(b, seg.getId * SEGMENT_SIZE) + () // continue - this cell was an interrupted sender + else + val result = updateCellExpandBuffer(seg, i) + if result == ExpandBufferResult.DONE then + seg.cellProcessed_notInterruptedSender() + return + else if result == ExpandBufferResult.CLOSED then + seg.cellProcessed_notInterruptedSender() + () // continue to mark other closed cells as processed + end if + else + val result = updateCellExpandBuffer(seg, i) + if result == ExpandBufferResult.DONE then + seg.cellProcessed_notInterruptedSender() + return + else if result == ExpandBufferResult.CLOSED then + seg.cellProcessed_notInterruptedSender() + () // continue + end if + end while + end expandBuffer + + private def updateCellExpandBuffer(segment: Segment, i: Int): ExpandBufferResult = + while true do + val state = segment.getCell(i) + if state == null then + if segment.casCell(i, null, IN_BUFFER) then return ExpandBufferResult.DONE + else + state match + case DONE => return ExpandBufferResult.DONE + case c: Continuation if c.isSender => + if segment.casCell(i, state, RESUMING) then + if c.tryResume(Integer.valueOf(0)) then + segment.setCell(i, c.payload) + return ExpandBufferResult.DONE + else return ExpandBufferResult.FAILED + case _: Continuation => return ExpandBufferResult.DONE + case ss: StoredSelectClause if ss.isSender => + if segment.casCell(i, state, RESUMING) then + if ss.select.trySelect(ss) then + segment.setCell(i, ss.payload) + return ExpandBufferResult.DONE + else return ExpandBufferResult.FAILED + case _: StoredSelectClause => return ExpandBufferResult.DONE + case cs: CellState => + cs match + case CellState.INTERRUPTED_SEND => return ExpandBufferResult.FAILED + case CellState.INTERRUPTED_RECEIVE => return ExpandBufferResult.DONE + case CellState.BROKEN => return ExpandBufferResult.DONE + case CellState.RESUMING => Thread.onSpinWait() + case CellState.CLOSED => return ExpandBufferResult.CLOSED + case _ => throw new IllegalStateException(s"Unexpected state: $state in channel: $this") + case _ => + // buffered value + return ExpandBufferResult.DONE + end if + end while + throw new AssertionError("unreachable") + end updateCellExpandBuffer + + // ******* + // Closing + // ******* + + override def done(): Unit = + val r = doneOrClosed() + r match + case c: ChannelClosed => throw c.toException() + case _ => + + override def doneOrClosed(): AnyRef = + closeOrClosed(ChannelDone(this)) + + override def error(reason: Throwable): Unit = + if reason == null then throw new NullPointerException("Error reason cannot be null") + val r = errorOrClosed(reason) + r match + case c: ChannelClosed => throw c.toException() + case _ => + + override def errorOrClosed(reason: Throwable): AnyRef = + closeOrClosed(ChannelError(reason, this)) + + private def closeOrClosed(cc: ChannelClosed): AnyRef = + if !closedReason.compareAndSet(null, cc) then return closedReason.get().nn + + // set closed flag + var scfUpdated = false + var scf = 0L + while !scfUpdated do + val initialScf = sendersAndClosedFlag.get() + scf = setClosedFlag(initialScf) + scfUpdated = sendersAndClosedFlag.compareAndSet(initialScf, scf) + + val lastSender = getSendersCounter(scf) + val lastSegment = sendSegment.get().close() + + cc match + case _: ChannelError => closeCellsUntil(0, lastSegment) + case _ => closeCellsUntil(lastSender, lastSegment) + + if capacity > 0 then + val lastGlobalIndex = (lastSegment.getId + 1) * SEGMENT_SIZE - 1 + while bufferEnd.get() <= lastGlobalIndex do expandBuffer() + + null + end closeOrClosed + + private def closeCellsUntil(lastCellToClose: Long, segment: Segment | Null): Unit = + if segment == null then return + + val lastCellToCloseSegmentId = lastCellToClose / SEGMENT_SIZE + val lastIndexToCloseInSegment = + if lastCellToCloseSegmentId == segment.getId then (lastCellToClose % SEGMENT_SIZE).toInt + else if lastCellToCloseSegmentId < segment.getId then 0 + else return + + var i = SEGMENT_SIZE - 1 + while i >= lastIndexToCloseInSegment do + updateCellClose(segment, i) + i -= 1 + + closeCellsUntil(lastCellToClose, segment.getPrev) + end closeCellsUntil + + private def updateCellClose(segment: Segment, i: Int): Unit = + while true do + val state = segment.getCell(i) + if state == null || (state eq IN_BUFFER) then + if segment.casCell(i, state, CLOSED) then + segment.cellInterruptedReceiver() + return + else + state match + case c: Continuation => + if c.tryResume(ChannelClosedMarker.CLOSED) then + segment.setCell(i, CLOSED) + segment.cellInterruptedReceiver() + return + else Thread.onSpinWait() + case ss: StoredSelectClause => + if ss.select.channelClosed(closedReason.get().nn) then return + else Thread.onSpinWait() + case cs: CellState => + cs match + case CellState.DONE | CellState.BROKEN => return + case CellState.INTERRUPTED_RECEIVE | CellState.INTERRUPTED_SEND => return + case CellState.RESUMING => Thread.onSpinWait() + case _ => throw new IllegalStateException(s"Unexpected state: $state in channel: $this") + case _ => + // buffered value: discarding + if segment.casCell(i, state, CLOSED) then + segment.cellInterruptedReceiver() + return + end if + + override def closedForSend(): ChannelClosed | Null = + if isClosed(sendersAndClosedFlag.get()) then closedReason.get() else null + + override def closedForReceive(): ChannelClosed | Null = + if isClosed(sendersAndClosedFlag.get()) then + val cr = closedReason.get().nn + cr match + case _: ChannelError => cr + case _ => if hasValuesToReceive() then null else cr + else null + + private def hasValuesToReceive(): Boolean = + while true do + val segment = receiveSegment.get() + val r = receivers.get() + val s = getSendersCounter(sendersAndClosedFlag.get()) + if s <= r then return false + + val id = r / SEGMENT_SIZE + val i = (r % SEGMENT_SIZE).toInt + + var seg = segment + if segment.getId != id then + seg = findAndMoveForward(receiveSegment, segment, id) + if seg == null then return false + if seg.getId != id then + receivers.compareAndSet(r, seg.getId * SEGMENT_SIZE) + () // continue + else + seg.cleanPrev() + if hasValueToReceive(seg, i) then return true + else receivers.compareAndSet(r, r + 1) + else + seg.cleanPrev() + if hasValueToReceive(seg, i) then return true + else receivers.compareAndSet(r, r + 1) + end if + end while + false + end hasValuesToReceive + + private def hasValueToReceive(segment: Segment, i: Int): Boolean = + while true do + val state = segment.getCell(i) + if state == null || (state eq IN_BUFFER) then Thread.onSpinWait() + else + state match + case c: Continuation => return c.isSender + case ss: StoredSelectClause => return ss.isSender + case cs: CellState => + cs match + case CellState.INTERRUPTED_SEND | CellState.INTERRUPTED_RECEIVE => return false + case CellState.RESUMING => Thread.onSpinWait() + case CellState.CLOSED => return false + case CellState.DONE | CellState.BROKEN => return false + case _ => throw new IllegalStateException(s"Unexpected state: $state in channel: $this") + case _ => return true // buffered value + end if + end while + false + end hasValueToReceive + + // ************** + // Select clauses + // ************** + + override def receiveClause(): SelectClause[T] = receiveClause(identity) + + override def receiveClause[U](callback: T => U): SelectClause[U] = + val ch = this + new SelectClause[U]: + override private[jox] def getChannel: Channel[?] | Null = ch + override private[jox] def register(select: SelectInstance): AnyRef = + try ch.doReceive(select, this) + catch case e: InterruptedException => throw new IllegalStateException(e) + override private[jox] def transformedRawValue(rawValue: AnyRef): U = + callback(rawValue.asInstanceOf[T]) + end receiveClause + + override def sendClause(value: T): SelectClause[Null] = sendClause(value, () => null) + + override def sendClause[U](value: T, callback: () => U): SelectClause[U] = + val ch = this + new SelectClause[U]: + override private[jox] def getChannel: Channel[?] | Null = ch + override private[jox] def register(select: SelectInstance): AnyRef = + try + val result = ch.doSend(value, select, this) + if result == null then SentClauseMarker.SENT else result + catch case e: InterruptedException => throw new IllegalStateException(e) + override private[jox] def transformedRawValue(rawValue: AnyRef): U = + callback() + end new + end sendClause + + private[jox] def cleanupStoredSelectClause(segment: Segment, i: Int, isSender: Boolean): Unit = + segment.setCell(i, if isSender then INTERRUPTED_SEND else INTERRUPTED_RECEIVE) + if isSender then segment.cellInterruptedSender() + else segment.cellInterruptedReceiver() + + // **** + // Misc + // **** + + override def toString: String = s"Channel(capacity=$capacity)" + +end Channel + +object Channel: + val DEFAULT_BUFFER_SIZE: Int = 16 + val TRY_SEND_NOT_SENT: AnyRef = new AnyRef + + def newRendezvousChannel[T](): Channel[T] = new Channel(0) + def newBufferedChannel[T](capacity: Int): Channel[T] = new Channel(capacity) + def newBufferedDefaultChannel[T](): Channel[T] = new Channel(DEFAULT_BUFFER_SIZE) + def newUnlimitedChannel[T](): Channel[T] = new Channel(-1) + + private val SENDERS_AND_CLOSED_FLAG_SHIFT = 60 + private val SENDERS_COUNTER_MASK = (1L << SENDERS_AND_CLOSED_FLAG_SHIFT) - 1 + + private[jox] def getSendersCounter(scf: Long): Long = scf & SENDERS_COUNTER_MASK + private[jox] def isClosed(scf: Long): Boolean = (scf >> SENDERS_AND_CLOSED_FLAG_SHIFT) == 1 + private[jox] def setClosedFlag(scf: Long): Long = scf | (1L << SENDERS_AND_CLOSED_FLAG_SHIFT) +end Channel diff --git a/core/src/main/scalanative/ox/channels/jox/ChannelClosed.scala b/core/src/main/scalanative/ox/channels/jox/ChannelClosed.scala new file mode 100644 index 00000000..bdd261b9 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/ChannelClosed.scala @@ -0,0 +1,22 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/ChannelClosed.java +// https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/ChannelDone.java +// https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/ChannelError.java +// https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/ChannelClosedException.java + +sealed trait ChannelClosed: + def toException(): ChannelClosedException + def channel: Channel[?] + +case class ChannelDone(override val channel: Channel[?]) extends ChannelClosed: + override def toException(): ChannelClosedException = new ChannelDoneException() + +case class ChannelError(cause: Throwable, override val channel: Channel[?]) extends ChannelClosed: + override def toException(): ChannelClosedException = new ChannelErrorException(cause) + +sealed class ChannelClosedException(cause: Throwable) extends RuntimeException(cause): + def this() = this(null) + +final class ChannelDoneException extends ChannelClosedException() +final class ChannelErrorException(cause: Throwable) extends ChannelClosedException(cause) diff --git a/core/src/main/scalanative/ox/channels/jox/CloseableChannel.scala b/core/src/main/scalanative/ox/channels/jox/CloseableChannel.scala new file mode 100644 index 00000000..1a1cf8a4 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/CloseableChannel.scala @@ -0,0 +1,16 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/CloseableChannel.java + +trait CloseableChannel: + def done(): Unit + def doneOrClosed(): AnyRef + def error(reason: Throwable): Unit + def errorOrClosed(reason: Throwable): AnyRef + + def isClosedForSend: Boolean = closedForSend() != null + def isClosedForReceive: Boolean = closedForReceive() != null + + def closedForSend(): ChannelClosed | Null + def closedForReceive(): ChannelClosed | Null +end CloseableChannel diff --git a/core/src/main/scalanative/ox/channels/jox/Continuation.scala b/core/src/main/scalanative/ox/channels/jox/Continuation.scala new file mode 100644 index 00000000..8155b9e4 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/Continuation.scala @@ -0,0 +1,49 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/Channel.java +// (inner class Continuation, lines 1508-1622) + +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.locks.LockSupport + +final class Continuation(val payload: AnyRef): + private val creatingThread: Thread = Thread.currentThread() + private val data: AtomicReference[AnyRef] = new AtomicReference(null) + + /** `true` if this continuation is for a sender; `false` for a receiver. */ + def isSender: Boolean = payload != null + + /** Resume the continuation with the given value. Returns `true` if successful. */ + def tryResume(value: AnyRef): Boolean = + val result = data.compareAndSet(null, value) + LockSupport.unpark(creatingThread) + result + + /** Await for the continuation to be resumed. May throw InterruptedException. */ + @throws[InterruptedException] + def await(segment: Segment, cellIndex: Int, isRendezvous: Boolean): AnyRef = + var spinIterations = if isRendezvous then Continuation.RENDEZVOUS_SPINS else 0 + while data.get() == null do + if spinIterations > 0 then + Thread.onSpinWait() + spinIterations -= 1 + else + LockSupport.park() + if Thread.interrupted() then + if data.compareAndSet(null, ContinuationMarker.INTERRUPTED) then + val _isSender = isSender + segment.setCell(cellIndex, if _isSender then CellState.INTERRUPTED_SEND else CellState.INTERRUPTED_RECEIVE) + if _isSender then segment.cellInterruptedSender() + else segment.cellInterruptedReceiver() + throw new InterruptedException() + else Thread.currentThread().interrupt() + end if + end while + data.get() + end await +end Continuation + +object Continuation: + val RENDEZVOUS_SPINS: Int = + val nproc = Runtime.getRuntime.availableProcessors() + if nproc == 1 then 0 else if nproc <= 4 then 1 << 7 else 1 << 10 diff --git a/core/src/main/scalanative/ox/channels/jox/README.md b/core/src/main/scalanative/ox/channels/jox/README.md new file mode 100644 index 00000000..8a4969b6 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/README.md @@ -0,0 +1,94 @@ +# Scala Port of Jox Channels + +Pure Scala port of [softwaremill/jox](https://github.com/softwaremill/jox) `channels` module, +enabling Scala Native support. This code lives in `core/native/` and is only compiled for the +Native platform; the JVM continues to use the Java Jox library directly. + +## Source Version + +Ported from: **jox 1.1.2** (`v1.1.2-channels` tag) + +## Target Platform + +Requires: **Scala Native 0.5.12+** (for virtual threads, `java.util.concurrent` atomics, `LockSupport`) + +## Divergences from Java Jox + +### 1. AtomicXxx instead of VarHandle + +Java Jox uses `java.lang.invoke.VarHandle` for all atomic field and array operations. +This port uses: + +| Java Jox | Scala Port | Reason | +|----------|-----------|--------| +| `VarHandle` on `long` fields | `AtomicLong` | VarHandle is a stub in Scala Native (not implemented) | +| `VarHandle` on reference fields | `AtomicReference[T]` | Same | +| `VarHandle` on `int` fields | `AtomicInteger` | Same | +| `MethodHandles.arrayElementVarHandle` | `AtomicReferenceArray` | Same | + +**Impact**: Slightly more indirection (fields are objects rather than plain volatiles with CAS via handles). +Performance difference is negligible for virtual-thread-based workloads. + +**Removal condition**: If Scala Native implements `java.lang.invoke.VarHandle` with `findVarHandle` +and `arrayElementVarHandle`, this port could switch back to VarHandle for parity. + +### 2. Segment.findAndMoveForward signature + +Java Jox passes a `VarHandle` + owning object to `findAndMoveForward` / `moveForward` for generic +field updates. The Scala port passes `AtomicReference[Segment]` directly. + +**Impact**: None on behavior. Slightly less generic but type-safe. + +### 3. No forEach / toList on Source + +The Java `Source` interface has default `forEach` and `toList` methods. These are omitted because +the Ox wrapper (`ox.channels.SourceOps` / `SourceDrainOps`) provides equivalent functionality +with better Scala ergonomics. + +### 4. No Sink.trySend(value, channels...) static method + +The Java `Sink` has a static `trySend` that selects across multiple sinks. Omitted because +Ox exposes this via its own select API. + +### 5. Channel.toString is simplified + +The Java version prints full segment-by-segment cell state. The Scala port returns a short +`Channel(capacity=N)` string. The verbose version can be added if needed for debugging. + +## File Mapping + +| Java Jox file | Scala port file | +|--------------|-----------------| +| `Channel.java` (1640 lines) | `Channel.scala` | +| `Segment.java` | `Segment.scala` | +| `Select.java` | `Select.scala` (includes `SelectInstance`) | +| `SelectClause.java` | `SelectClause.scala` | +| `Source.java` | `Source.scala` | +| `Sink.java` | `Sink.scala` | +| `CloseableChannel.java` | `CloseableChannel.scala` | +| `ChannelClosed.java`, `ChannelDone.java`, `ChannelError.java` | `ChannelClosed.scala` | +| `ChannelClosedException.java`, `*Exception.java` | `ChannelClosed.scala` | +| (inner classes in Channel.java) | `CellState.scala` | +| (inner class in Select.java) | `StoredSelectClause.scala` | +| (inner class in Channel.java) | `Continuation.scala` | + +## Test Mapping + +Tests are in `core/native/src/test/scala/ox/channels/jox/`. They require clang/LLVM to link and run. + +| Java test file | Scala test file | Notes | +|---------------|----------------|-------| +| `TestUtil.java` | `TestUtil.scala` | Scope/fork helpers | +| `ChannelRendezvousTest.java` | `ChannelRendezvousTest.scala` | All tests except perf benchmark | +| `ChannelBufferedTest.java` | `ChannelBufferedTest.scala` | All tests | +| `ChannelUnlimitedTest.java` | `ChannelUnlimitedTest.scala` | All tests | +| `ChannelClosedTest.java` | `ChannelClosedTest.scala` | All tests | +| `ChannelTrySendReceiveTest.java` | `ChannelTrySendReceiveTest.scala` | All tests | +| `ChannelInterruptionTest.java` | `ChannelInterruptionTest.scala` | Core interruption tests (no memory leak tests) | +| `SelectReceiveTest.java` | `SelectTest.scala` | Merged into one file | +| `SelectSendTest.java` | `SelectTest.scala` | Merged into one file | +| `SelectTest.java` | `SelectTest.scala` | Default clause tests | +| `StressTest.java` | — | Not ported (Fray-specific, requires JVM tooling) | +| `SegmentTest.java` | — | Not ported (tests internals via reflection) | +| `SegmentRendezvousTest.java` | — | Not ported (tests internals via reflection) | +| `SelectWithinTest.java` | — | Not ported (uses Thread.ofVirtual for timeout, covered by Ox wrapper tests) | diff --git a/core/src/main/scalanative/ox/channels/jox/Segment.scala b/core/src/main/scalanative/ox/channels/jox/Segment.scala new file mode 100644 index 00000000..e0f0c0e0 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/Segment.scala @@ -0,0 +1,187 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/Segment.java + +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference, AtomicReferenceArray} + +final class Segment( + private val id: Long, + initialPrev: Segment | Null, + pointers: Int, + val isRendezvousOrUnlimited: Boolean +): + import Segment.* + + private val data = new AtomicReferenceArray[AnyRef](SEGMENT_SIZE) + private val nextRef = new AtomicReference[Segment | Null](null) + private val prevRef = new AtomicReference[Segment | Null](initialPrev) + + // bits: [pointers(2)][notProcessed(6)][notInterrupted(6)] + private val pointersNotProcessedNotInterrupted = new AtomicInteger( + SEGMENT_SIZE + + (if isRendezvousOrUnlimited then 0 else SEGMENT_SIZE << PROCESSED_SHIFT) + + (pointers << POINTERS_SHIFT) + ) + + def getId: Long = id + + def cleanPrev(): Unit = prevRef.set(null) + + def getNext: Segment | Null = + val s = nextRef.get() + if s eq CLOSED_SENTINEL then null else s + + def getPrev: Segment | Null = prevRef.get() + + private def setNextIfNull(setTo: Segment): Boolean = + nextRef.compareAndSet(null, setTo) + + def getCell(index: Int): AnyRef | Null = data.get(index) + + def setCell(index: Int, value: AnyRef): Unit = data.set(index, value) + + def casCell(index: Int, expected: AnyRef | Null, newValue: AnyRef): Boolean = + data.compareAndSet(index, expected.asInstanceOf[AnyRef], newValue) + + private def isTail: Boolean = getNext == null + + def isRemoved: Boolean = pointersNotProcessedNotInterrupted.get() == 0 + + def tryIncPointers(): Boolean = + var p = pointersNotProcessedNotInterrupted.get() + while p != 0 do + if pointersNotProcessedNotInterrupted.compareAndSet(p, p + (1 << POINTERS_SHIFT)) then return true + p = pointersNotProcessedNotInterrupted.get() + false + + def decPointers(): Boolean = + val toAdd = -(1 << POINTERS_SHIFT) + var currentP = pointersNotProcessedNotInterrupted.get() + while true do + if pointersNotProcessedNotInterrupted.compareAndSet(currentP, currentP + toAdd) then return (currentP + toAdd) == 0 + currentP = pointersNotProcessedNotInterrupted.get() + false // unreachable + + def cellInterruptedReceiver(): Unit = + if pointersNotProcessedNotInterrupted.getAndDecrement() == 1 then remove() + + def cellInterruptedSender(): Unit = + if isRendezvousOrUnlimited then + if pointersNotProcessedNotInterrupted.getAndDecrement() == 1 then remove() + else if pointersNotProcessedNotInterrupted.getAndAdd(-ONE_PROCESSED_AND_INTERRUPTED) == ONE_PROCESSED_AND_INTERRUPTED then remove() + + def cellProcessed_notInterruptedSender(): Unit = + if pointersNotProcessedNotInterrupted.getAndAdd(-ONE_PROCESSED) == ONE_PROCESSED then remove() + + /** Marks cells as processed during channel setup. Not thread-safe. */ + def setup_markCellsProcessed(numberOfCells: Int): Unit = + pointersNotProcessedNotInterrupted.addAndGet(-ONE_PROCESSED * numberOfCells) + () + + def remove(): Unit = + var continue = true + while continue do + if isTail then return + val _prev = aliveSegmentLeft() + val _next = aliveSegmentRight() + + // link next.prev to _prev + var prevOfNextUpdated = false + while !prevOfNextUpdated do + val currentPrevOfNext = _next.prevRef.get() + if currentPrevOfNext == null then prevOfNextUpdated = true + else prevOfNextUpdated = _next.prevRef.compareAndSet(currentPrevOfNext, _prev) + + if _prev != null then _prev.nextRef.set(_next) + + if _next.isRemoved && !_next.isTail then () // continue loop + else if _prev != null && _prev.isRemoved then () // continue loop + else continue = false + end while + end remove + + def close(): Segment = + var s: Segment = this + while true do + val n = s.nextRef.get() + if n == null then + if s.nextRef.compareAndSet(null, CLOSED_SENTINEL) then return s + else if n eq CLOSED_SENTINEL then return s + else s = n + s // unreachable + end close + + private def aliveSegmentLeft(): Segment | Null = + var s = prevRef.get() + while s != null && s.isRemoved do s = s.prevRef.get() + s + + private def aliveSegmentRight(): Segment = + var n = nextRef.get() + while n.nn.isRemoved && !n.nn.isTail do n = n.nn.nextRef.get() + n.nn + + // for tests + def setNext(newNext: Segment | Null): Unit = nextRef.set(newNext) + + override def toString: String = + val n = nextRef.get() + val p = prevRef.get() + val c = pointersNotProcessedNotInterrupted.get() + val notInterrupted = c & ((1 << PROCESSED_SHIFT) - 1) + val notProcessed = (c & ((1 << POINTERS_SHIFT) - 1)) >> PROCESSED_SHIFT + val ptrs = c >> POINTERS_SHIFT + val nextStr = if n == null then "null" else if n eq CLOSED_SENTINEL then "closed" else n.id.toString + val prevStr = if p == null then "null" else p.id.toString + s"Segment{id=$id, next=$nextStr, prev=$prevStr, pointers=$ptrs, notProcessed=$notProcessed, notInterrupted=$notInterrupted}" + end toString + +end Segment + +object Segment: + val SEGMENT_SIZE: Int = + val env = System.getenv("JOX_SEGMENT_SIZE") + if env != null then Integer.parseInt(env) else 32 + + private val PROCESSED_SHIFT = 6 + private val POINTERS_SHIFT = 12 + private val ONE_PROCESSED = 1 << PROCESSED_SHIFT + private val ONE_PROCESSED_AND_INTERRUPTED = ONE_PROCESSED + 1 + + val NULL_SEGMENT: Segment = new Segment(-1, null, 0, false) + private val CLOSED_SENTINEL: Segment = new Segment(-1, null, 0, false) + + /** Finds or creates a non-removed segment with id >= `id`, and updates `ref` to it. */ + def findAndMoveForward(ref: AtomicReference[Segment], start: Segment, id: Long): Segment | Null = + var continue = true + while continue do + val segment = findSegment(start, id) + if segment == null then return null + if moveForward(ref, segment) then return segment + null // unreachable + + private def findSegment(start: Segment, id: Long): Segment | Null = + var current = start + while current.getId < id || current.isRemoved do + val n = current.nextRef.get() + if n eq CLOSED_SENTINEL then return null + else if n == null then + val newSegment = new Segment(current.getId + 1, current, 0, start.isRendezvousOrUnlimited) + if current.setNextIfNull(newSegment) then if current.isRemoved then current.remove() + else current = n.nn + current + end findSegment + + private def moveForward(ref: AtomicReference[Segment], to: Segment): Boolean = + while true do + val current = ref.get() + if current.getId >= to.getId then return true + if !to.tryIncPointers() then return false + if ref.compareAndSet(current, to) then + if current.decPointers() then current.remove() + return true + else if to.decPointers() then to.remove() + end while + false // unreachable + end moveForward +end Segment diff --git a/core/src/main/scalanative/ox/channels/jox/Select.scala b/core/src/main/scalanative/ox/channels/jox/Select.scala new file mode 100644 index 00000000..a8cc6617 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/Select.scala @@ -0,0 +1,227 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/Select.java + +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.locks.LockSupport +import scala.collection.mutable + +object Select: + @throws[InterruptedException] + def select[U](clauses: SelectClause[? <: U]*): U = + val r = selectOrClosed(clauses*) + r match + case c: ChannelClosed => throw c.toException() + case _ => r.asInstanceOf[U] + + @throws[InterruptedException] + def selectOrClosed[U](clauses: SelectClause[? <: U]*): AnyRef = + if clauses == null || clauses.isEmpty then throw new IllegalArgumentException("No clauses given") + if clauses.exists(_ == null) then throw new IllegalArgumentException("Null clauses are not supported") + while true do + val r = doSelectOrClosed(clauses*) + if r ne RestartSelectMarker.RESTART then return r + throw new AssertionError("unreachable") + end selectOrClosed + + @throws[InterruptedException] + def defaultClause[T](value: T): SelectClause[T] = new DefaultClauseValue(value) + + def defaultClause[T](callback: () => T): SelectClause[T] = new DefaultClauseCallback(callback) + + @throws[InterruptedException] + private def doSelectOrClosed[U](clauses: SelectClause[? <: U]*): AnyRef = + // short-circuit if any channel is in error + val anyError = getAnyChannelInError(clauses) + if anyError != null then return anyError + + val allRendezvous = verifyChannelsUnique_getAreAllRendezvous(clauses) + val si = new SelectInstance(clauses.size) + var i = 0 + var done = false + while i < clauses.size && !done do + val clause = clauses(i) + clause match + case _: DefaultClause[?] if i != clauses.size - 1 => + throw new IllegalArgumentException("The default clause can only be the last one.") + case _ => + if !si.register(clause) then done = true + i += 1 + end while + si.checkStateAndWait(allRendezvous) + end doSelectOrClosed + + private def verifyChannelsUnique_getAreAllRendezvous(clauses: Seq[SelectClause[?]]): Boolean = + var allRendezvous = true + var i = 0 + while i < clauses.size do + val chi = clauses(i).getChannel + var j = i + 1 + while j < clauses.size do + if (chi ne null) && (chi eq clauses(j).getChannel) then + throw new IllegalArgumentException(s"Channel $chi is used in multiple clauses") + j += 1 + allRendezvous = allRendezvous && (chi == null || chi.isRendezvous) + i += 1 + end while + allRendezvous + end verifyChannelsUnique_getAreAllRendezvous + + private def getAnyChannelInError(clauses: Seq[SelectClause[?]]): ChannelError | Null = + for clause <- clauses do + val ch = clause.getChannel + if ch != null then + ch.closedForSend() match + case ce: ChannelError => return ce + case _ => + null + end getAnyChannelInError +end Select + +private[jox] final class SelectInstance(clausesCount: Int): + private val state: AtomicReference[AnyRef] = new AtomicReference(SelectState.REGISTERING) + private val storedClauses = mutable.ArrayBuffer.empty[StoredSelectClause] + private var resultSelectedDuringRegistration: AnyRef = _ + + def register[U](clause: SelectClause[U]): Boolean = + val result = clause.register(this) + result match + case ss: StoredSelectClause => + storedClauses += ss + true + case cc: ChannelClosed => + state.set(cc) + false + case _ => + // clause was selected immediately + resultSelectedDuringRegistration = result + state.set(clause) + false + end match + end register + + @throws[InterruptedException] + def checkStateAndWait(allRendezvous: Boolean): AnyRef = + while true do + val currentState = state.get() + currentState match + case SelectState.REGISTERING => + val currentThread = Thread.currentThread() + if state.compareAndSet(SelectState.REGISTERING, currentThread) then + var spinIterations = if allRendezvous then Continuation.RENDEZVOUS_SPINS else 0 + while state.get() eq currentThread do + if spinIterations > 0 then + Thread.onSpinWait() + spinIterations -= 1 + else + LockSupport.park() + if Thread.interrupted() then + if state.compareAndSet(currentThread, SelectState.INTERRUPTED) then + cleanup(null) + throw new InterruptedException() + else Thread.currentThread().interrupt() + end while + end if + + case clausesToReRegister: java.util.List[?] => + if state.compareAndSet(currentState, SelectState.REGISTERING) then + val iter = clausesToReRegister.iterator() + var done = false + while iter.hasNext && !done do + val clause = iter.next().asInstanceOf[SelectClause[?]] + // cleanup the stored select for the clause we'll re-register + val storedIter = storedClauses.iterator + var found = false + val newStored = mutable.ArrayBuffer.empty[StoredSelectClause] + for stored <- storedClauses do + if !found && (stored.clause eq clause) then + stored.cleanup() + found = true + else newStored += stored + storedClauses.clear() + storedClauses ++= newStored + + if !register(clause) then done = true + end while + + case selectedClause: SelectClause[?] @unchecked => + cleanup(selectedClause) + return selectedClause.transformedRawValue(resultSelectedDuringRegistration).asInstanceOf[AnyRef] + + case ss: StoredSelectClause => + val selectedClause = ss.clause + cleanup(selectedClause) + return selectedClause.transformedRawValue(ss.payload.asInstanceOf[AnyRef]).asInstanceOf[AnyRef] + + case cc: ChannelClosed => + cleanup(null) + return cc + + case _ => + throw new IllegalStateException(s"Unknown state: $currentState") + end match + end while + throw new AssertionError("unreachable") + end checkStateAndWait + + private def cleanup(selected: SelectClause[?] | Null): Unit = + for stored <- storedClauses do if !(stored.clause eq selected) then stored.cleanup() + storedClauses.clear() + + /** Called by another thread to try selecting this clause. */ + def trySelect(storedSelectClause: StoredSelectClause): Boolean = + while true do + val currentState = state.get() + currentState match + case SelectState.REGISTERING => + val list = new java.util.ArrayList[SelectClause[?]](1) + list.add(storedSelectClause.clause) + if state.compareAndSet(currentState, list) then return false + case clausesToReRegister: java.util.List[?] => + val newList = new java.util.ArrayList[SelectClause[?]](clausesToReRegister.size() + 1) + newList.addAll(clausesToReRegister.asInstanceOf[java.util.List[SelectClause[?]]]) + newList.add(storedSelectClause.clause) + if state.compareAndSet(currentState, newList) then return false + case _: SelectClause[?] => + return false // already selected + case _: StoredSelectClause => + return false // already selected + case t: Thread => + if state.compareAndSet(currentState, storedSelectClause) then + LockSupport.unpark(t) + return true + case SelectState.INTERRUPTED => + return false + case _: ChannelClosed => + return false + case _ => + throw new IllegalStateException(s"Unknown state: $currentState") + end match + end while + false // unreachable + end trySelect + + /** Called when a channel is closed. */ + def channelClosed(channelClosed: ChannelClosed): Boolean = + while true do + val currentState = state.get() + currentState match + case SelectState.REGISTERING | (_: java.util.List[?]) => + if state.compareAndSet(currentState, channelClosed) then return true + case _: SelectClause[?] | _: StoredSelectClause => + return false // already selected + case t: Thread => + if state.compareAndSet(currentState, channelClosed) then + LockSupport.unpark(t) + return true + case SelectState.INTERRUPTED => + return false + case _: ChannelClosed => + return false + case _ => + throw new IllegalStateException(s"Unknown state: $currentState") + end match + end while + false // unreachable + end channelClosed +end SelectInstance diff --git a/core/src/main/scalanative/ox/channels/jox/SelectClause.scala b/core/src/main/scalanative/ox/channels/jox/SelectClause.scala new file mode 100644 index 00000000..49ac5a21 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/SelectClause.scala @@ -0,0 +1,23 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/SelectClause.java + +/** A clause to use as part of `Select.select`. */ +abstract class SelectClause[T]: + private[jox] def getChannel: Channel[?] | Null = null + + /** Returns a StoredSelectClause, ChannelClosed, or the selected value (not null). */ + private[jox] def register(select: SelectInstance): AnyRef + + /** Transforms the raw value using the transformation function provided when creating the clause. */ + private[jox] def transformedRawValue(rawValue: AnyRef): T +end SelectClause + +private[jox] abstract class DefaultClause[T] extends SelectClause[T]: + override private[jox] def register(select: SelectInstance): AnyRef = this + +private[jox] final class DefaultClauseValue[T](value: T) extends DefaultClause[T]: + override private[jox] def transformedRawValue(rawValue: AnyRef): T = value + +private[jox] final class DefaultClauseCallback[T](callback: () => T) extends DefaultClause[T]: + override private[jox] def transformedRawValue(rawValue: AnyRef): T = callback() diff --git a/core/src/main/scalanative/ox/channels/jox/Sink.scala b/core/src/main/scalanative/ox/channels/jox/Sink.scala new file mode 100644 index 00000000..eb6d66d2 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/Sink.scala @@ -0,0 +1,17 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/Sink.java + +trait Sink[T] extends CloseableChannel: + @throws[InterruptedException] + def send(value: T): Unit + + @throws[InterruptedException] + def sendOrClosed(value: T): AnyRef + + def trySendOrClosed(value: T): AnyRef + + def sendClause(value: T): SelectClause[Null] + + def sendClause[U](value: T, callback: () => U): SelectClause[U] +end Sink diff --git a/core/src/main/scalanative/ox/channels/jox/Source.scala b/core/src/main/scalanative/ox/channels/jox/Source.scala new file mode 100644 index 00000000..e755e727 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/Source.scala @@ -0,0 +1,17 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/Source.java + +trait Source[T] extends CloseableChannel: + @throws[InterruptedException] + def receive(): T + + @throws[InterruptedException] + def receiveOrClosed(): AnyRef + + def tryReceiveOrClosed(): AnyRef + + def receiveClause(): SelectClause[T] + + def receiveClause[U](callback: T => U): SelectClause[U] +end Source diff --git a/core/src/main/scalanative/ox/channels/jox/StoredSelectClause.scala b/core/src/main/scalanative/ox/channels/jox/StoredSelectClause.scala new file mode 100644 index 00000000..eb2401d1 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/jox/StoredSelectClause.scala @@ -0,0 +1,17 @@ +package ox.channels.jox + +// Ported from: https://github.com/softwaremill/jox/blob/v1.1.2-channels/channels/src/main/java/com/softwaremill/jox/Select.java +// (inner class StoredSelectClause, lines 597-643) + +/** Keeps information about a select instance stored in a channel cell, awaiting completion. */ +private[jox] final class StoredSelectClause( + val select: SelectInstance, + val segment: Segment, + val cellIndex: Int, + val isSender: Boolean, + val clause: SelectClause[?], + @volatile var payload: AnyRef | Null +): + def cleanup(): Unit = + clause.getChannel.nn.cleanupStoredSelectClause(segment, cellIndex, isSender) +end StoredSelectClause diff --git a/core/src/main/scalanative/ox/channels/select.scala b/core/src/main/scalanative/ox/channels/select.scala new file mode 100644 index 00000000..d3c22911 --- /dev/null +++ b/core/src/main/scalanative/ox/channels/select.scala @@ -0,0 +1,444 @@ +package ox.channels + +import ox.channels.jox.Select as JSelect + +import ox.channels.ChannelClosedUnion.{map, orThrow} +import ox.{discard, forkUnsupervised, sleep, unsupervised} +import scala.concurrent.duration.FiniteDuration +import scala.concurrent.TimeoutException + +/** @see [[selectOrClosed(Seq[SelectClause])]]. */ +def selectOrClosed(clause1: SelectClause[?], clause2: SelectClause[?]): clause1.Result | clause2.Result | ChannelClosed = + selectOrClosed(List(clause1, clause2)).asInstanceOf[clause1.Result | clause2.Result | ChannelClosed] + +/** @see [[selectOrClosed(Seq[SelectClause])]]. */ +def selectOrClosed( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result | ChannelClosed = + selectOrClosed(List(clause1, clause2, clause3)).asInstanceOf[clause1.Result | clause2.Result | clause3.Result | ChannelClosed] + +/** @see [[selectOrClosed(Seq[SelectClause])]]. */ +def selectOrClosed( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?], + clause4: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result | clause4.Result | ChannelClosed = + selectOrClosed(List(clause1, clause2, clause3, clause4)) + .asInstanceOf[clause1.Result | clause2.Result | clause3.Result | clause4.Result | ChannelClosed] + +/** @see [[selectOrClosed(Seq[SelectClause])]]. */ +def selectOrClosed( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?], + clause4: SelectClause[?], + clause5: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result | clause4.Result | clause5.Result | ChannelClosed = + selectOrClosed(List(clause1, clause2, clause3, clause4, clause5)) + .asInstanceOf[clause1.Result | clause2.Result | clause3.Result | clause4.Result | clause5.Result | ChannelClosed] + +def selectOrClosed[T](clauses: Seq[SelectClause[T]]): SelectResult[T] | ChannelClosed = + ChannelClosed.fromJoxOrT(JSelect.selectOrClosed(clauses.map(_.delegate)*)) + +// + +/** @see [[select(Seq[SelectClause])]]. */ +def select(clause1: SelectClause[?], clause2: SelectClause[?]): clause1.Result | clause2.Result = + select(List(clause1, clause2)).asInstanceOf[clause1.Result | clause2.Result] + +/** @see [[select(Seq[SelectClause])]]. */ +def select( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result = + select(List(clause1, clause2, clause3)).asInstanceOf[clause1.Result | clause2.Result | clause3.Result] + +/** @see [[select(Seq[SelectClause])]]. */ +def select( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?], + clause4: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result | clause4.Result = + select(List(clause1, clause2, clause3, clause4)).asInstanceOf[clause1.Result | clause2.Result | clause3.Result | clause4.Result] + +/** @see [[select(Seq[SelectClause])]]. */ +def select( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?], + clause4: SelectClause[?], + clause5: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result | clause4.Result | clause5.Result = + select(List(clause1, clause2, clause3, clause4, clause5)) + .asInstanceOf[clause1.Result | clause2.Result | clause3.Result | clause4.Result | clause5.Result] + +def select[T](clauses: Seq[SelectClause[T]]): SelectResult[T] = selectOrClosed(clauses).orThrow + +// + +def selectOrClosed[T1, T2](source1: Source[T1], source2: Source[T2]): T1 | T2 | ChannelClosed = + selectOrClosed(source1.receiveClause, source2.receiveClause).map { + case source1.Received(v) => v + case source2.Received(v) => v + } + +def selectOrClosed[T1, T2, T3](source1: Source[T1], source2: Source[T2], source3: Source[T3]): T1 | T2 | T3 | ChannelClosed = + selectOrClosed(source1.receiveClause, source2.receiveClause, source3.receiveClause).map { + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + } + +def selectOrClosed[T1, T2, T3, T4](source1: Source[T1], source2: Source[T2], source3: Source[T3], source4: Source[T4]): T1 | T2 | T3 | T4 | + ChannelClosed = + selectOrClosed(source1.receiveClause, source2.receiveClause, source3.receiveClause, source4.receiveClause).map { + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case source4.Received(v) => v + } + +def selectOrClosed[T1, T2, T3, T4, T5]( + source1: Source[T1], + source2: Source[T2], + source3: Source[T3], + source4: Source[T4], + source5: Source[T5] +): T1 | T2 | T3 | T4 | T5 | ChannelClosed = + selectOrClosed(source1.receiveClause, source2.receiveClause, source3.receiveClause, source4.receiveClause, source5.receiveClause).map { + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case source4.Received(v) => v + case source5.Received(v) => v + } + +def selectOrClosed[T](sources: Seq[Source[T]])(using DummyImplicit): T | ChannelClosed = + selectOrClosed(sources.map(_.receiveClause: SelectClause[T])) match + case r: Source[T]#Received => r.value + case c: ChannelClosed => c + case _: Sink[?]#Sent => throw new IllegalStateException() + case _: DefaultResult[?] => throw new IllegalStateException() + +// + +def select[T1, T2](source1: Source[T1], source2: Source[T2]): T1 | T2 = + select(source1.receiveClause, source2.receiveClause) match + case source1.Received(v) => v + case source2.Received(v) => v + +def select[T1, T2, T3](source1: Source[T1], source2: Source[T2], source3: Source[T3]): T1 | T2 | T3 = + select(source1.receiveClause, source2.receiveClause, source3.receiveClause) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + +def select[T1, T2, T3, T4](source1: Source[T1], source2: Source[T2], source3: Source[T3], source4: Source[T4]): T1 | T2 | T3 | T4 = + select(source1.receiveClause, source2.receiveClause, source3.receiveClause, source4.receiveClause) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case source4.Received(v) => v + +def select[T1, T2, T3, T4, T5]( + source1: Source[T1], + source2: Source[T2], + source3: Source[T3], + source4: Source[T4], + source5: Source[T5] +): T1 | T2 | T3 | T4 | T5 = + select(source1.receiveClause, source2.receiveClause, source3.receiveClause, source4.receiveClause, source5.receiveClause) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case source4.Received(v) => v + case source5.Received(v) => v + +def select[T](sources: Seq[Source[T]])(using DummyImplicit): T | ChannelClosed = + selectOrClosed(sources).orThrow + +// + +def selectOrClosedWithin[TV]( + timeout: FiniteDuration, + timeoutValue: TV +)(clause1: SelectClause[?]): TV | clause1.Result | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(List(clause1)) + .asInstanceOf[TV | clause1.Result | ChannelClosed] + +def selectOrClosedWithin[TV]( + timeout: FiniteDuration, + timeoutValue: TV +)(clause1: SelectClause[?], clause2: SelectClause[?]): TV | clause1.Result | clause2.Result | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(List(clause1, clause2)) + .asInstanceOf[TV | clause1.Result | clause2.Result | ChannelClosed] + +def selectOrClosedWithin[TV]( + timeout: FiniteDuration, + timeoutValue: TV +)( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?] +): TV | clause1.Result | clause2.Result | clause3.Result | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(List(clause1, clause2, clause3)) + .asInstanceOf[TV | clause1.Result | clause2.Result | clause3.Result | ChannelClosed] + +def selectOrClosedWithin[TV]( + timeout: FiniteDuration, + timeoutValue: TV +)( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?], + clause4: SelectClause[?] +): TV | clause1.Result | clause2.Result | clause3.Result | clause4.Result | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(List(clause1, clause2, clause3, clause4)) + .asInstanceOf[TV | clause1.Result | clause2.Result | clause3.Result | clause4.Result | ChannelClosed] + +def selectOrClosedWithin[TV]( + timeout: FiniteDuration, + timeoutValue: TV +)( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?], + clause4: SelectClause[?], + clause5: SelectClause[?] +): TV | clause1.Result | clause2.Result | clause3.Result | clause4.Result | clause5.Result | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(List(clause1, clause2, clause3, clause4, clause5)).asInstanceOf[ + TV | clause1.Result | clause2.Result | clause3.Result | clause4.Result | clause5.Result | ChannelClosed + ] + +def selectOrClosedWithin[TV, T]( + timeout: FiniteDuration, + timeoutValue: TV +)(clauses: Seq[SelectClause[T]]): TV | SelectResult[T] | ChannelClosed = + if clauses.isEmpty then timeoutValue + else + unsupervised { + val timeoutChannel = Channel.withCapacity[Unit](1) + + forkUnsupervised { + sleep(timeout) + timeoutChannel.sendOrClosed(()).discard + }.discard + + val clausesWithTimeout = clauses :+ timeoutChannel.receiveClause + + selectOrClosed(clausesWithTimeout) match + case timeoutChannel.Received(_) => timeoutValue + case c: ChannelClosed => c + case r: SelectResult[?] @unchecked => r.asInstanceOf[SelectResult[T]] + end match + } + +// + +def selectOrClosedWithin[TV, T1]( + timeout: FiniteDuration, + timeoutValue: TV +)(source1: Source[T1]): TV | T1 | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(source1.receiveClause) match + case source1.Received(v) => v + case c: ChannelClosed => c + case tv => timeoutValue + +def selectOrClosedWithin[TV, T1, T2]( + timeout: FiniteDuration, + timeoutValue: TV +)(source1: Source[T1], source2: Source[T2]): TV | T1 | T2 | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(source1.receiveClause, source2.receiveClause) match + case source1.Received(v) => v + case source2.Received(v) => v + case c: ChannelClosed => c + case tv => timeoutValue + +def selectOrClosedWithin[TimeoutValue, T1, T2, T3]( + timeout: FiniteDuration, + timeoutValue: TimeoutValue +)(source1: Source[T1], source2: Source[T2], source3: Source[T3]): TimeoutValue | T1 | T2 | T3 | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(source1.receiveClause, source2.receiveClause, source3.receiveClause) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case c: ChannelClosed => c + case tv => timeoutValue + +def selectOrClosedWithin[TV, T1, T2, T3, T4]( + timeout: FiniteDuration, + timeoutValue: TV +)(source1: Source[T1], source2: Source[T2], source3: Source[T3], source4: Source[T4]): TV | T1 | T2 | T3 | T4 | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)( + source1.receiveClause, + source2.receiveClause, + source3.receiveClause, + source4.receiveClause + ) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case source4.Received(v) => v + case c: ChannelClosed => c + case tv => timeoutValue + +def selectOrClosedWithin[TV, T1, T2, T3, T4, T5]( + timeout: FiniteDuration, + timeoutValue: TV +)( + source1: Source[T1], + source2: Source[T2], + source3: Source[T3], + source4: Source[T4], + source5: Source[T5] +): TV | T1 | T2 | T3 | T4 | T5 | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)( + source1.receiveClause, + source2.receiveClause, + source3.receiveClause, + source4.receiveClause, + source5.receiveClause + ) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case source4.Received(v) => v + case source5.Received(v) => v + case c: ChannelClosed => c + case tv => timeoutValue + +def selectOrClosedWithin[TV, T]( + timeout: FiniteDuration, + timeoutValue: TV +)(sources: Seq[Source[T]])(using DummyImplicit): TV | T | ChannelClosed = + selectOrClosedWithin(timeout, timeoutValue)(sources.map(_.receiveClause: SelectClause[T])) match + case r: Source[T]#Received => r.value + case c: ChannelClosed => c + case _: Sink[?]#Sent => throw new IllegalStateException() + case _: DefaultResult[?] => throw new IllegalStateException() + case _: TV @unchecked => timeoutValue + +// + +private object TimeoutMarker + +def selectWithin( + timeout: FiniteDuration +)(clause1: SelectClause[?]): clause1.Result = + selectWithin(timeout)(List(clause1)).asInstanceOf[clause1.Result] + +def selectWithin( + timeout: FiniteDuration +)(clause1: SelectClause[?], clause2: SelectClause[?]): clause1.Result | clause2.Result = + selectWithin(timeout)(List(clause1, clause2)).asInstanceOf[clause1.Result | clause2.Result] + +def selectWithin( + timeout: FiniteDuration +)( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result = + selectWithin(timeout)(List(clause1, clause2, clause3)) + .asInstanceOf[clause1.Result | clause2.Result | clause3.Result] + +def selectWithin( + timeout: FiniteDuration +)( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?], + clause4: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result | clause4.Result = + selectWithin(timeout)(List(clause1, clause2, clause3, clause4)) + .asInstanceOf[clause1.Result | clause2.Result | clause3.Result | clause4.Result] + +def selectWithin( + timeout: FiniteDuration +)( + clause1: SelectClause[?], + clause2: SelectClause[?], + clause3: SelectClause[?], + clause4: SelectClause[?], + clause5: SelectClause[?] +): clause1.Result | clause2.Result | clause3.Result | clause4.Result | clause5.Result = + selectWithin(timeout)(List(clause1, clause2, clause3, clause4, clause5)) + .asInstanceOf[clause1.Result | clause2.Result | clause3.Result | clause4.Result | clause5.Result] + +def selectWithin[T]( + timeout: FiniteDuration +)(clauses: Seq[SelectClause[T]]): SelectResult[T] = + val result = selectOrClosedWithin(timeout, TimeoutMarker)(clauses) + if result == TimeoutMarker then throw new TimeoutException(s"select timed out after $timeout") + else result.asInstanceOf[SelectResult[T] | ChannelClosed].orThrow + +// + +def selectWithin[T1]( + timeout: FiniteDuration +)(source1: Source[T1]): T1 = + selectWithin(timeout)(source1.receiveClause) match + case source1.Received(v) => v + +def selectWithin[T1, T2]( + timeout: FiniteDuration +)(source1: Source[T1], source2: Source[T2]): T1 | T2 = + selectWithin(timeout)(source1.receiveClause, source2.receiveClause) match + case source1.Received(v) => v + case source2.Received(v) => v + +def selectWithin[T1, T2, T3]( + timeout: FiniteDuration +)(source1: Source[T1], source2: Source[T2], source3: Source[T3]): T1 | T2 | T3 = + selectWithin(timeout)(source1.receiveClause, source2.receiveClause, source3.receiveClause) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + +def selectWithin[T1, T2, T3, T4]( + timeout: FiniteDuration +)(source1: Source[T1], source2: Source[T2], source3: Source[T3], source4: Source[T4]): T1 | T2 | T3 | T4 = + selectWithin(timeout)( + source1.receiveClause, + source2.receiveClause, + source3.receiveClause, + source4.receiveClause + ) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case source4.Received(v) => v + +def selectWithin[T1, T2, T3, T4, T5]( + timeout: FiniteDuration +)( + source1: Source[T1], + source2: Source[T2], + source3: Source[T3], + source4: Source[T4], + source5: Source[T5] +): T1 | T2 | T3 | T4 | T5 = + selectWithin(timeout)( + source1.receiveClause, + source2.receiveClause, + source3.receiveClause, + source4.receiveClause, + source5.receiveClause + ) match + case source1.Received(v) => v + case source2.Received(v) => v + case source3.Received(v) => v + case source4.Received(v) => v + case source5.Received(v) => v + +def selectWithin[T]( + timeout: FiniteDuration +)(sources: Seq[Source[T]])(using DummyImplicit): T = + val result = selectOrClosedWithin(timeout, TimeoutMarker)(sources) + if result == TimeoutMarker then throw new TimeoutException(s"select timed out after $timeout") + else result.asInstanceOf[T | ChannelClosed].orThrow diff --git a/core/src/test/scala/ox/util/Trail.scala b/core/src/test/scala/ox/util/Trail.scala index 134b7a70..ff4b8210 100644 --- a/core/src/test/scala/ox/util/Trail.scala +++ b/core/src/test/scala/ox/util/Trail.scala @@ -2,12 +2,12 @@ package ox.util import ox.discard -import java.time.Clock +import java.util.Date import java.util.concurrent.atomic.AtomicReference class Trail(trail: AtomicReference[Vector[String]] = AtomicReference(Vector.empty)): def add(s: String): Unit = - println(s"[${Clock.systemUTC().instant()}] [${Thread.currentThread().threadId()}] $s") + println(s"[${new Date()}] [${Thread.currentThread().threadId()}] $s") trail.updateAndGet(_ :+ s).discard def get: Vector[String] = trail.get diff --git a/core/src/test/scala/ox/flow/FlowCompanionIOOpsTest.scala b/core/src/test/scalajvm/ox/flow/FlowCompanionIOOpsTest.scala similarity index 100% rename from core/src/test/scala/ox/flow/FlowCompanionIOOpsTest.scala rename to core/src/test/scalajvm/ox/flow/FlowCompanionIOOpsTest.scala diff --git a/core/src/test/scala/ox/flow/FlowIOOpsTest.scala b/core/src/test/scalajvm/ox/flow/FlowIOOpsTest.scala similarity index 100% rename from core/src/test/scala/ox/flow/FlowIOOpsTest.scala rename to core/src/test/scalajvm/ox/flow/FlowIOOpsTest.scala diff --git a/core/src/test/scala/ox/flow/reactive/FlowPublisherPekkoTest.scala b/core/src/test/scalajvm/ox/flow/reactive/FlowPublisherPekkoTest.scala similarity index 100% rename from core/src/test/scala/ox/flow/reactive/FlowPublisherPekkoTest.scala rename to core/src/test/scalajvm/ox/flow/reactive/FlowPublisherPekkoTest.scala diff --git a/core/src/test/scala/ox/flow/reactive/FlowPublisherTckTest.scala b/core/src/test/scalajvm/ox/flow/reactive/FlowPublisherTckTest.scala similarity index 100% rename from core/src/test/scala/ox/flow/reactive/FlowPublisherTckTest.scala rename to core/src/test/scalajvm/ox/flow/reactive/FlowPublisherTckTest.scala diff --git a/core/src/test/scalanative/ox/channels/jox/ChannelBufferedTest.scala b/core/src/test/scalanative/ox/channels/jox/ChannelBufferedTest.scala new file mode 100644 index 00000000..1e94a868 --- /dev/null +++ b/core/src/test/scalanative/ox/channels/jox/ChannelBufferedTest.scala @@ -0,0 +1,74 @@ +package ox.channels.jox + +// Ported from: channels/src/test/java/com/softwaremill/jox/ChannelBufferedTest.java (jox 1.1.2) + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import java.util.concurrent.ConcurrentSkipListSet + +class ChannelBufferedTest extends AnyFlatSpec with Matchers: + import TestUtil.* + + "buffered channel" should "send and receive without blocking when buffer has space" in: + val ch = Channel.newBufferedChannel[Int](3) + ch.send(1) + ch.send(2) + ch.send(3) + ch.receive() shouldBe 1 + ch.receive() shouldBe 2 + ch.receive() shouldBe 3 + + it should "block when buffer is full" in scoped { scope => + val ch = Channel.newBufferedChannel[Int](2) + ch.send(1) + ch.send(2) + @volatile var sent = false + forkVoid( + scope, + () => + ch.send(3); sent = true + ) + Thread.sleep(50) + sent shouldBe false + ch.receive() shouldBe 1 + Thread.sleep(50) + sent shouldBe true + ch.receive() shouldBe 2 + ch.receive() shouldBe 3 + } + + // Java version uses 1000; reduced for Scala Native virtual thread scheduling reliability + it should "send and receive in many forks" in scoped { scope => + val n = 100 + val ch = Channel.newBufferedChannel[Int](16) + val s = new ConcurrentSkipListSet[Int]() + + for i <- 1 to n do forkVoid(scope, () => ch.send(i)) + val fs = (1 to n).map(_ => + forkVoid( + scope, + () => + s.add(ch.receive()); () + ) + ) + fs.foreach(_.get()) + s.size() shouldBe n + } + + it should "handle done with buffered values" in: + val ch = Channel.newBufferedChannel[Int](5) + ch.send(1) + ch.send(2) + ch.done() + ch.receive() shouldBe 1 + ch.receive() shouldBe 2 + ch.receiveOrClosed() shouldBe a[ChannelDone] + + it should "discard buffered values on error" in: + val ch = Channel.newBufferedChannel[Int](5) + ch.send(1) + ch.send(2) + val ex = new RuntimeException("test error") + ch.error(ex) + ch.receiveOrClosed() shouldBe a[ChannelError] +end ChannelBufferedTest diff --git a/core/src/test/scalanative/ox/channels/jox/ChannelClosedTest.scala b/core/src/test/scalanative/ox/channels/jox/ChannelClosedTest.scala new file mode 100644 index 00000000..1b7da5f6 --- /dev/null +++ b/core/src/test/scalanative/ox/channels/jox/ChannelClosedTest.scala @@ -0,0 +1,52 @@ +package ox.channels.jox + +// Ported from: channels/src/test/java/com/softwaremill/jox/ChannelClosedTest.java (jox 1.1.2) + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class ChannelClosedTest extends AnyFlatSpec with Matchers: + import TestUtil.* + + "closed channel" should "report closed with no values when error" in: + val c = Channel.newRendezvousChannel[Int]() + val reason = new RuntimeException() + c.error(reason) + c.isClosedForReceive shouldBe true + c.isClosedForSend shouldBe true + c.receiveOrClosed() shouldBe a[ChannelError] + + it should "report closed with no values when done" in: + val c = Channel.newRendezvousChannel[Int]() + c.done() + c.isClosedForReceive shouldBe true + c.isClosedForSend shouldBe true + c.receiveOrClosed() shouldBe a[ChannelDone] + + it should "not be closed for receive when done with suspended sender" in scoped { scope => + val c = Channel.newRendezvousChannel[Int]() + val f = forkCancelable(scope, () => c.send(1)) + try + Thread.sleep(100) + c.done() + c.isClosedForReceive shouldBe false + c.isClosedForSend shouldBe true + finally f.cancel() + } + + it should "not be closed for receive when done with buffered values" in: + val c = Channel.newBufferedChannel[Int](5) + c.send(1) + c.send(2) + c.done() + c.isClosedForReceive shouldBe false + c.isClosedForSend shouldBe true + + it should "be closed for receive when error with buffered values" in: + val c = Channel.newBufferedChannel[Int](5) + c.send(1) + c.send(2) + c.error(new RuntimeException()) + c.isClosedForReceive shouldBe true + c.isClosedForSend shouldBe true +end ChannelClosedTest diff --git a/core/src/test/scalanative/ox/channels/jox/ChannelInterruptionTest.scala b/core/src/test/scalanative/ox/channels/jox/ChannelInterruptionTest.scala new file mode 100644 index 00000000..5e71afd7 --- /dev/null +++ b/core/src/test/scalanative/ox/channels/jox/ChannelInterruptionTest.scala @@ -0,0 +1,72 @@ +package ox.channels.jox + +// Ported from: channels/src/test/java/com/softwaremill/jox/ChannelInterruptionTest.java (jox 1.1.2) +// Note: Memory leak tests and Fray tests not ported (require JVM-specific tooling) + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class ChannelInterruptionTest extends AnyFlatSpec with Matchers: + import TestUtil.* + + "interruption" should "interrupt a blocked send on rendezvous channel" in scoped { scope => + val ch = Channel.newRendezvousChannel[Int]() + val f = forkCancelable(scope, () => ch.send(1)) + Thread.sleep(50) + val result = f.cancel() + result shouldBe a[InterruptedException] + } + + it should "interrupt a blocked receive on rendezvous channel" in scoped { scope => + val ch = Channel.newRendezvousChannel[Int]() + val f = forkCancelable(scope, () => ch.receive()) + Thread.sleep(50) + val result = f.cancel() + result shouldBe a[InterruptedException] + } + + it should "interrupt a blocked send on full buffered channel" in scoped { scope => + val ch = Channel.newBufferedChannel[Int](1) + ch.send(1) // fill buffer + val f = forkCancelable(scope, () => ch.send(2)) + Thread.sleep(50) + val result = f.cancel() + result shouldBe a[InterruptedException] + } + + it should "interrupt a blocked receive on empty buffered channel" in scoped { scope => + val ch = Channel.newBufferedChannel[Int](1) + val f = forkCancelable(scope, () => ch.receive()) + Thread.sleep(50) + val result = f.cancel() + result shouldBe a[InterruptedException] + } + + it should "allow channel to continue working after interrupted send" in scoped { scope => + val ch = Channel.newRendezvousChannel[Int]() + + // start a sender, then interrupt it + val f = forkCancelable(scope, () => ch.send(1)) + Thread.sleep(50) + f.cancel() + + // channel should still work + forkVoid(scope, () => ch.send(2)) + Thread.sleep(50) + ch.receive() shouldBe 2 + } + + it should "allow channel to continue working after interrupted receive" in scoped { scope => + val ch = Channel.newRendezvousChannel[Int]() + + // start a receiver, then interrupt it + val f = forkCancelable(scope, () => ch.receive()) + Thread.sleep(50) + f.cancel() + + // channel should still work + forkVoid(scope, () => ch.send(3)) + Thread.sleep(50) + ch.receive() shouldBe 3 + } +end ChannelInterruptionTest diff --git a/core/src/test/scalanative/ox/channels/jox/ChannelRendezvousTest.scala b/core/src/test/scalanative/ox/channels/jox/ChannelRendezvousTest.scala new file mode 100644 index 00000000..17493bfa --- /dev/null +++ b/core/src/test/scalanative/ox/channels/jox/ChannelRendezvousTest.scala @@ -0,0 +1,95 @@ +package ox.channels.jox + +// Ported from: channels/src/test/java/com/softwaremill/jox/ChannelRendezvousTest.java (jox 1.1.2) + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import java.util.concurrent.{ConcurrentLinkedQueue, ConcurrentSkipListSet} + +class ChannelRendezvousTest extends AnyFlatSpec with Matchers: + import TestUtil.* + + "rendezvous channel" should "send and receive" in scoped { scope => + val channel = Channel.newRendezvousChannel[String]() + forkVoid(scope, () => channel.send("x")) + val t2 = fork(scope, () => channel.receive()) + t2.get() shouldBe "x" + } + + // Java version uses 1000; reduced for Scala Native virtual thread scheduling reliability + it should "send and receive in many forks" in scoped { scope => + val n = 100 + val channel = Channel.newRendezvousChannel[Int]() + val s = new ConcurrentSkipListSet[Int]() + + for i <- 1 to n do forkVoid(scope, () => channel.send(i)) + val fs = (1 to n).map(_ => + forkVoid( + scope, + () => + s.add(channel.receive()); () + ) + ) + fs.foreach(_.get()) + s.size() shouldBe n + } + + it should "send and receive many elements in two forks" in scoped { scope => + val channel = Channel.newRendezvousChannel[Int]() + val s = new ConcurrentSkipListSet[Int]() + + forkVoid(scope, () => for i <- 1 to 1000 do channel.send(i)) + forkVoid(scope, () => for _ <- 1 to 1000 do s.add(channel.receive())).get() + s.size() shouldBe 1000 + } + + it should "block sender until receiver arrives" in scoped { scope => + val channel = Channel.newRendezvousChannel[Int]() + val trail = new ConcurrentLinkedQueue[String]() + + forkVoid( + scope, + () => + channel.send(1); trail.add("S"); () + ) + forkVoid( + scope, + () => + channel.send(2); trail.add("S"); () + ) + forkVoid( + scope, + () => + Thread.sleep(100L) + trail.add("R1") + channel.receive() + Thread.sleep(100L) + trail.add("R2") + channel.receive() + ).get() + + Thread.sleep(100L) + import scala.jdk.CollectionConverters.* + trail.asScala.toList shouldBe List("R1", "S", "R2", "S") + } + + it should "notify pending receives when channel is done" in scoped { scope => + val c = Channel.newRendezvousChannel[Int]() + val f = fork(scope, () => c.receiveOrClosed()) + + Thread.sleep(100L) + c.done() + f.get() shouldBe a[ChannelDone] + c.receiveOrClosed() shouldBe a[ChannelDone] + } + + it should "notify pending sends when channel is errored" in scoped { scope => + val c = Channel.newRendezvousChannel[Int]() + val f = fork(scope, () => c.sendOrClosed(1)) + + Thread.sleep(100L) + c.error(new RuntimeException()) + f.get() shouldBe a[ChannelError] + c.sendOrClosed(2) shouldBe a[ChannelError] + } +end ChannelRendezvousTest diff --git a/core/src/test/scalanative/ox/channels/jox/ChannelTrySendReceiveTest.scala b/core/src/test/scalanative/ox/channels/jox/ChannelTrySendReceiveTest.scala new file mode 100644 index 00000000..c8eaed4f --- /dev/null +++ b/core/src/test/scalanative/ox/channels/jox/ChannelTrySendReceiveTest.scala @@ -0,0 +1,101 @@ +package ox.channels.jox + +// Ported from: channels/src/test/java/com/softwaremill/jox/ChannelTrySendReceiveTest.java (jox 1.1.2) + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class ChannelTrySendReceiveTest extends AnyFlatSpec with Matchers: + import TestUtil.* + "trySendOrClosed" should "return null when value is sent to buffered channel" in: + val ch = Channel.newBufferedChannel[Int](2) + ch.trySendOrClosed(1) shouldBe null + ch.trySendOrClosed(2) shouldBe null + + it should "return sentinel when buffered channel is full" in: + val ch = Channel.newBufferedChannel[Int](1) + ch.trySendOrClosed(1) shouldBe null + ch.trySendOrClosed(2) should be(Channel.TRY_SEND_NOT_SENT) + + it should "return ChannelClosed when channel is closed" in: + val ch = Channel.newBufferedChannel[Int](1) + ch.done() + ch.trySendOrClosed(1) shouldBe a[ChannelClosed] + + it should "send to unlimited channel" in: + val ch = Channel.newUnlimitedChannel[Int]() + ch.trySendOrClosed(1) shouldBe null + ch.trySendOrClosed(2) shouldBe null + ch.receive() shouldBe 1 + ch.receive() shouldBe 2 + + "tryReceiveOrClosed" should "return null when nothing is available" in: + val ch = Channel.newBufferedChannel[Int](2) + ch.tryReceiveOrClosed() shouldBe null + + it should "return a value when one is available" in: + val ch = Channel.newBufferedChannel[Int](2) + ch.send(42) + ch.tryReceiveOrClosed() shouldBe 42.asInstanceOf[AnyRef] + + it should "return ChannelClosed when channel is done and empty" in: + val ch = Channel.newBufferedChannel[Int](2) + ch.done() + ch.tryReceiveOrClosed() shouldBe a[ChannelClosed] + + it should "return buffered value even after done" in: + val ch = Channel.newBufferedChannel[Int](2) + ch.send(1) + ch.done() + ch.tryReceiveOrClosed() shouldBe 1.asInstanceOf[AnyRef] + ch.tryReceiveOrClosed() shouldBe a[ChannelClosed] + + "trySend on rendezvous" should "return false when no receiver" in: + val ch = Channel.newRendezvousChannel[String]() + val r = ch.trySendOrClosed("a") + r should not be null + r should not be a[ChannelClosed] + + it should "send when receiver is waiting" in scoped { scope => + val ch = Channel.newRendezvousChannel[String]() + fork(scope, () => ch.receive()) + var sent = false + for _ <- 0 until 10 if !sent do + Thread.sleep(10) + if ch.trySendOrClosed("x") == null then sent = true + sent shouldBe true + } + + "tryReceive on rendezvous" should "return null when no sender" in: + val ch = Channel.newRendezvousChannel[String]() + ch.tryReceiveOrClosed() shouldBe null + + it should "receive when sender is waiting" in scoped { scope => + val ch = Channel.newRendezvousChannel[String]() + forkVoid(scope, () => ch.send("x")) + var received: AnyRef | Null = null + for _ <- 0 until 10 if received == null do + Thread.sleep(10) + val r = ch.tryReceiveOrClosed() + if r != null && !r.isInstanceOf[ChannelClosed] then received = r + received shouldBe "x" + } + + "trySend on unlimited" should "always send" in: + val ch = Channel.newUnlimitedChannel[String]() + for i <- 0 until 1000 do ch.trySendOrClosed(s"v$i") shouldBe null + + "trySend on closed" should "throw on done" in: + val ch = Channel.newBufferedChannel[String](1) + ch.done() + ch.trySendOrClosed("x") shouldBe a[ChannelClosed] + + it should "throw on error" in: + val ch = Channel.newBufferedChannel[String](1) + ch.error(new RuntimeException("boom")) + ch.trySendOrClosed("x") shouldBe a[ChannelClosed] + + "trySend with null" should "throw NPE" in: + val ch = Channel.newBufferedChannel[String](1) + a[NullPointerException] should be thrownBy ch.trySendOrClosed(null.asInstanceOf[String]) +end ChannelTrySendReceiveTest diff --git a/core/src/test/scalanative/ox/channels/jox/ChannelUnlimitedTest.scala b/core/src/test/scalanative/ox/channels/jox/ChannelUnlimitedTest.scala new file mode 100644 index 00000000..9ad69a68 --- /dev/null +++ b/core/src/test/scalanative/ox/channels/jox/ChannelUnlimitedTest.scala @@ -0,0 +1,24 @@ +package ox.channels.jox + +// Ported from: channels/src/test/java/com/softwaremill/jox/ChannelUnlimitedTest.java (jox 1.1.2) + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class ChannelUnlimitedTest extends AnyFlatSpec with Matchers: + "unlimited channel" should "never block on send" in: + val ch = Channel.newUnlimitedChannel[Int]() + for i <- 1 to 1000 do ch.send(i) + for i <- 1 to 1000 do ch.receive() shouldBe i + + it should "handle done with buffered values" in: + val ch = Channel.newUnlimitedChannel[Int]() + ch.send(1) + ch.send(2) + ch.send(3) + ch.done() + ch.receive() shouldBe 1 + ch.receive() shouldBe 2 + ch.receive() shouldBe 3 + ch.receiveOrClosed() shouldBe a[ChannelDone] +end ChannelUnlimitedTest diff --git a/core/src/test/scalanative/ox/channels/jox/SelectTest.scala b/core/src/test/scalanative/ox/channels/jox/SelectTest.scala new file mode 100644 index 00000000..d8254508 --- /dev/null +++ b/core/src/test/scalanative/ox/channels/jox/SelectTest.scala @@ -0,0 +1,89 @@ +package ox.channels.jox + +// Ported from: channels/src/test/java/com/softwaremill/jox/SelectTest.java, +// channels/src/test/java/com/softwaremill/jox/SelectReceiveTest.java, +// channels/src/test/java/com/softwaremill/jox/SelectSendTest.java (jox 1.1.2) +// Note: Fray/stress tests not ported (require JVM-specific infrastructure) + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class SelectTest extends AnyFlatSpec with Matchers: + import TestUtil.* + + "selectOrClosed" should "receive from the first available channel" in: + val ch1 = Channel.newBufferedChannel[Int](1) + val ch2 = Channel.newBufferedChannel[String](1) + ch1.send(42) + val result = Select.selectOrClosed(ch1.receiveClause(), ch2.receiveClause()) + result shouldBe Integer.valueOf(42) + + it should "receive from the second channel if first is empty" in: + val ch1 = Channel.newRendezvousChannel[Int]() + val ch2 = Channel.newBufferedChannel[String](1) + ch2.send("hello") + val result = Select.selectOrClosed(ch1.receiveClause(), ch2.receiveClause()) + result shouldBe "hello" + + it should "select default when no channel is ready" in: + val ch1 = Channel.newRendezvousChannel[Int]() + val ch2 = Channel.newRendezvousChannel[String]() + val result = Select.selectOrClosed( + ch1.receiveClause(), + ch2.receiveClause(), + Select.defaultClause("default_value") + ) + result shouldBe "default_value" + + it should "return closed when channel is in error" in: + val ch1 = Channel.newRendezvousChannel[Int]() + val ex = new RuntimeException("error") + ch1.error(ex) + val result = Select.selectOrClosed(ch1.receiveClause()) + result shouldBe a[ChannelError] + + it should "support send clauses" in: + val ch = Channel.newBufferedChannel[Int](1) + val result = Select.selectOrClosed(ch.sendClause(42)) + result shouldBe null + ch.receive() shouldBe 42 + + it should "select from a channel with a waiting sender (rendezvous)" in scoped { scope => + val ch1 = Channel.newRendezvousChannel[Int]() + val ch2 = Channel.newRendezvousChannel[Int]() + forkVoid(scope, () => ch1.send(1)) + Thread.sleep(50) + val result = Select.selectOrClosed(ch1.receiveClause(), ch2.receiveClause()) + result shouldBe Integer.valueOf(1) + } + + it should "select a send clause when receiver is waiting" in scoped { scope => + val ch = Channel.newRendezvousChannel[Int]() + val f = fork(scope, () => ch.receive()) + Thread.sleep(50) + Select.selectOrClosed(ch.sendClause(99)) + f.get() shouldBe 99 + } + + it should "apply transformation callback on receive" in: + val ch = Channel.newBufferedChannel[Int](1) + ch.send(10) + val result = Select.selectOrClosed(ch.receiveClause(v => s"got:$v")) + result shouldBe "got:10" + + it should "apply callback on send" in: + val ch = Channel.newBufferedChannel[Int](1) + val result = Select.selectOrClosed(ch.sendClause(5, () => "sent!")) + result shouldBe "sent!" + ch.receive() shouldBe 5 + + it should "throw when default clause is not last" in: + val ch = Channel.newRendezvousChannel[Int]() + an[IllegalArgumentException] should be thrownBy + Select.selectOrClosed(Select.defaultClause(0), ch.receiveClause()) + + it should "throw when same channel appears in multiple clauses" in: + val ch = Channel.newBufferedChannel[Int](1) + an[IllegalArgumentException] should be thrownBy + Select.selectOrClosed(ch.receiveClause(), ch.receiveClause()) +end SelectTest diff --git a/core/src/test/scalanative/ox/channels/jox/TestUtil.scala b/core/src/test/scalanative/ox/channels/jox/TestUtil.scala new file mode 100644 index 00000000..1b838e02 --- /dev/null +++ b/core/src/test/scalanative/ox/channels/jox/TestUtil.scala @@ -0,0 +1,79 @@ +package ox.channels.jox + +// Ported from: channels/src/test/java/com/softwaremill/jox/TestUtil.java (jox 1.1.2) + +import java.util.concurrent.{CompletableFuture, ExecutionException, Future} +import scala.collection.mutable + +object TestUtil: + def scoped(f: Scope => Unit): Unit = + val scope = new Scope + val mainTask = Thread + .ofVirtual() + .start(() => + try f(scope) + catch case e: Exception => scope.completeExceptionally(e) + ) + mainTask.join() + scope.waitForCompletion() + end scoped + + def fork[T](scope: Scope, f: () => T): Future[T] = + val cf = new CompletableFuture[T]() + Thread + .ofVirtual() + .start(() => + try cf.complete(f()) + catch case ex: Exception => cf.completeExceptionally(ex) + ) + scope.addFuture(cf) + cf + end fork + + def forkVoid(scope: Scope, f: () => Unit): Future[Void] = + fork( + scope, + () => + f(); null + ) + + def forkCancelable[T](scope: Scope, f: () => T): CancelableFork[T] = + val cf = new CompletableFuture[T]() + val t = Thread + .ofVirtual() + .start(() => + try cf.complete(f()) + catch case ex: Exception => cf.completeExceptionally(ex) + ) + new CancelableFork(t, cf) + end forkCancelable + + class Scope: + private val futures = mutable.ArrayBuffer.empty[CompletableFuture[?]] + @volatile private var exception: Exception | Null = null + + def addFuture(f: CompletableFuture[?]): Unit = synchronized { futures += f } + + def completeExceptionally(e: Exception): Unit = exception = e + + def waitForCompletion(): Unit = + if exception != null then throw new ExecutionException(exception) + synchronized { + for f <- futures do + try f.get() + catch + case e: ExecutionException => + if exception == null then exception = e.getCause.asInstanceOf[Exception] + } + if exception != null then throw new ExecutionException(exception) + end waitForCompletion + end Scope + + class CancelableFork[T](thread: Thread, future: CompletableFuture[T]): + def get(): T = future.get() + def cancel(): AnyRef = + thread.interrupt() + thread.join() + if future.isCompletedExceptionally then future.exceptionNow() + else future.get().asInstanceOf[AnyRef] +end TestUtil diff --git a/doc/index.md b/doc/index.md index c489321d..e61cd4ed 100644 --- a/doc/index.md +++ b/doc/index.md @@ -1,8 +1,9 @@ # Ox -Safe direct-style streaming, concurrency and resiliency for Scala on the JVM. Requires JDK 21+ & Scala 3. +Safe direct-style streaming, concurrency and resiliency for Scala on the JVM. +Requires JDK 21+ & Scala 3. Experimental support for Scala Native. -To start using Ox, add the `com.softwaremill.ox::core:@VERSION@` [dependency](info/dependency.md) to your project. +To start using Ox, add the `com.softwaremill.ox:::core:@VERSION@` [dependency](info/dependency.md) to your project. Then, take a look at the tour of Ox, or follow one of the topics listed in the menu to get to know Ox's API! In addition to this documentation, ScalaDocs can be browsed at [https://javadoc.io](https://www.javadoc.io/doc/com.softwaremill.ox). diff --git a/doc/info/dependency.md b/doc/info/dependency.md index 034c4583..63c5ee99 100644 --- a/doc/info/dependency.md +++ b/doc/info/dependency.md @@ -4,12 +4,14 @@ To use ox core in your project, add: ```scala // sbt dependency -"com.softwaremill.ox" %% "core" % "@VERSION@" +"com.softwaremill.ox" %%% "core" % "@VERSION@" // scala-cli dependency -//> using dep com.softwaremill.ox::core:@VERSION@ +//> using dep com.softwaremill.ox:::core:@VERSION@ ``` -Ox core depends only on the Java [jox](https://github.com/softwaremill/jox) project, where channels are implemented. There are no other direct or transitive dependencies. +On the JVM, Ox core depends only on the Java [jox](https://github.com/softwaremill/jox) project, where channels are implemented. There are no other direct or transitive dependencies. + +For Scala Native, only the the `core` module is available. It contains a reimplementation of Jox Channels in pure Scala. Integration modules have separate dependencies. \ No newline at end of file diff --git a/examples/src/main/scala/NativeVirtualThreadScalabilityIssue.scala b/examples/src/main/scala/NativeVirtualThreadScalabilityIssue.scala new file mode 100644 index 00000000..f5a548ee --- /dev/null +++ b/examples/src/main/scala/NativeVirtualThreadScalabilityIssue.scala @@ -0,0 +1,46 @@ +import java.util.concurrent.CompletableFuture +import java.util.concurrent.atomic.AtomicInteger + +/** Reproduces a Scala Native 0.5.12 virtual thread scalability issue. + * + * Pattern: N virtual threads each block on CompletableFuture.get() while a single "actor" virtual thread processes + * requests sequentially. At N=100 this works. At N>=500 it livelocks on Native (works on JVM). + * + * This is the same pattern used by ox.channels.Actor.ask under high concurrency. + * + * To reproduce: + * {{{ + * sbt examples3/run # JVM — prints "All OK" + * sbt examplesNative3/run # Native — hangs at N>=500 + * }}} + */ +object NativeVirtualThreadScalabilityIssue: + def main(args: Array[String]): Unit = + for n <- List(10, 100, 500, 1000) do + println(s"Running with N=$n...") + run(n) + println(s" N=$n OK") + println("All OK") + + private def run(n: Int): Unit = + val queue = new java.util.concurrent.LinkedBlockingQueue[() => Unit]() + val counter = new AtomicInteger(0) + + // single "actor" thread processing requests sequentially + val actor = Thread.ofVirtual().start { () => + while true do + val msg = queue.take() + msg() + } + + // N concurrent virtual threads, each sending a request and blocking on the response + val threads = (1 to n).map { _ => + Thread.ofVirtual().start { () => + val result = new CompletableFuture[Int]() + queue.put { () => result.complete(counter.incrementAndGet()): Unit } + result.get(): Unit + } + } + + threads.foreach(_.join()) + actor.interrupt() diff --git a/examples/src/main/scala/VirtualThreadsNativeJvmBenchmark.scala b/examples/src/main/scala/VirtualThreadsNativeJvmBenchmark.scala new file mode 100644 index 00000000..4818c6da --- /dev/null +++ b/examples/src/main/scala/VirtualThreadsNativeJvmBenchmark.scala @@ -0,0 +1,44 @@ +import ox.* + +import java.util.concurrent.atomic.AtomicLong + +/** Spawns 100,000 virtual threads in a supervised scope, each incrementing a shared counter, then joins all. Measures wall-clock time to + * compare JVM vs Scala Native performance. + * + * Prerequisites: JDK 21+ (JVM), clang/LLVM 16+ (Native). + * + * To package & run: + * {{{ + * # JVM fat jar + * sbt examples3/assembly + * java -jar examples/target/jvm-3/examples-assembly.jar + * + * # Native binary + * sbt examplesNative3/nativeLink + * ./examples/target/native-3/example + * + * # Compare (3 iterations each): + * for i in 1 2 3; do java -jar examples/target/jvm-3/examples-assembly.jar; done + * for i in 1 2 3; do ./examples/target/native-3/example; done + * }}} + */ +object VirtualThreadsNativeJvmBenchmark: + def main(args: Array[String]): Unit = + val n = 100_000 + val counter = new AtomicLong(0L) + + val start = System.nanoTime() + + supervised { + for _ <- 1 to n do + fork { + counter.incrementAndGet() + } + } + + val elapsed = (System.nanoTime() - start) / 1_000_000 + + assert(counter.get() == n, s"Expected $n, got ${counter.get()}") + println(s"Spawned and joined $n virtual threads in ${elapsed}ms") + end main +end VirtualThreadsNativeJvmBenchmark diff --git a/project/plugins.sbt b/project/plugins.sbt index 4a1ee523..c0490e4b 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,3 +3,6 @@ addSbtPlugin("com.softwaremill.sbt-softwaremill" % "sbt-softwaremill-common" % s addSbtPlugin("com.softwaremill.sbt-softwaremill" % "sbt-softwaremill-publish" % sbtSoftwareMillVersion) addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.9.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "1.1.5") +addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.5.12") +addSbtPlugin("com.eed3si9n" % "sbt-projectmatrix" % "0.11.0") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.1")