Skip to content

[ML] Adding configurable inference service #127939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
14a5383
Inference changes
jonathan-buttner May 8, 2025
eba5fce
Custom service fixes
jonathan-buttner May 8, 2025
9af98be
Update docs/changelog/127939.yaml
jonathan-buttner May 8, 2025
cb09e30
Cleaning up from failed merge
jonathan-buttner May 8, 2025
c8642cd
Merge branch 'custom-inference-service-jon' of github.com:jonathan-bu…
jonathan-buttner May 8, 2025
e7c62d8
Fixing changelog
jonathan-buttner May 8, 2025
6bb2a95
[CI] Auto commit changes from spotless
elasticsearchmachine May 8, 2025
67329e2
Fixing transport version
jonathan-buttner May 20, 2025
dd14970
Merge branch 'custom-inference-service-jon' of github.com:jonathan-bu…
jonathan-buttner May 20, 2025
6be22b5
Fixing test
jonathan-buttner May 20, 2025
da1c71f
Fixing transport version
jonathan-buttner May 20, 2025
84c16ce
Adding feature flag
jonathan-buttner May 29, 2025
7eb72ff
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 29, 2025
280d4dd
[CI] Auto commit changes from spotless
elasticsearchmachine May 29, 2025
d1137b6
Fixing test issue
jonathan-buttner May 29, 2025
e955bf4
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 29, 2025
27fdfa8
Merge branch 'custom-inference-service-jon' of github.com:jonathan-bu…
jonathan-buttner May 29, 2025
7d2c112
[CI] Auto commit changes from spotless
elasticsearchmachine May 29, 2025
63fdaed
Fixing the expected values
jonathan-buttner May 29, 2025
95f05d2
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 29, 2025
e085040
Merge branch 'custom-inference-service-jon' of github.com:jonathan-bu…
jonathan-buttner May 29, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127939.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127939
summary: Add Custom inference service
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_RERANK_ADDED_8_19 = def(8_841_0_36);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION_8_19 = def(8_841_0_37);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED_8_19 = def(8_841_0_38);
public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED_8_19 = def(8_841_0_39);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -269,7 +270,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_DRY_RUN = def(9_081_0_00);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_CHAT_COMPLETION = def(9_082_0_00);
public static final TransportVersion ML_INFERENCE_VERTEXAI_CHATCOMPLETION_ADDED = def(9_083_0_00);

public static final TransportVersion INFERENCE_CUSTOM_SERVICE_ADDED = def(9_084_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ public enum FeatureFlag {
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
DOC_VALUES_SKIPPER("es.doc_values_skipper_feature_flag_enabled=true", Version.fromString("8.18.1"), null),
USE_LUCENE101_POSTINGS_FORMAT("es.use_lucene101_postings_format_feature_flag_enabled=true", Version.fromString("9.1.0"), null);
USE_LUCENE101_POSTINGS_FORMAT("es.use_lucene101_postings_format_feature_flag_enabled=true", Version.fromString("9.1.0"), null),
INFERENCE_CUSTOM_SERVICE_ENABLED("es.inference_custom_service_feature_flag_enabled=true", Version.fromString("8.19.0"), null);

public final String systemProperty;
public final Version from;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.FeatureFlag;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.junit.ClassRule;
Expand Down Expand Up @@ -46,6 +47,7 @@ public class BaseMockEISAuthServerTest extends ESRestTestCase {
// This plugin is located in the inference/qa/test-service-plugin package, look for TestInferenceServicePlugin
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_CUSTOM_SERVICE_ENABLED)
.build();

// The reason we're doing this is to make sure the mock server is initialized first so we can get the address before communicating
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.FeatureFlag;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -50,6 +51,7 @@ public class InferenceBaseRestTest extends ESRestTestCase {
.setting("xpack.security.enabled", "true")
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_CUSTOM_SERVICE_ENABLED)
.build();

@ClassRule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(22));
assertThat(services.size(), equalTo(23));

var providers = providers(services);

Expand All @@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"elastic",
"elasticsearch",
Expand Down Expand Up @@ -70,7 +71,7 @@ private Iterable<String> providers(List<Object> services) {

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(16));
assertThat(services.size(), equalTo(17));

var providers = providers(services);

Expand All @@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(8));
assertThat(services.size(), equalTo(9));

var providers = providers(services);

Expand All @@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
List.of(
"alibabacloud-ai-search",
"cohere",
"custom",
"elasticsearch",
"googlevertexai",
"jinaai",
Expand All @@ -124,7 +127,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(12));
assertThat(services.size(), equalTo(13));

var providers = providers(services);

Expand All @@ -138,6 +141,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"googleaistudio",
"openai",
Expand Down Expand Up @@ -173,7 +177,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

var providers = providers(services);

Expand All @@ -182,6 +186,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"custom",
"elastic",
"elasticsearch",
"hugging_face",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.common.util.FeatureFlag;

public class CustomServiceFeatureFlag {
/**
* {@link org.elasticsearch.xpack.inference.services.custom.CustomService} feature flag. When the feature is complete,
* this flag will be removed.
* Enable feature via JVM option: `-Des.inference_custom_service_feature_flag_enabled=true`.
*/
public static final FeatureFlag CUSTOM_SERVICE_FEATURE_FLAG = new FeatureFlag("inference_custom_service");

private CustomServiceFeatureFlag() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
Expand Down Expand Up @@ -108,6 +117,8 @@
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.xpack.inference.CustomServiceFeatureFlag.CUSTOM_SERVICE_FEATURE_FLAG;

public class InferenceNamedWriteablesProvider {

private InferenceNamedWriteablesProvider() {}
Expand Down Expand Up @@ -158,6 +169,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addCustomNamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand All @@ -169,6 +181,42 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return namedWriteables;
}

private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
if (CUSTOM_SERVICE_FEATURE_FLAG.isEnabled() == false) {
return;
}

namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));

namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
CustomResponseParser.class,
SparseEmbeddingResponseParser.NAME,
SparseEmbeddingResponseParser::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new)
);
}

private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
var writeables = UnifiedCompletionRequest.getNamedWriteables();
namedWriteables.addAll(writeables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
Expand Down Expand Up @@ -148,8 +149,10 @@
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Stream;

import static java.util.Collections.singletonList;
import static org.elasticsearch.xpack.inference.CustomServiceFeatureFlag.CUSTOM_SERVICE_FEATURE_FLAG;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;

Expand Down Expand Up @@ -379,7 +382,11 @@ public void loadExtensions(ExtensionLoader loader) {
}

public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
List<InferenceServiceExtension.Factory> conditionalServices = CUSTOM_SERVICE_FEATURE_FLAG.isEnabled()
? List.of(context -> new CustomService(httpFactory.get(), serviceComponents.get()))
: List.of();

List<InferenceServiceExtension.Factory> availableServices = List.of(
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()),
context -> new OpenAiService(httpFactory.get(), serviceComponents.get()),
Expand All @@ -398,6 +405,8 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);

return Stream.concat(availableServices.stream(), conditionalServices.stream()).toList();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
protected final ResponseParser parseFunction;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making this available so the custom response handler can immediately return on a parse failure instead of retrying.

private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

Expand Down
Loading
Loading