Skip to content

PoC for not using a CHM in #4388 #4426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import scala.concurrent.duration.{Duration, FiniteDuration}

import java.time.Instant
import java.time.temporal.ChronoField
import java.util.concurrent.{LinkedTransferQueue, ThreadLocalRandom}
import java.util.concurrent.{SynchronousQueue, ThreadLocalRandom}
import java.util.concurrent.atomic.{
AtomicBoolean,
AtomicInteger,
Expand Down Expand Up @@ -131,8 +131,15 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
*/
private[this] val state: AtomicInteger = new AtomicInteger(threadCount << UnparkShift)

private[unsafe] val cachedThreads: LinkedTransferQueue[WorkerThread[P]] =
new LinkedTransferQueue
private[unsafe] val transferStateStack: SynchronousQueue[WorkerThread.TransferState] =
new SynchronousQueue[WorkerThread.TransferState](
// Note: we use the queue in UNfair mode, so it's a stack really
// (we depend on an implementation detail of openjdk, where unfair
// SynchronousQueue is implemented with a stack). This is important
// so that older cached threads can time out and shut down even
// if there are frequent blocking operations (see issue #4382).
false
)

/**
* The shutdown latch of the work stealing thread pool.
Expand Down Expand Up @@ -749,12 +756,8 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
system.close()
}

var t: WorkerThread[P] = null
while ({
t = cachedThreads.poll()
t ne null
}) {
t.interrupt()
// signal cached threads to shut down:
while (transferStateStack.offer(WorkerThread.transferStateSentinel)) {
// don't bother joining, cached threads are not doing anything interesting
}

Expand Down
42 changes: 26 additions & 16 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.concurrent.{BlockContext, CanAwait}
import scala.concurrent.duration.{Duration, FiniteDuration}

import java.lang.Long.MIN_VALUE
import java.util.concurrent.{ArrayBlockingQueue, ThreadLocalRandom}
import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.AtomicBoolean

import WorkerThread.{Metrics, TransferState}
Expand Down Expand Up @@ -110,7 +110,6 @@ private[effect] final class WorkerThread[P <: AnyRef](
*/
private[this] var _active: Runnable = _

private val stateTransfer: ArrayBlockingQueue[TransferState] = new ArrayBlockingQueue(1)
private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration

private[effect] var currentIOFiber: IOFiber[?] = _
Expand Down Expand Up @@ -732,20 +731,24 @@ private[effect] final class WorkerThread[P <: AnyRef](
// by another thread in the future.
val len = runtimeBlockingExpiration.length
val unit = runtimeBlockingExpiration.unit
if (pool.cachedThreads.tryTransfer(this, len, unit)) {
// Someone accepted the transfer of this thread and will transfer the state soon.
val newState = stateTransfer.take()

// Try to poll for a new state from the transfer queue
val newState = pool.transferStateStack.poll(len, unit)

if ((newState ne null) && (newState ne WorkerThread.transferStateSentinel)) {
// Got a state to take over
init(newState)
} else {
// The timeout elapsed and no one woke up this thread. It's time to exit.
// No state to take over after timeout (or we're shutting down), exit
pool.blockedWorkerThreadCounter.decrementAndGet()
return
}
} catch {
case _: InterruptedException =>
// This thread was interrupted while cached. This should only happen
// during the shutdown of the pool. Nothing else to be done, just
// exit.
// exit. (Note, that if we're shutting down ourselves, we're doing
// that with `transferStateSentinel`, see above.)
return
}
}
Expand Down Expand Up @@ -928,15 +931,14 @@ private[effect] final class WorkerThread[P <: AnyRef](
// Set the name of this thread to a blocker prefixed name.
setName(s"$prefix-$nameIndex")

val cached = pool.cachedThreads.poll()
if (cached ne null) {
// There is a cached worker thread that can be reused.
val idx = index
pool.replaceWorker(idx, cached)
// Transfer the data structures to the cached thread and wake it up.
transferState.index = idx
transferState.tick = tick + 1
val _ = cached.stateTransfer.offer(transferState)
val idx = index

// Prepare the transfer state
transferState.index = idx
transferState.tick = tick + 1

if (pool.transferStateStack.offer(transferState)) {
// If successful, a waiting thread will pick it up
} else {
// Spawn a new `WorkerThread`, a literal clone of this one. It is safe to
// transfer ownership of the local queue and the parked signal to the new
Expand Down Expand Up @@ -1002,6 +1004,8 @@ private[effect] final class WorkerThread[P <: AnyRef](
setName(s"$prefix-${_index}")

blocking = false

pool.replaceWorker(newIdx, this)
}

/**
Expand All @@ -1026,6 +1030,12 @@ private[effect] object WorkerThread {
var tick: Int = _
}

/**
* We use this to signal interrupt to cached threads
*/
private[unsafe] val transferStateSentinel: TransferState =
new TransferState

final class Metrics {
private[this] var idleTime: Long = 0
def getIdleTime(): Long = idleTime
Expand Down
34 changes: 34 additions & 0 deletions tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,40 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala
ok
}

"cached threads should be used in LIFO order" in {
val (pool, poller, shutdown) = IORuntime.createWorkStealingComputeThreadPool(
threads = 1,
pollingSystem = SleepSystem)

implicit val runtime: IORuntime =
IORuntime.builder().setCompute(pool, shutdown).addPoller(poller, () => ()).build()

try {
val test = for {
// create 3 blocker threads, which will be cached in order:
th1 <- IO(new AtomicReference[Thread])
fib1 <- IO.blocking { th1.set(Thread.currentThread()); Thread.sleep(200L) }.start
th2 <- IO(new AtomicReference[Thread])
fib2 <- IO.blocking { th2.set(Thread.currentThread()); Thread.sleep(400L) }.start
th3 <- IO(new AtomicReference[Thread])
fib3 <- IO.blocking { th3.set(Thread.currentThread()); Thread.sleep(600L) }.start
_ <- fib1.join
_ <- fib2.join
_ <- fib3.join
// now we have 3 cached threads, and when we do `blocking`, the LAST one should be used,
// so the first 2 should remain cached (blocked in SynchronousQueue#poll):
_ <- IO.blocking { Thread.sleep(100L) }
_ <- IO.cede // move back to the WSTP
_ <- IO { th1.get().getState() mustEqual Thread.State.TIMED_WAITING }
_ <- IO { th2.get().getState() mustEqual Thread.State.TIMED_WAITING }
_ <- IO { th3.get().getState() must not(beEqualTo(Thread.State.TIMED_WAITING)) }
} yield ok
test.unsafeRunSync()
} finally {
runtime.shutdown()
}
}

trait DummyPoller {
def poll: IO[Unit]
}
Expand Down
Loading