diff --git a/docs/changelog/122218.yaml b/docs/changelog/122218.yaml new file mode 100644 index 0000000000000..bfd44399e3e8d --- /dev/null +++ b/docs/changelog/122218.yaml @@ -0,0 +1,5 @@ +pr: 122218 +summary: Integrate with `DeepSeek` API +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 72305efe26fd2..3ace93ece62f0 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -147,6 +147,7 @@ static TransportVersion def(int id) { public static final TransportVersion JINA_AI_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_19 = def(8_841_0_06); public static final TransportVersion RETRY_ILM_ASYNC_ACTION_REQUIRE_ERROR_8_19 = def(8_841_0_07); public static final TransportVersion INFERENCE_CONTEXT_8_X = def(8_841_0_08); + public static final TransportVersion ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -183,6 +184,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_SERIALIZE_BLOCK_TYPE_CODE = def(9_026_0_00); public static final TransportVersion ESQL_THREAD_NAME_IN_DRIVER_PROFILE = def(9_027_0_00); public static final TransportVersion INFERENCE_CONTEXT = def(9_028_0_00); + public static final TransportVersion ML_INFERENCE_DEEPSEEK = def(9_029_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 859a065b6e1a0..6f9a550481049 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -25,7 +25,7 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest { @SuppressWarnings("unchecked") public void testGetServicesWithoutTaskType() throws IOException { List services = getAllServices(); - assertThat(services.size(), equalTo(20)); + assertThat(services.size(), equalTo(21)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -41,6 +41,7 @@ public void testGetServicesWithoutTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "deepseek", "elastic", "elasticsearch", "googleaistudio", @@ -114,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(9)); + assertThat(services.size(), equalTo(10)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -130,6 +131,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "azureaistudio", "azureopenai", "cohere", + "deepseek", "googleaistudio", "openai", "streaming_completion_test_service" @@ -141,7 +143,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); - assertThat(services.size(), equalTo(3)); + assertThat(services.size(), equalTo(4)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -149,7 +151,7 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { providers[i] = (String) serviceConfig.get("service"); } - assertArrayEquals(List.of("elastic", "openai", "streaming_completion_test_service").toArray(), providers); + assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers); } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index d57a6b86e4e71..e83d4243c7445 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -58,6 +58,7 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; @@ -153,6 +154,7 @@ public static List getNamedWriteables() { addUnifiedNamedWriteables(namedWriteables); namedWriteables.addAll(StreamingTaskManager.namedWriteables()); + namedWriteables.addAll(DeepSeekChatCompletionModel.namedWriteables()); return namedWriteables; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index ecc217c7dfe7d..7263f204808d2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -116,6 +116,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; @@ -362,6 +363,7 @@ public List getInferenceServiceFactories() { context -> new IbmWatsonxService(httpFactory.get(), serviceComponents.get()), context -> new JinaAIService(httpFactory.get(), serviceComponents.get()), context -> new VoyageAIService(httpFactory.get(), serviceComponents.get()), + context -> new DeepSeekService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java new file mode 100644 index 0000000000000..5fbc8883d5051 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/deepseek/DeepSeekChatCompletionRequest.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.deepseek; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class DeepSeekChatCompletionRequest implements Request { + private static final String MODEL_FIELD = "model"; + private static final String MAX_TOKENS = "max_tokens"; + + private final DeepSeekChatCompletionModel model; + private final UnifiedChatInput unifiedChatInput; + + public DeepSeekChatCompletionRequest(UnifiedChatInput unifiedChatInput, DeepSeekChatCompletionModel model) { + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + httpPost.setEntity(createEntity()); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(model.apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + private ByteArrayEntity createEntity() { + var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model); + try (var builder = JsonXContent.contentBuilder()) { + builder.startObject(); + new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS); + builder.field(MODEL_FIELD, modelId); + + if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { + builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens()); + } + + builder.endObject(); + return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8)); + } catch (IOException e) { + throw new ElasticsearchException("Failed to serialize request payload.", e); + } + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public boolean isStreaming() { + return unifiedChatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java new file mode 100644 index 0000000000000..ffc5bfb1eb918 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DeepSeekRequestManager.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.deepseek.DeepSeekChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel; + +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs.createUnsupportedTypeException; + +public class DeepSeekRequestManager extends BaseRequestManager { + + private static final Logger logger = LogManager.getLogger(DeepSeekRequestManager.class); + + private static final ResponseHandler CHAT_COMPLETION = createChatCompletionHandler(); + private static final ResponseHandler COMPLETION = createCompletionHandler(); + + private final DeepSeekChatCompletionModel model; + + public DeepSeekRequestManager(DeepSeekChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model.getInferenceEntityId(), model.rateLimitGroup(), model.rateLimitSettings()); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + switch (inferenceInputs) { + case UnifiedChatInput uci -> execute(uci, requestSender, hasRequestCompletedFunction, listener); + case ChatCompletionInput cci -> execute(cci, requestSender, hasRequestCompletedFunction, listener); + default -> throw createUnsupportedTypeException(inferenceInputs, UnifiedChatInput.class); + } + } + + private void execute( + UnifiedChatInput inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var request = new DeepSeekChatCompletionRequest(inferenceInputs, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, CHAT_COMPLETION, hasRequestCompletedFunction, listener)); + } + + private void execute( + ChatCompletionInput inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var unifiedInputs = new UnifiedChatInput(inferenceInputs.getInputs(), "user", inferenceInputs.stream()); + var request = new DeepSeekChatCompletionRequest(unifiedInputs, model); + execute(new ExecutableInferenceRequest(requestSender, logger, request, COMPLETION, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createChatCompletionHandler() { + return new OpenAiUnifiedChatCompletionResponseHandler("deepseek chat completion", OpenAiChatCompletionResponseEntity::fromResponse); + } + + private static ResponseHandler createCompletionHandler() { + return new OpenAiChatCompletionResponseHandler("deepseek completion", OpenAiChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionRequestEntity.java index 41900a96d9e7a..22df64f5fe80a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -21,12 +21,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec public static final String USER_FIELD = "user"; private static final String MODEL_FIELD = "model"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private final UnifiedChatInput unifiedChatInput; private final OpenAiChatCompletionModel model; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { - this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); this.model = Objects.requireNonNull(model); } @@ -41,6 +44,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(USER_FIELD, model.getTaskSettings().user()); } + if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens()); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java index ded8a074478cf..2631eaa085fb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java @@ -17,12 +17,15 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject { private static final String MODEL_FIELD = "model"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private final UnifiedChatInput unifiedChatInput; private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; private final String modelId; public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) { - this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(Objects.requireNonNull(unifiedChatInput)); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); this.modelId = Objects.requireNonNull(modelId); } @@ -31,6 +34,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); unifiedRequestEntity.toXContent(builder, params); builder.field(MODEL_FIELD, modelId); + + if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens()); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java index 5e6d09cde2b9f..6a6f8d92c74ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java @@ -32,7 +32,6 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment { public static final String MESSAGES_FIELD = "messages"; private static final String ROLE_FIELD = "role"; private static final String CONTENT_FIELD = "content"; - private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; private static final String STOP_FIELD = "stop"; private static final String TEMPERATURE_FIELD = "temperature"; private static final String TOOL_CHOICE_FIELD = "tool_choice"; @@ -104,10 +103,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); - if (unifiedRequest.maxCompletionTokens() != null) { - builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); - } - // Underlying providers expect OpenAI to only return 1 possible choice. builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java new file mode 100644 index 0000000000000..bcfcf279ab768 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekChatCompletionModel.java @@ -0,0 +1,200 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.deepseek; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; + +/** + * Design notes: + * This provider tries to match the OpenAI, so we'll design around that as well. + * + * Task Type: + * - Chat Completion + * + * Service Settings: + * - api_key + * - model + * - url + * + * Task Settings: + * - nothing? + * + * Rate Limiting: + * - The website claims to want unlimited, so we're setting it as MAX_INT per minute? + */ +public class DeepSeekChatCompletionModel extends Model { + // Per-node rate limit group and settings, limiting the outbound requests this node can make to INTEGER.MAX_VALUE per minute. + private static final Object RATE_LIMIT_GROUP = new Object(); + private static final RateLimitSettings RATE_LIMIT_SETTINGS = new RateLimitSettings(Integer.MAX_VALUE); + + private static final URI DEFAULT_URI = URI.create("https://api.deepseek.com/chat/completions"); + private final DeepSeekServiceSettings serviceSettings; + private final DefaultSecretSettings secretSettings; + + public static List namedWriteables() { + return List.of(new NamedWriteableRegistry.Entry(ServiceSettings.class, DeepSeekServiceSettings.NAME, DeepSeekServiceSettings::new)); + } + + public static DeepSeekChatCompletionModel createFromNewInput( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettingsMap + ) { + var validationException = new ValidationException(); + + var model = extractRequiredString(serviceSettingsMap, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = createOptionalUri( + extractOptionalString(serviceSettingsMap, URL, ModelConfigurations.SERVICE_SETTINGS, validationException) + ); + var secureApiToken = extractRequiredSecureString( + serviceSettingsMap, + "api_key", + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var serviceSettings = new DeepSeekServiceSettings(model, uri); + var taskSettings = new EmptyTaskSettings(); + var secretSettings = new DefaultSecretSettings(secureApiToken); + var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings); + return new DeepSeekChatCompletionModel(serviceSettings, secretSettings, modelConfigurations, new ModelSecrets(secretSettings)); + } + + public static DeepSeekChatCompletionModel readFromStorage( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettingsMap, + Map secrets + ) { + var validationException = new ValidationException(); + + var model = extractRequiredString(serviceSettingsMap, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var uri = createOptionalUri( + extractOptionalString(serviceSettingsMap, "url", ModelConfigurations.SERVICE_SETTINGS, validationException) + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + var serviceSettings = new DeepSeekServiceSettings(model, uri); + var taskSettings = new EmptyTaskSettings(); + var secretSettings = DefaultSecretSettings.fromMap(secrets); + var modelConfigurations = new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings); + return new DeepSeekChatCompletionModel(serviceSettings, secretSettings, modelConfigurations, new ModelSecrets(secretSettings)); + } + + private DeepSeekChatCompletionModel( + DeepSeekServiceSettings serviceSettings, + DefaultSecretSettings secretSettings, + ModelConfigurations configurations, + ModelSecrets secrets + ) { + super(configurations, secrets); + this.serviceSettings = serviceSettings; + this.secretSettings = secretSettings; + } + + public SecureString apiKey() { + return secretSettings.apiKey(); + } + + public String model() { + return serviceSettings.modelId(); + } + + public URI uri() { + return serviceSettings.uri() != null ? serviceSettings.uri() : DEFAULT_URI; + } + + public Object rateLimitGroup() { + return RATE_LIMIT_GROUP; + } + + public RateLimitSettings rateLimitSettings() { + return RATE_LIMIT_SETTINGS; + } + + private record DeepSeekServiceSettings(String modelId, URI uri) implements ServiceSettings { + private static final String NAME = "deep_seek_service_settings"; + + DeepSeekServiceSettings { + Objects.requireNonNull(modelId); + } + + DeepSeekServiceSettings(StreamInput in) throws IOException { + this(in.readString(), in.readOptional(url -> URI.create(url.readString()))); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_DEEPSEEK; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalString(uri != null ? uri.toString() : null); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID, modelId); + if (uri != null) { + builder.field(URL, uri.toString()); + } + return builder.endObject(); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java new file mode 100644 index 0000000000000..4433c43e1b8f0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekService.java @@ -0,0 +1,233 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.deepseek; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.DeepSeekRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class DeepSeekService extends SenderService { + private static final String NAME = "deepseek"; + private static final String CHAT_COMPLETION_ERROR_PREFIX = "deepseek chat completions"; + private static final String COMPLETION_ERROR_PREFIX = "deepseek completions"; + private static final String SERVICE_NAME = "DeepSeek"; + // The task types exposed via the _inference/_services API + private static final EnumSet SUPPORTED_TASK_TYPES_FOR_SERVICES_API = EnumSet.of( + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION + ); + + public DeepSeekService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + protected void doInfer( + Model model, + InferenceInputs inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + doInfer(model, inputs, timeout, COMPLETION_ERROR_PREFIX, listener); + } + + private void doInfer( + Model model, + InferenceInputs inputs, + TimeValue timeout, + String errorPrefix, + ActionListener listener + ) { + if (model instanceof DeepSeekChatCompletionModel deepSeekModel) { + var requestCreator = new DeepSeekRequestManager(deepSeekModel, getServiceComponents().threadPool()); + var errorMessage = constructFailedToSendRequestMessage(errorPrefix); + var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); + action.execute(inputs, timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + doInfer(model, inputs, timeout, CHAT_COMPLETION_ERROR_PREFIX, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + DocumentsOnlyInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + listener.onFailure(new UnsupportedOperationException(Strings.format("The %s service only supports unified completion", NAME))); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + ActionListener parsedModelListener + ) { + ActionListener.completeWith(parsedModelListener, () -> { + var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + try { + return DeepSeekChatCompletionModel.createFromNewInput(modelId, taskType, NAME, serviceSettingsMap); + } finally { + throwIfNotEmptyMap(serviceSettingsMap, NAME); + } + }); + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + var serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + var secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS); + return DeepSeekChatCompletionModel.readFromStorage(modelId, taskType, NAME, serviceSettingsMap, secretSettingsMap); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + return parsePersistedConfigWithSecrets(modelId, taskType, config, config); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES_FOR_SERVICES_API; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_DEEPSEEK; + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.CHAT_COMPLETION); + } + + @Override + public void checkModelConfig(Model model, ActionListener listener) { + // TODO: Remove this function once all services have been updated to use the new model validators + ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener); + } + + private static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + MODEL_ID, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDescription( + "The name of the model to use for the inference task." + ) + .setLabel("Model ID") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll( + DefaultSecretSettings.toSettingsConfigurationWithDescription( + "The DeepSeek API authentication key. For more details about generating DeepSeek API keys, " + + "refer to https://api-docs.deepseek.com.", + SUPPORTED_TASK_TYPES_FOR_SERVICES_API + ) + ); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES_FOR_SERVICES_API).setDefaultValue( + "https://api.deepseek.com/chat/completions" + ) + .setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(false) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES_FOR_SERVICES_API) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java new file mode 100644 index 0000000000000..277eba9e7dbfc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/deepseek/DeepSeekServiceTests.java @@ -0,0 +1,441 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.deepseek; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.action.support.ActionTestUtils.assertNoFailureListener; +import static org.elasticsearch.action.support.ActionTestUtils.assertNoSuccessListener; +import static org.elasticsearch.common.Strings.format; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; + +public class DeepSeekServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig() throws IOException, URISyntaxException { + parseRequestConfig(format(""" + { + "service_settings": { + "api_key": "12345", + "model_id": "some-cool-model", + "url": "%s" + } + } + """, webServer.getUri(null).toString()), assertNoFailureListener(model -> { + if (model instanceof DeepSeekChatCompletionModel deepSeekModel) { + assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray())); + assertThat(deepSeekModel.model(), equalTo("some-cool-model")); + assertThat(deepSeekModel.uri(), equalTo(webServer.getUri(null))); + } else { + fail("Expected DeepSeekModel, found " + (model != null ? model.getClass().getSimpleName() : "null")); + } + })); + } + + public void testParseRequestConfigWithoutApiKey() throws IOException { + parseRequestConfig(""" + { + "service_settings": { + "model_id": "some-cool-model" + } + } + """, assertNoSuccessListener(e -> { + if (e instanceof ValidationException ve) { + assertThat( + ve.getMessage(), + equalTo("Validation Failed: 1: [service_settings] does not contain the required setting [api_key];") + ); + } + })); + } + + public void testParseRequestConfigWithoutModel() throws IOException { + parseRequestConfig(""" + { + "service_settings": { + "api_key": "1234" + } + } + """, assertNoSuccessListener(e -> { + if (e instanceof ValidationException ve) { + assertThat( + ve.getMessage(), + equalTo("Validation Failed: 1: [service_settings] does not contain the required setting [model_id];") + ); + } + })); + } + + public void testParseRequestConfigWithExtraSettings() throws IOException { + parseRequestConfig( + """ + { + "service_settings": { + "api_key": "12345", + "model_id": "some-cool-model", + "so": "extra" + } + } + """, + assertNoSuccessListener( + e -> assertThat( + e.getMessage(), + equalTo("Model configuration contains settings [{so=extra}] unknown to the [deepseek] service") + ) + ) + ); + } + + public void testParsePersistedConfig() throws IOException { + var deepSeekModel = parsePersistedConfig(""" + { + "service_settings": { + "model_id": "some-cool-model" + }, + "secret_settings": { + "api_key": "12345" + } + } + """); + assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray())); + assertThat(deepSeekModel.model(), equalTo("some-cool-model")); + } + + public void testParsePersistedConfigWithUrl() throws IOException { + var deepSeekModel = parsePersistedConfig(""" + { + "service_settings": { + "model_id": "some-cool-model", + "url": "http://localhost:989" + }, + "secret_settings": { + "api_key": "12345" + } + } + """); + assertThat(deepSeekModel.apiKey().getChars(), equalTo("12345".toCharArray())); + assertThat(deepSeekModel.model(), equalTo("some-cool-model")); + assertThat(deepSeekModel.uri(), equalTo(URI.create("http://localhost:989"))); + } + + public void testParsePersistedConfigWithoutApiKey() { + assertThrows( + "Validation Failed: 1: [secret_settings] does not contain the required setting [api_key];", + ValidationException.class, + () -> parsePersistedConfig(""" + { + "service_settings": { + "model_id": "some-cool-model" + }, + "secret_settings": { + } + } + """) + ); + } + + public void testParsePersistedConfigWithoutModel() { + assertThrows( + "Validation Failed: 1: [service_settings] does not contain the required setting [model];", + ValidationException.class, + () -> parsePersistedConfig(""" + { + "service_settings": { + }, + "secret_settings": { + "api_key": "12345" + } + } + """) + ); + } + + public void testParsePersistedConfigWithoutServiceSettings() { + assertThrows( + "Validation Failed: 1: [service_settings] does not contain the required setting [model];", + ElasticsearchStatusException.class, + () -> parsePersistedConfig(""" + { + "secret_settings": { + "api_key": "12345" + } + } + """) + ); + } + + public void testDoUnifiedInfer() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {"choices": [{"delta": {"content": "hello, world", "role": "assistant"}, "finish_reason": null, "index": 0, \ + "logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \ + "object": "chat.completion.chunk", "system_fingerprint": "fp_1234"} + + data: [DONE] + + """)); + doUnifiedCompletionInfer().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"hello, world","role":"assistant"},"index":0}],""" + """ + "model":"deepseek-chat","object":"chat.completion.chunk"}"""); + } + + public void testDoInfer() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + {"choices": [{"message": {"content": "hello, world", "role": "assistant"}, "finish_reason": "stop", "index": 0, \ + "logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \ + "object": "chat.completion", "system_fingerprint": "fp_1234"}""")); + try (var service = createService()) { + var model = createModel(service, TaskType.COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, null, List.of("hello"), false, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + var result = listener.actionGet(TIMEOUT); + assertThat(result, isA(ChatCompletionResults.class)); + var completionResults = (ChatCompletionResults) result; + assertThat( + completionResults.results().stream().map(ChatCompletionResults.Result::predictedValue).toList(), + equalTo(List.of("hello, world")) + ); + } + } + + public void testDoInferStream() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {"choices": [{"delta": {"content": "hello, world", "role": "assistant"}, "finish_reason": null, "index": 0, \ + "logprobs": null}], "created": 1718345013, "id": "12345", "model": "deepseek-chat", \ + "object": "chat.completion.chunk", "system_fingerprint": "fp_1234"} + + data: [DONE] + + """)); + try (var service = createService()) { + var model = createModel(service, TaskType.COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer(model, null, List.of("hello"), true, Map.of(), InputType.UNSPECIFIED, TIMEOUT, listener); + InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + } + + public void testUnifiedCompletionError() { + String responseJson = """ + { + "error": { + "message": "The model `deepseek-not-chat` does not exist or you do not have access to it.", + "type": "invalid_request_error", + "param": null, + "code": "model_not_found" + } + }"""; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + var e = assertThrows(UnifiedChatCompletionException.class, this::doUnifiedCompletionInfer); + assertThat( + e.getMessage(), + equalTo( + "Received an unsuccessful status code for request from inference entity id [inference-id] status" + + " [404]. Error message: [The model `deepseek-not-chat` does not exist or you do not have access to it.]" + ) + ); + } + + private void testStreamError(String expectedResponse) throws Exception { + try (var service = createService()) { + var model = createModel(service, TaskType.CHAT_COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + event: error + data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "message":"Received an error response for request from inference entity id [inference-id]. Error message: \ + [Timed out waiting for more data]",\ + "type":"timeout"\ + }}"""); + } + + public void testUnifiedCompletionMalformedError() throws Exception { + String responseJson = """ + data: { invalid json } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"bad_request",\ + "message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\ + at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\ + "type":"x_content_parse_exception"\ + }}"""); + } + + public void testDoChunkedInferAlwaysFails() throws IOException { + try (var service = createService()) { + service.doChunkedInfer(mock(), mock(), Map.of(), InputType.UNSPECIFIED, TIMEOUT, assertNoSuccessListener(e -> { + assertThat(e, isA(UnsupportedOperationException.class)); + assertThat(e.getMessage(), equalTo("The deepseek service only supports unified completion")); + })); + } + } + + private DeepSeekService createService() { + return new DeepSeekService( + HttpRequestSenderTests.createSenderFactory(threadPool, clientManager), + createWithEmptySettings(threadPool) + ); + } + + private void parseRequestConfig(String json, ActionListener listener) throws IOException { + try (var service = createService()) { + service.parseRequestConfig("inference-id", TaskType.CHAT_COMPLETION, map(json), listener); + } + } + + private Map map(String json) throws IOException { + try ( + var parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, json.getBytes(StandardCharsets.UTF_8)) + ) { + return parser.map(); + } + } + + private DeepSeekChatCompletionModel parsePersistedConfig(String json) throws IOException { + try (var service = createService()) { + var model = service.parsePersistedConfig("inference-id", TaskType.CHAT_COMPLETION, map(json)); + assertThat(model, isA(DeepSeekChatCompletionModel.class)); + return (DeepSeekChatCompletionModel) model; + } + } + + private InferenceEventsAssertion doUnifiedCompletionInfer() throws Exception { + try (var service = createService()) { + var model = createModel(service, TaskType.CHAT_COMPLETION); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + TIMEOUT, + listener + ); + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + private DeepSeekChatCompletionModel createModel(DeepSeekService service, TaskType taskType) throws URISyntaxException, IOException { + var model = service.parsePersistedConfig("inference-id", taskType, map(Strings.format(""" + { + "service_settings": { + "model_id": "some-cool-model", + "url": "%s" + }, + "secret_settings": { + "api_key": "12345" + } + } + """, webServer.getUri(null).toString()))); + assertThat(model, isA(DeepSeekChatCompletionModel.class)); + return (DeepSeekChatCompletionModel) model; + } +}