Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
14a5383
Inference changes
jonathan-buttner May 8, 2025
eba5fce
Custom service fixes
jonathan-buttner May 8, 2025
9af98be
Update docs/changelog/127939.yaml
jonathan-buttner May 8, 2025
cb09e30
Cleaning up from failed merge
jonathan-buttner May 8, 2025
c8642cd
Merge branch 'custom-inference-service-jon' of github.com:jonathan-bu…
jonathan-buttner May 8, 2025
e7c62d8
Fixing changelog
jonathan-buttner May 8, 2025
6bb2a95
[CI] Auto commit changes from spotless
May 8, 2025
67329e2
Fixing transport version
jonathan-buttner May 20, 2025
dd14970
Merge branch 'custom-inference-service-jon' of github.com:jonathan-bu…
jonathan-buttner May 20, 2025
6be22b5
Fixing test
jonathan-buttner May 20, 2025
da1c71f
Fixing transport version
jonathan-buttner May 20, 2025
84c16ce
Adding feature flag
jonathan-buttner May 29, 2025
7eb72ff
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 29, 2025
280d4dd
[CI] Auto commit changes from spotless
May 29, 2025
d1137b6
Fixing test issue
jonathan-buttner May 29, 2025
e955bf4
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 29, 2025
27fdfa8
Merge branch 'custom-inference-service-jon' of github.com:jonathan-bu…
jonathan-buttner May 29, 2025
7d2c112
[CI] Auto commit changes from spotless
May 29, 2025
63fdaed
Fixing the expected values
jonathan-buttner May 29, 2025
95f05d2
Merge branch 'main' of github.com:elastic/elasticsearch into custom-i…
jonathan-buttner May 29, 2025
e085040
Merge branch 'custom-inference-service-jon' of github.com:jonathan-bu…
jonathan-buttner May 29, 2025
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127939.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127939
summary: Add Custom inference service
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RESCORE_VECTOR_ALLOW_ZERO_BACKPORT_8_19 = def(8_841_0_27);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL_8_19 = def(8_841_0_30);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -252,7 +253,7 @@ static TransportVersion def(int id) {
public static final TransportVersion FIELD_CAPS_ADD_CLUSTER_ALIAS = def(9_073_0_00);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_074_0_00);
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);

public static final TransportVersion ADD_INFERENCE_CUSTOM_MODEL = def(9_076_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(22));
assertThat(services.size(), equalTo(23));

var providers = providers(services);

Expand All @@ -39,6 +39,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"elastic",
"elasticsearch",
Expand Down Expand Up @@ -70,7 +71,7 @@ private Iterable<String> providers(List<Object> services) {

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(16));
assertThat(services.size(), equalTo(17));

var providers = providers(services);

Expand All @@ -83,6 +84,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -101,7 +103,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {

public void testGetServicesWithRerankTaskType() throws IOException {
List<Object> services = getServices(TaskType.RERANK);
assertThat(services.size(), equalTo(7));
assertThat(services.size(), equalTo(8));

var providers = providers(services);

Expand All @@ -111,6 +113,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
List.of(
"alibabacloud-ai-search",
"cohere",
"custom",
"elasticsearch",
"googlevertexai",
"jinaai",
Expand All @@ -123,7 +126,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));

var providers = providers(services);

Expand All @@ -137,6 +140,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"deepseek",
"googleaistudio",
"openai",
Expand All @@ -157,7 +161,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException {

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.SPARSE_EMBEDDING);
assertThat(services.size(), equalTo(6));
assertThat(services.size(), equalTo(7));

var providers = providers(services);

Expand All @@ -166,6 +170,7 @@ public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"custom",
"elastic",
"elasticsearch",
"hugging_face",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
Expand Down Expand Up @@ -154,6 +163,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addAlibabaCloudSearchNamedWriteables(namedWriteables);
addJinaAINamedWriteables(namedWriteables);
addVoyageAINamedWriteables(namedWriteables);
addCustomNamedWriteables(namedWriteables);

addUnifiedNamedWriteables(namedWriteables);

Expand All @@ -165,6 +175,38 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
return namedWriteables;
}

private static void addCustomNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, CustomServiceSettings.NAME, CustomServiceSettings::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, CustomTaskSettings.NAME, CustomTaskSettings::new));

namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, CustomSecretSettings.NAME, CustomSecretSettings::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, TextEmbeddingResponseParser.NAME, TextEmbeddingResponseParser::new)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
CustomResponseParser.class,
SparseEmbeddingResponseParser.NAME,
SparseEmbeddingResponseParser::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, RerankResponseParser.NAME, RerankResponseParser::new)
);

namedWriteables.add(new NamedWriteableRegistry.Entry(CustomResponseParser.class, NoopResponseParser.NAME, NoopResponseParser::new));

namedWriteables.add(
new NamedWriteableRegistry.Entry(CustomResponseParser.class, CompletionResponseParser.NAME, CompletionResponseParser::new)
);
}

private static void addUnifiedNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
var writeables = UnifiedCompletionRequest.getNamedWriteables();
namedWriteables.addAll(writeables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.custom.CustomService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
Expand Down Expand Up @@ -396,6 +397,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
context -> new CustomService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public abstract class BaseResponseHandler implements ResponseHandler {
public static final String METHOD_NOT_ALLOWED = "Received a method not allowed status code";

protected final String requestType;
private final ResponseParser parseFunction;
protected final ResponseParser parseFunction;
private final Function<HttpResult, ErrorResponse> errorParseFunction;
private final boolean canHandleStreamingResponses;

Expand Down
Loading
Loading