Skip to content

[ML] Integrate with DeepSeek API #122218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/122218.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 122218
summary: Integrate with `DeepSeek` API
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ static TransportVersion def(int id) {
public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06);
public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07);
public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08);
public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -183,6 +184,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(9_026_0_00);
public static final TransportVersion ESQL_THREAD_NAME_IN_DRIVER_PROFILE = def(9_027_0_00);
public static final TransportVersion INFERENCE_CONTEXT = def(9_028_0_00);
public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_029_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
@SuppressWarnings("unchecked")
public void testGetServicesWithoutTaskType() throws IOException {
List<Object> services = getAllServices();
assertThat(services.size(), equalTo(20));
assertThat(services.size(), equalTo(21));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -41,6 +41,7 @@ public void testGetServicesWithoutTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"elastic",
"elasticsearch",
"googleaistudio",
Expand Down Expand Up @@ -114,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(9));
assertThat(services.size(), equalTo(10));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Expand All @@ -130,6 +131,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"azureaistudio",
"azureopenai",
"cohere",
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
Expand All @@ -141,15 +143,15 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
@SuppressWarnings("unchecked")
public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(3));
assertThat(services.size(), equalTo(4));

String[] providers = new String[services.size()];
for (int i = 0; i < services.size(); i++) {
Map<String, Object> serviceConfig = (Map<String, Object>) services.get(i);
providers[i] = (String) serviceConfig.get("service");
}

assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers);
assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
Expand Down Expand Up @@ -153,6 +154,7 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
addUnifiedNamedWriteables(namedWriteables);

namedWriteables.addAll(StreamingTaskManager.namedWriteables());
namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables());

return namedWriteables;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService;
import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService;
import org.elasticsearch.xpack.inference.services.cohere.CohereService;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
Expand Down Expand Up @@ -362,6 +363,7 @@ public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()),
context -> new JinaAIService(httpFactory.get(), serviceComponents.get()),
context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()),
context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()),
ElasticsearchInternalService::new
);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.deepseek;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class DeepSeekChatCompletionRequest implements Request {
private static final String MODEL_FIELD = "model";
private static final String MAX_TOKENS = "max_tokens";

private final DeepSeekChatCompletionModel model;
private final UnifiedChatInput unifiedChatInput;

public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) {
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

httpPost.setEntity(createEntity());

httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));

return new HttpRequest(httpPost, getInferenceEntityId());
}

private ByteArrayEntity createEntity() {
var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
try (var builder = JsonXContent.contentBuilder()) {
builder.startObject();
new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
builder.field(MODEL_FIELD, modelId);

if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens());
}

builder.endObject();
return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
} catch (IOException e) {
throw new ElasticsearchException("Failed to serialize request payload.", e);
}
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
return this;
}

@Override
public boolean[] getTruncationInfo() {
return null;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public boolean isStreaming() {
return unifiedChatInput.stream();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity;
import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;

import java.util.Objects;
import java.util.function.Supplier;

import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException;

public class DeepSeekRequestManager extends BaseRequestManager {

private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class);

private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler();
private static final ResponseHandler COMPLETION = createCompletionHandler();

private final DeepSeekChatCompletionModel model;

public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings());
this.model = Objects.requireNonNull(model);
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
switch (inferenceInputs) {
case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener);
case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener);
default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class);
}
}

private void execute(
UnifiedChatInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var request = new DeepSeekChatCompletionRequest(inferenceInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener));
}

private void execute(
ChatCompletionInput inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream());
var request = new DeepSeekChatCompletionRequest(unifiedInputs, model);
execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener));
}

private static ResponseHandler createChatCompletionHandler() {
return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse);
}

private static ResponseHandler createCompletionHandler() {
return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec

public static final String USER_FIELD = "user";
private static final String MODEL_FIELD = "model";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";

private final UnifiedChatInput unifiedChatInput;
private final OpenAiChatCompletionModel model;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;

public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}

Expand All @@ -41,6 +44,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(USER_FIELD, model.getTaskSettings().user());
}

if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
}

builder.endObject();

return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
private static final String MODEL_FIELD = "model";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";

private final UnifiedChatInput unifiedChatInput;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
private final String modelId;

public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput));
this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.modelId = Objects.requireNonNull(modelId);
}

Expand All @@ -31,6 +34,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.startObject();
unifiedRequestEntity.toXContent(builder, params);
builder.field(MODEL_FIELD, modelId);

if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
}

builder.endObject();

return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
public static final String MESSAGES_FIELD = "messages";
private static final String ROLE_FIELD = "role";
private static final String CONTENT_FIELD = "content";
private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
private static final String STOP_FIELD = "stop";
private static final String TEMPERATURE_FIELD = "temperature";
private static final String TOOL_CHOICE_FIELD = "tool_choice";
Expand Down Expand Up @@ -104,10 +103,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.endArray();

if (unifiedRequest.maxCompletionTokens() != null) {
builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens());
}

// Underlying providers expect OpenAI to only return 1 possible choice.
builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);

Expand Down
Loading