diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index a7139724e9..ab19182b2d 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -39,7 +39,7 @@ import scala.concurrent.duration.{Duration, FiniteDuration} import java.time.Instant import java.time.temporal.ChronoField -import java.util.concurrent.{LinkedTransferQueue, ThreadLocalRandom} +import java.util.concurrent.{SynchronousQueue, ThreadLocalRandom} import java.util.concurrent.atomic.{ AtomicBoolean, AtomicInteger, @@ -131,8 +131,15 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( */ private[this] val state: AtomicInteger = new AtomicInteger(threadCount << UnparkShift) - private[unsafe] val cachedThreads: LinkedTransferQueue[WorkerThread[P]] = - new LinkedTransferQueue + private[unsafe] val transferStateStack: SynchronousQueue[WorkerThread.TransferState] = + new SynchronousQueue[WorkerThread.TransferState]( + // Note: we use the queue in UNfair mode, so it's a stack really + // (we depend on an implementation detail of openjdk, where unfair + // SynchronousQueue is implemented with a stack). This is important + // so that older cached threads can time out and shut down even + // if there are frequent blocking operations (see issue #4382). + false + ) /** * The shutdown latch of the work stealing thread pool. @@ -749,12 +756,8 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( system.close() } - var t: WorkerThread[P] = null - while ({ - t = cachedThreads.poll() - t ne null - }) { - t.interrupt() + // signal cached threads to shut down: + while (transferStateStack.offer(WorkerThread.transferStateSentinel)) { // don't bother joining, cached threads are not doing anything interesting } diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index 615fe8804e..b29c915a4e 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -26,7 +26,7 @@ import scala.concurrent.{BlockContext, CanAwait} import scala.concurrent.duration.{Duration, FiniteDuration} import java.lang.Long.MIN_VALUE -import java.util.concurrent.{ArrayBlockingQueue, ThreadLocalRandom} +import java.util.concurrent.ThreadLocalRandom import java.util.concurrent.atomic.AtomicBoolean import WorkerThread.{Metrics, TransferState} @@ -110,7 +110,6 @@ private[effect] final class WorkerThread[P <: AnyRef]( */ private[this] var _active: Runnable = _ - private val stateTransfer: ArrayBlockingQueue[TransferState] = new ArrayBlockingQueue(1) private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration private[effect] var currentIOFiber: IOFiber[?] = _ @@ -732,12 +731,15 @@ private[effect] final class WorkerThread[P <: AnyRef]( // by another thread in the future. val len = runtimeBlockingExpiration.length val unit = runtimeBlockingExpiration.unit - if (pool.cachedThreads.tryTransfer(this, len, unit)) { - // Someone accepted the transfer of this thread and will transfer the state soon. - val newState = stateTransfer.take() + + // Try to poll for a new state from the transfer queue + val newState = pool.transferStateStack.poll(len, unit) + + if ((newState ne null) && (newState ne WorkerThread.transferStateSentinel)) { + // Got a state to take over init(newState) } else { - // The timeout elapsed and no one woke up this thread. It's time to exit. + // No state to take over after timeout (or we're shutting down), exit pool.blockedWorkerThreadCounter.decrementAndGet() return } @@ -745,7 +747,8 @@ private[effect] final class WorkerThread[P <: AnyRef]( case _: InterruptedException => // This thread was interrupted while cached. This should only happen // during the shutdown of the pool. Nothing else to be done, just - // exit. + // exit. (Note, that if we're shutting down ourselves, we're doing + // that with `transferStateSentinel`, see above.) return } } @@ -928,15 +931,14 @@ private[effect] final class WorkerThread[P <: AnyRef]( // Set the name of this thread to a blocker prefixed name. setName(s"$prefix-$nameIndex") - val cached = pool.cachedThreads.poll() - if (cached ne null) { - // There is a cached worker thread that can be reused. - val idx = index - pool.replaceWorker(idx, cached) - // Transfer the data structures to the cached thread and wake it up. - transferState.index = idx - transferState.tick = tick + 1 - val _ = cached.stateTransfer.offer(transferState) + val idx = index + + // Prepare the transfer state + transferState.index = idx + transferState.tick = tick + 1 + + if (pool.transferStateStack.offer(transferState)) { + // If successful, a waiting thread will pick it up } else { // Spawn a new `WorkerThread`, a literal clone of this one. It is safe to // transfer ownership of the local queue and the parked signal to the new @@ -1002,6 +1004,8 @@ private[effect] final class WorkerThread[P <: AnyRef]( setName(s"$prefix-${_index}") blocking = false + + pool.replaceWorker(newIdx, this) } /** @@ -1026,6 +1030,12 @@ private[effect] object WorkerThread { var tick: Int = _ } + /** + * We use this to signal interrupt to cached threads + */ + private[unsafe] val transferStateSentinel: TransferState = + new TransferState + final class Metrics { private[this] var idleTime: Long = 0 def getIdleTime(): Long = idleTime diff --git a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala index 9c35359f71..e619031782 100644 --- a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala +++ b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala @@ -512,6 +512,40 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala ok } + "cached threads should be used in LIFO order" in { + val (pool, poller, shutdown) = IORuntime.createWorkStealingComputeThreadPool( + threads = 1, + pollingSystem = SleepSystem) + + implicit val runtime: IORuntime = + IORuntime.builder().setCompute(pool, shutdown).addPoller(poller, () => ()).build() + + try { + val test = for { + // create 3 blocker threads, which will be cached in order: + th1 <- IO(new AtomicReference[Thread]) + fib1 <- IO.blocking { th1.set(Thread.currentThread()); Thread.sleep(200L) }.start + th2 <- IO(new AtomicReference[Thread]) + fib2 <- IO.blocking { th2.set(Thread.currentThread()); Thread.sleep(400L) }.start + th3 <- IO(new AtomicReference[Thread]) + fib3 <- IO.blocking { th3.set(Thread.currentThread()); Thread.sleep(600L) }.start + _ <- fib1.join + _ <- fib2.join + _ <- fib3.join + // now we have 3 cached threads, and when we do `blocking`, the LAST one should be used, + // so the first 2 should remain cached (blocked in SynchronousQueue#poll): + _ <- IO.blocking { Thread.sleep(100L) } + _ <- IO.cede // move back to the WSTP + _ <- IO { th1.get().getState() mustEqual Thread.State.TIMED_WAITING } + _ <- IO { th2.get().getState() mustEqual Thread.State.TIMED_WAITING } + _ <- IO { th3.get().getState() must not(beEqualTo(Thread.State.TIMED_WAITING)) } + } yield ok + test.unsafeRunSync() + } finally { + runtime.shutdown() + } + } + trait DummyPoller { def poll: IO[Unit] }