diff --git a/docs/changelog/126858.yaml b/docs/changelog/126858.yaml new file mode 100644 index 0000000000000..d1ea2ebba73ef --- /dev/null +++ b/docs/changelog/126858.yaml @@ -0,0 +1,6 @@ +pr: 126858 +summary: Leverage threadpool schedule for inference api to avoid long running thread +area: Machine Learning +type: bug +issues: + - 126853 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index ad1324d0a315f..69dca7fb31dfd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -55,15 +55,6 @@ */ class RequestExecutorService implements RequestExecutor { - /** - * Provides dependency injection mainly for testing - */ - interface Sleeper { - void sleep(TimeValue sleepTime) throws InterruptedException; - } - - // default for tests - static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration()); // default for tests static final AdjustableCapacityBlockingQueue.QueueCreator DEFAULT_QUEUE_CREATOR = new AdjustableCapacityBlockingQueue.QueueCreator<>() { @@ -106,7 +97,6 @@ interface RateLimiterCreator { private final Clock clock; private final AtomicBoolean shutdown = new AtomicBoolean(false); private final AdjustableCapacityBlockingQueue.QueueCreator queueCreator; - private final Sleeper sleeper; private final RateLimiterCreator rateLimiterCreator; private final AtomicReference cancellableCleanupTask = new AtomicReference<>(); private final AtomicBoolean started = new AtomicBoolean(false); @@ -117,16 +107,7 @@ interface RateLimiterCreator { RequestExecutorServiceSettings settings, RequestSender requestSender ) { - this( - threadPool, - DEFAULT_QUEUE_CREATOR, - startupLatch, - settings, - requestSender, - Clock.systemUTC(), - DEFAULT_SLEEPER, - DEFAULT_RATE_LIMIT_CREATOR - ); + this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR); } RequestExecutorService( @@ -136,7 +117,6 @@ interface RateLimiterCreator { RequestExecutorServiceSettings settings, RequestSender requestSender, Clock clock, - Sleeper sleeper, RateLimiterCreator rateLimiterCreator ) { this.threadPool = Objects.requireNonNull(threadPool); @@ -145,7 +125,6 @@ interface RateLimiterCreator { this.requestSender = Objects.requireNonNull(requestSender); this.settings = Objects.requireNonNull(settings); this.clock = Objects.requireNonNull(clock); - this.sleeper = Objects.requireNonNull(sleeper); this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator); } @@ -188,15 +167,10 @@ public void start() { startCleanupTask(); signalStartInitiated(); - while (isShutdown() == false) { - handleTasks(); - } - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } finally { - shutdown(); - notifyRequestsOfShutdown(); - terminationLatch.countDown(); + handleTasks(); + } catch (Exception e) { + logger.warn("Failed to start request executor", e); + cleanup(); } } @@ -231,13 +205,44 @@ void removeStaleGroupings() { } } - private void handleTasks() throws InterruptedException { - var timeToWait = settings.getTaskPollFrequency(); - for (var endpoint : rateLimitGroupings.values()) { - timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); + private void scheduleNextHandleTasks(TimeValue timeToWait) { + if (shutdown.get()) { + logger.debug("Shutdown requested while scheduling next handle task call, cleaning up"); + cleanup(); + return; } - sleeper.sleep(timeToWait); + threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + + private void cleanup() { + try { + shutdown(); + notifyRequestsOfShutdown(); + terminationLatch.countDown(); + } catch (Exception e) { + logger.warn("Encountered an error while cleaning up", e); + } + } + + private void handleTasks() { + try { + if (shutdown.get()) { + logger.debug("Shutdown requested while handling tasks, cleaning up"); + cleanup(); + return; + } + + var timeToWait = settings.getTaskPollFrequency(); + for (var endpoint : rateLimitGroupings.values()) { + timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait); + } + + scheduleNextHandleTasks(timeToWait); + } catch (Exception e) { + logger.warn("Encountered an error while handling tasks", e); + cleanup(); + } } private void notifyRequestsOfShutdown() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 79f6aa8164b75..61708dab490c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -40,6 +40,7 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -77,7 +78,7 @@ public void shutdown() throws IOException, InterruptedException { } public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception { - var senderFactory = createSenderFactory(clientManager, threadRef); + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); try (var sender = createSender(senderFactory)) { sender.start(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index e09e4968571e5..13399b320e176 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -50,7 +50,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -195,7 +194,7 @@ public void testExecute_Throws_WhenQueueIsFull() { assertFalse(thrownException.isExecutorShutdown()); } - public void testTaskThrowsError_CallsOnFailure() { + public void testTaskThrowsError_CallsOnFailure() throws InterruptedException { var requestSender = mock(RetryingHttpSender.class); var service = createRequestExecutorService(null, requestSender); @@ -218,6 +217,8 @@ public void testTaskThrowsError_CallsOnFailure() { var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id"))); assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class)); + service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + assertTrue(service.isTerminated()); } @@ -342,7 +343,6 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I createRequestExecutorServiceSettingsEmpty(), requestSender, Clock.systemUTC(), - RequestExecutorService.DEFAULT_SLEEPER, RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); @@ -356,36 +356,7 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I }); service.start(); - assertTrue(service.isTerminated()); - } - - public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception { - @SuppressWarnings("unchecked") - BlockingQueue queue = mock(LinkedBlockingQueue.class); - var sleeper = mock(RequestExecutorService.Sleeper.class); - doThrow(new InterruptedException("failed")).when(sleeper).sleep(any()); - - var service = new RequestExecutorService( - threadPool, - mockQueueCreator(queue), - null, - createRequestExecutorServiceSettingsEmpty(), - mock(RetryingHttpSender.class), - Clock.systemUTC(), - sleeper, - RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR - ); - - Future executorTermination = threadPool.generic().submit(() -> { - try { - service.start(); - } catch (Exception e) { - fail(Strings.format("Failed to shutdown executor: %s", e)); - } - }); - - executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); - + service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS); assertTrue(service.isTerminated()); } @@ -552,7 +523,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() { settings, requestSender, Clock.systemUTC(), - RequestExecutorService.DEFAULT_SLEEPER, rateLimiterCreator ); var requestManager = RequestManagerTests.createMock(requestSender); @@ -585,7 +555,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And settings, requestSender, Clock.systemUTC(), - RequestExecutorService.DEFAULT_SLEEPER, rateLimiterCreator ); var requestManager = RequestManagerTests.createMock(requestSender); @@ -597,11 +566,15 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And doAnswer(invocation -> { service.shutdown(); + ActionListener passedListener = invocation.getArgument(4); + passedListener.onResponse(null); + return Void.TYPE; }).when(requestSender).send(any(), any(), any(), any(), any()); service.start(); + listener.actionGet(TIMEOUT); verify(requestSender, times(1)).send(any(), any(), any(), any(), any()); } @@ -619,7 +592,6 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() { settings, requestSender, clock, - RequestExecutorService.DEFAULT_SLEEPER, RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR ); var requestManager = RequestManagerTests.createMock(requestSender, "id1"); @@ -653,7 +625,6 @@ public void testStartsCleanupThread() { settings, requestSender, Clock.systemUTC(), - RequestExecutorService.DEFAULT_SLEEPER, RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR );