Skip to content

[8.17] [ML] Refactor inference request executor to leverage scheduled execution (#126858) #126950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/126858.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126858
summary: Leverage threadpool schedule for inference api to avoid long running thread
area: Machine Learning
type: bug
issues:
- 126853
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,6 @@
*/
class RequestExecutorService implements RequestExecutor {

/**
* Provides dependency injection mainly for testing
*/
interface Sleeper {
void sleep(TimeValue sleepTime) throws InterruptedException;
}

// default for tests
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
// default for tests
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
Expand Down Expand Up @@ -106,7 +97,6 @@ interface RateLimiterCreator {
private final Clock clock;
private final AtomicBoolean shutdown = new AtomicBoolean(false);
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
private final Sleeper sleeper;
private final RateLimiterCreator rateLimiterCreator;
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
private final AtomicBoolean started = new AtomicBoolean(false);
Expand All @@ -117,16 +107,7 @@ interface RateLimiterCreator {
RequestExecutorServiceSettings settings,
RequestSender requestSender
) {
this(
threadPool,
DEFAULT_QUEUE_CREATOR,
startupLatch,
settings,
requestSender,
Clock.systemUTC(),
DEFAULT_SLEEPER,
DEFAULT_RATE_LIMIT_CREATOR
);
this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR);
}

RequestExecutorService(
Expand All @@ -136,7 +117,6 @@ interface RateLimiterCreator {
RequestExecutorServiceSettings settings,
RequestSender requestSender,
Clock clock,
Sleeper sleeper,
RateLimiterCreator rateLimiterCreator
) {
this.threadPool = Objects.requireNonNull(threadPool);
Expand All @@ -145,7 +125,6 @@ interface RateLimiterCreator {
this.requestSender = Objects.requireNonNull(requestSender);
this.settings = Objects.requireNonNull(settings);
this.clock = Objects.requireNonNull(clock);
this.sleeper = Objects.requireNonNull(sleeper);
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
}

Expand Down Expand Up @@ -188,15 +167,10 @@ public void start() {
startCleanupTask();
signalStartInitiated();

while (isShutdown() == false) {
handleTasks();
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
shutdown();
notifyRequestsOfShutdown();
terminationLatch.countDown();
handleTasks();
} catch (Exception e) {
logger.warn("Failed to start request executor", e);
cleanup();
}
}

Expand Down Expand Up @@ -231,13 +205,44 @@ void removeStaleGroupings() {
}
}

private void handleTasks() throws InterruptedException {
var timeToWait = settings.getTaskPollFrequency();
for (var endpoint : rateLimitGroupings.values()) {
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
private void scheduleNextHandleTasks(TimeValue timeToWait) {
if (shutdown.get()) {
logger.debug("Shutdown requested while scheduling next handle task call, cleaning up");
cleanup();
return;
}

sleeper.sleep(timeToWait);
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME));
}

private void cleanup() {
try {
shutdown();
notifyRequestsOfShutdown();
terminationLatch.countDown();
} catch (Exception e) {
logger.warn("Encountered an error while cleaning up", e);
}
}

private void handleTasks() {
try {
if (shutdown.get()) {
logger.debug("Shutdown requested while handling tasks, cleaning up");
cleanup();
return;
}

var timeToWait = settings.getTaskPollFrequency();
for (var endpoint : rateLimitGroupings.values()) {
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
}

scheduleNextHandleTasks(timeToWait);
} catch (Exception e) {
logger.warn("Encountered an error while handling tasks", e);
cleanup();
}
}

private void notifyRequestsOfShutdown() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -77,7 +78,7 @@ public void shutdown() throws IOException, InterruptedException {
}

public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
var senderFactory = createSenderFactory(clientManager, threadRef);
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());

try (var sender = createSender(senderFactory)) {
sender.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand Down Expand Up @@ -195,7 +194,7 @@ public void testExecute_Throws_WhenQueueIsFull() {
assertFalse(thrownException.isExecutorShutdown());
}

public void testTaskThrowsError_CallsOnFailure() {
public void testTaskThrowsError_CallsOnFailure() throws InterruptedException {
var requestSender = mock(RetryingHttpSender.class);

var service = createRequestExecutorService(null, requestSender);
Expand All @@ -218,6 +217,8 @@ public void testTaskThrowsError_CallsOnFailure() {
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id")));
assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class));
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);

assertTrue(service.isTerminated());
}

Expand Down Expand Up @@ -342,7 +343,6 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
createRequestExecutorServiceSettingsEmpty(),
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);

Expand All @@ -356,36 +356,7 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
});
service.start();

assertTrue(service.isTerminated());
}

public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
@SuppressWarnings("unchecked")
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
var sleeper = mock(RequestExecutorService.Sleeper.class);
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());

var service = new RequestExecutorService(
threadPool,
mockQueueCreator(queue),
null,
createRequestExecutorServiceSettingsEmpty(),
mock(RetryingHttpSender.class),
Clock.systemUTC(),
sleeper,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);

Future<?> executorTermination = threadPool.generic().submit(() -> {
try {
service.start();
} catch (Exception e) {
fail(Strings.format("Failed to shutdown executor: %s", e));
}
});

executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);

service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
assertTrue(service.isTerminated());
}

Expand Down Expand Up @@ -552,7 +523,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
rateLimiterCreator
);
var requestManager = RequestManagerTests.createMock(requestSender);
Expand Down Expand Up @@ -585,7 +555,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
rateLimiterCreator
);
var requestManager = RequestManagerTests.createMock(requestSender);
Expand All @@ -597,11 +566,15 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And

doAnswer(invocation -> {
service.shutdown();
ActionListener<InferenceServiceResults> passedListener = invocation.getArgument(4);
passedListener.onResponse(null);

return Void.TYPE;
}).when(requestSender).send(any(), any(), any(), any(), any());

service.start();

listener.actionGet(TIMEOUT);
verify(requestSender, times(1)).send(any(), any(), any(), any(), any());
}

Expand All @@ -619,7 +592,6 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() {
settings,
requestSender,
clock,
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
Expand Down Expand Up @@ -653,7 +625,6 @@ public void testStartsCleanupThread() {
settings,
requestSender,
Clock.systemUTC(),
RequestExecutorService.DEFAULT_SLEEPER,
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
);

Expand Down