From 63f21deb295e299ef8e28dba497f366a393b7bfc Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 23 Apr 2025 17:17:44 +0300 Subject: [PATCH 01/29] Add Hugging Face Chat Completion support to Inference Plugin --- .../InferenceNamedWriteablesProvider.java | 8 + .../xpack/inference/InferencePlugin.java | 2 + .../HuggingFaceCompletionRequestManager.java | 66 ++++++ .../HuggingFaceEmbeddingsRequestManager.java | 75 +++++++ .../HuggingFaceRequestManager.java | 58 +----- .../huggingface/HuggingFaceService.java | 25 ++- .../action/HuggingFaceActionCreator.java | 34 ++-- .../action/HuggingFaceActionVisitor.java | 3 + .../HuggingFaceChatCompletionModel.java | 95 +++++++++ .../HuggingFaceChatCompletionService.java | 167 +++++++++++++++ ...gingFaceChatCompletionServiceSettings.java | 190 ++++++++++++++++++ ...ggingFaceUnifiedChatCompletionRequest.java | 71 +++++++ ...aceUnifiedChatCompletionRequestEntity.java | 49 +++++ .../HuggingFaceInferenceRequest.java | 2 +- .../HuggingFaceInferenceRequestEntity.java | 2 +- ...ggingFaceChatCompletionResponseEntity.java | 56 ++++++ ...uggingFaceInferenceRequestEntityTests.java | 1 + .../HuggingFaceInferenceRequestTests.java | 1 + 18 files changed, 829 insertions(+), 76 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceCompletionRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/{ => embeddings}/HuggingFaceInferenceRequest.java (99%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/{ => embeddings}/HuggingFaceInferenceRequestEntity.java (98%) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceChatCompletionResponseEntity.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 63d9d8a3bd9d1..9db4e8f9c621e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -78,6 +78,7 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings; @@ -353,6 +354,13 @@ private static void addHuggingFaceNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 114d9eaedfa53..643de58f53e17 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -127,6 +127,7 @@ import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionService; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; @@ -361,6 +362,7 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), + context -> new HuggingFaceChatCompletionService(httpFactory.get(), serviceComponents.get()), context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), context -> new CohereService(httpFactory.get(), serviceComponents.get()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceCompletionRequestManager.java new file mode 100644 index 0000000000000..14bdda8cd3dc8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceCompletionRequestManager.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; + +import java.util.Objects; +import java.util.function.Supplier; + +public class HuggingFaceCompletionRequestManager extends HuggingFaceRequestManager { + private static final Logger logger = LogManager.getLogger(HuggingFaceCompletionRequestManager.class); + + public static HuggingFaceCompletionRequestManager of( + HuggingFaceChatCompletionModel model, + ResponseHandler responseHandler, + ThreadPool threadPool + ) { + return new HuggingFaceCompletionRequestManager( + Objects.requireNonNull(model), + Objects.requireNonNull(responseHandler), + Objects.requireNonNull(threadPool) + ); + } + + private final HuggingFaceChatCompletionModel model; + private final ResponseHandler responseHandler; + + private HuggingFaceCompletionRequestManager( + HuggingFaceChatCompletionModel model, + ResponseHandler responseHandler, + ThreadPool threadPool + ) { + super(model, threadPool); + this.model = model; + this.responseHandler = responseHandler; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var chatCompletionInput = inferenceInputs.castTo(UnifiedChatInput.class); + HuggingFaceUnifiedChatCompletionRequest request = new HuggingFaceUnifiedChatCompletionRequest(chatCompletionInput, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..c6ed5a000dc98 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequest; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class HuggingFaceEmbeddingsRequestManager extends HuggingFaceRequestManager { + private static final Logger logger = LogManager.getLogger(HuggingFaceEmbeddingsRequestManager.class); + + public static HuggingFaceEmbeddingsRequestManager of( + HuggingFaceModel model, + ResponseHandler responseHandler, + Truncator truncator, + ThreadPool threadPool + ) { + return new HuggingFaceEmbeddingsRequestManager( + Objects.requireNonNull(model), + Objects.requireNonNull(responseHandler), + Objects.requireNonNull(truncator), + Objects.requireNonNull(threadPool) + ); + } + + private final HuggingFaceModel model; + private final ResponseHandler responseHandler; + private final Truncator truncator; + + private HuggingFaceEmbeddingsRequestManager( + HuggingFaceModel model, + ResponseHandler responseHandler, + Truncator truncator, + ThreadPool threadPool + ) { + super(model, threadPool); + this.model = model; + this.responseHandler = responseHandler; + this.truncator = truncator; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); + var truncatedInput = truncate(docsInput, model.getTokenLimit()); + var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java index b09cf8d98b7f3..33c97b6fb811b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java @@ -7,66 +7,12 @@ package org.elasticsearch.xpack.inference.services.huggingface; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.common.Truncator; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest; -import java.util.List; -import java.util.Objects; -import java.util.function.Supplier; - -import static org.elasticsearch.xpack.inference.common.Truncator.truncate; - -public class HuggingFaceRequestManager extends BaseRequestManager { - private static final Logger logger = LogManager.getLogger(HuggingFaceRequestManager.class); - - public static HuggingFaceRequestManager of( - HuggingFaceModel model, - ResponseHandler responseHandler, - Truncator truncator, - ThreadPool threadPool - ) { - return new HuggingFaceRequestManager( - Objects.requireNonNull(model), - Objects.requireNonNull(responseHandler), - Objects.requireNonNull(truncator), - Objects.requireNonNull(threadPool) - ); - } - - private final HuggingFaceModel model; - private final ResponseHandler responseHandler; - private final Truncator truncator; - - private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler responseHandler, Truncator truncator, ThreadPool threadPool) { +public abstract class HuggingFaceRequestManager extends BaseRequestManager { + protected HuggingFaceRequestManager(HuggingFaceModel model, ThreadPool threadPool) { super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); - this.model = model; - this.responseHandler = responseHandler; - this.truncator = truncator; - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); - var truncatedInput = truncate(docsInput, model.getTokenLimit()); - var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); } record RateLimitGrouping(int accountHash) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index f2a53520e18e6..5cef2c65d13ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -51,7 +52,11 @@ public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; private static final String SERVICE_NAME = "Hugging Face"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING); + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING, + TaskType.COMPLETION + ); public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); @@ -78,6 +83,14 @@ protected HuggingFaceModel createModel( context ); case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); + case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -149,7 +162,7 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return supportedTaskTypes; + return SUPPORTED_TASK_TYPES; } @Override @@ -173,7 +186,7 @@ public static InferenceServiceConfiguration get() { configurationMap.put( URL, - new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings") + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDefaultValue("https://api.openai.com/v1/embeddings") .setDescription("The URL endpoint to use for the requests.") .setLabel("URL") .setRequired(true) @@ -183,12 +196,12 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) - .setTaskTypes(supportedTaskTypes) + .setTaskTypes(SUPPORTED_TASK_TYPES) .setConfigurations(configurationMap) .build(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index d0578edaeaec6..82390b9d358c2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -11,10 +11,13 @@ import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceCompletionRequestManager; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; @@ -26,6 +29,9 @@ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type. */ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { + + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = + "Failed to send Hugging Face %s request from inference entity id [%s]"; private final Sender sender; private final ServiceComponents serviceComponents; @@ -40,34 +46,38 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) { "hugging face text embeddings", HuggingFaceEmbeddingsResponseEntity::fromResponse ); - var requestCreator = HuggingFaceRequestManager.of( + var requestCreator = HuggingFaceEmbeddingsRequestManager.of( model, responseHandler, serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format( - "Failed to send Hugging Face %s request from inference entity id [%s]", - "text embeddings", - model.getInferenceEntityId() - ); + var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "text embeddings", model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } @Override public ExecutableAction create(HuggingFaceElserModel model) { var responseHandler = new HuggingFaceResponseHandler("hugging face elser", HuggingFaceElserResponseEntity::fromResponse); - var requestCreator = HuggingFaceRequestManager.of( + var requestCreator = HuggingFaceEmbeddingsRequestManager.of( model, responseHandler, serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format( - "Failed to send Hugging Face %s request from inference entity id [%s]", - "ELSER", - model.getInferenceEntityId() + var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId()); + return new SenderExecutableAction(sender, requestCreator, errorMessage); + } + + @Override + public ExecutableAction create(HuggingFaceChatCompletionModel model) { + var responseHandler = new HuggingFaceResponseHandler( + "hugging face chat completion", + HuggingFaceChatCompletionResponseEntity::fromResponse ); + + var requestCreator = HuggingFaceCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool()); + var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java index 3fb7b538769e9..ee308db774b1d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.huggingface.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; @@ -15,4 +16,6 @@ public interface HuggingFaceActionVisitor { ExecutableAction create(HuggingFaceEmbeddingsModel model); ExecutableAction create(HuggingFaceElserModel model); + + ExecutableAction create(HuggingFaceChatCompletionModel model); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java new file mode 100644 index 0000000000000..c843ab03f9932 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; +import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; +import java.util.Objects; + +public class HuggingFaceChatCompletionModel extends HuggingFaceModel { + + public static HuggingFaceChatCompletionModel of(HuggingFaceChatCompletionModel model, UnifiedCompletionRequest request) { + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new HuggingFaceChatCompletionServiceSettings( + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.uri(), + originalModelServiceSettings.maxInputTokens(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new HuggingFaceChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + public HuggingFaceChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + HuggingFaceChatCompletionServiceSettings.fromMap(serviceSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + HuggingFaceChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + HuggingFaceChatCompletionServiceSettings serviceSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings), + new ModelSecrets(secretSettings), + serviceSettings, + secretSettings + ); + } + + @Override + public HuggingFaceChatCompletionServiceSettings getServiceSettings() { + return (HuggingFaceChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(HuggingFaceActionVisitor creator) { + return creator.create(this); + } + + @Override + public Integer getTokenLimit() { + return getServiceSettings().maxInputTokens(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java new file mode 100644 index 0000000000000..fcaab46835090 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; +import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; +import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings.URL; + +public class HuggingFaceChatCompletionService extends HuggingFaceBaseService { + public static final String NAME = "hugging_face_chat_completion"; + + private static final String SERVICE_NAME = "Hugging Face CHAT COMPLETION"; + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); + + public HuggingFaceChatCompletionService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + protected HuggingFaceModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + ChunkingSettings chunkingSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + return switch (taskType) { + case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + }; + } + + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof HuggingFaceChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model; + var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); + var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest()); + var action = overriddenModel.accept(actionCreator); + + action.execute(inputs, timeout, listener); + } + + @Override + protected void doChunkedInfer( + Model model, + EmbeddingsInput inputs, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener> listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + + @Override + public InferenceServiceConfiguration getConfiguration() { + return Configuration.get(); + } + + @Override + public boolean hideFromConfigurationApi() { + return true; + } + + @Override + public EnumSet supportedTaskTypes() { + return SUPPORTED_TASK_TYPES; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_12_0; + } + + public static class Configuration { + public static InferenceServiceConfiguration get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable configuration = new LazyInitializable<>( + () -> { + var configurationMap = new HashMap(); + + configurationMap.put( + URL, + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") + .setLabel("URL") + .setRequired(true) + .setSensitive(false) + .setUpdatable(false) + .setType(SettingsConfigurationFieldType.STRING) + .build() + ); + + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + + return new InferenceServiceConfiguration.Builder().setService(NAME) + .setName(SERVICE_NAME) + .setTaskTypes(SUPPORTED_TASK_TYPES) + .setConfigurations(configurationMap) + .build(); + } + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..8ab84dea6b12b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; + +public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + HuggingFaceRateLimitServiceSettings { + + public static final String NAME = "hugging_face_completion_service_settings"; + public static final String URL = "url"; + // At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess + // 3000 requests per minute + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + public static HuggingFaceChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + var uri = extractUri(map, URL, validationException); + + Integer maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + return new HuggingFaceChatCompletionServiceSettings(modelId, uri, maxInputTokens, rateLimitSettings); + } + + private final String modelId; + private final URI uri; + private final Integer maxInputTokens; + private final RateLimitSettings rateLimitSettings; + + public HuggingFaceChatCompletionServiceSettings( + String modelId, + @Nullable String url, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this(modelId, createUri(url), maxInputTokens, rateLimitSettings); + } + + public HuggingFaceChatCompletionServiceSettings( + String modelId, + @Nullable URI uri, + @Nullable Integer maxInputTokens, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.uri = uri; + this.maxInputTokens = maxInputTokens; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public HuggingFaceChatCompletionServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.uri = createUri(in.readString()); + this.maxInputTokens = in.readOptionalVInt(); + + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { + this.rateLimitSettings = new RateLimitSettings(in); + } else { + this.rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; + } + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public URI uri() { + return uri; + } + + public int maxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + + builder.field(URL, uri.toString()); + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_12_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalString(uri != null ? uri.toString() : null); + out.writeOptionalVInt(maxInputTokens); + + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { + rateLimitSettings.writeTo(out); + } + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + HuggingFaceChatCompletionServiceSettings that = (HuggingFaceChatCompletionServiceSettings) object; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, maxInputTokens, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java new file mode 100644 index 0000000000000..78171101fbe6d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class HuggingFaceUnifiedChatCompletionRequest implements Request { + + private final HuggingFaceAccount account; + private final HuggingFaceChatCompletionModel model; + private final UnifiedChatInput unifiedChatInput; + + public HuggingFaceUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) { + this.account = HuggingFaceAccount.of(model); + this.model = Objects.requireNonNull(model); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + } + + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(getURI()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(model.apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + public URI getURI() { + return account.uri(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public Request truncate() { + return this; + } + + @Override + public boolean[] getTruncationInfo() { + return null; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..de7f1bf04de0d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.completion; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +public class HuggingFaceUnifiedChatCompletionRequestEntity implements ToXContentObject { + + private static final String MODEL_FIELD = "model"; + private static final String MAX_TOKENS_FIELD = "max_tokens"; + + private final UnifiedChatInput unifiedChatInput; + private final HuggingFaceChatCompletionModel model; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + public HuggingFaceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) { + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, params); + + builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); + + if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { + builder.field(MAX_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens()); + } + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceInferenceRequest.java similarity index 99% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceInferenceRequest.java index af4fafff0fb2f..ea90e7de5a2f3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceInferenceRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.huggingface.request; +package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceInferenceRequestEntity.java similarity index 98% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceInferenceRequestEntity.java index 0c38929568f61..e275d84d2acc5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceInferenceRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.huggingface.request; +package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceChatCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceChatCompletionResponseEntity.java new file mode 100644 index 0000000000000..7e7dacb148f14 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceChatCompletionResponseEntity.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.response; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class HuggingFaceChatCompletionResponseEntity { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in HuggingFace chat response"; + + public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + // Ensure the response starts with an array + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_ARRAY, token, jsonParser); + + // Move to the first object in the array + token = jsonParser.nextToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + // Position the parser at the "generated_text" field + positionParserAtTokenAfterField(jsonParser, "generated_text", FAILED_TO_FIND_FIELD_TEMPLATE); + + // Extract the "generated_text" value + XContentParser.Token contentToken = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser); + String content = jsonParser.text(); + + return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content))); + } + } + + private HuggingFaceChatCompletionResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java index ed15fb77d04a7..3f7e716d25749 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequestEntity; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java index 768eeb622b943..e1f38b57ece3f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequest; import java.io.IOException; import java.net.URI; From 65e40606e9e45a3118e19c40f9fba83f3e4f0ae4 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 25 Apr 2025 13:39:23 +0300 Subject: [PATCH 02/29] Add support for streaming chat completion task for HuggingFace --- .../xpack/inference/InferencePlugin.java | 2 - ...gingFaceChatCompletionRequestManager.java} | 17 +- .../HuggingFaceEmbeddingsRequestManager.java | 17 +- .../huggingface/HuggingFaceService.java | 28 ++- .../action/HuggingFaceActionCreator.java | 11 +- .../HuggingFaceChatCompletionService.java | 167 ------------------ ...gingFaceChatCompletionServiceSettings.java | 30 +++- ...ggingFaceUnifiedChatCompletionRequest.java | 17 ++ ...java => HuggingFaceEmbeddingsRequest.java} | 18 +- ...> HuggingFaceEmbeddingsRequestEntity.java} | 8 +- ...ggingFaceChatCompletionResponseEntity.java | 56 ------ ...gingFaceEmbeddingsRequestEntityTests.java} | 6 +- ...=> HuggingFaceEmbeddingsRequestTests.java} | 8 +- 13 files changed, 127 insertions(+), 258 deletions(-) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/{HuggingFaceCompletionRequestManager.java => HuggingFaceChatCompletionRequestManager.java} (80%) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/{HuggingFaceInferenceRequest.java => HuggingFaceEmbeddingsRequest.java} (73%) rename x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/{HuggingFaceInferenceRequestEntity.java => HuggingFaceEmbeddingsRequestEntity.java} (73%) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceChatCompletionResponseEntity.java rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/{HuggingFaceInferenceRequestEntityTests.java => HuggingFaceEmbeddingsRequestEntityTests.java} (83%) rename x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/{HuggingFaceInferenceRequestTests.java => HuggingFaceEmbeddingsRequestTests.java} (92%) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 643de58f53e17..114d9eaedfa53 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -127,7 +127,6 @@ import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; -import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionService; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; import org.elasticsearch.xpack.inference.services.jinaai.JinaAIService; @@ -362,7 +361,6 @@ public void loadExtensions(ExtensionLoader loader) { public List getInferenceServiceFactories() { return List.of( context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()), - context -> new HuggingFaceChatCompletionService(httpFactory.get(), serviceComponents.get()), context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), context -> new CohereService(httpFactory.get(), serviceComponents.get()), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionRequestManager.java similarity index 80% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceCompletionRequestManager.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionRequestManager.java index 14bdda8cd3dc8..0eec750042379 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionRequestManager.java @@ -23,15 +23,22 @@ import java.util.Objects; import java.util.function.Supplier; -public class HuggingFaceCompletionRequestManager extends HuggingFaceRequestManager { - private static final Logger logger = LogManager.getLogger(HuggingFaceCompletionRequestManager.class); +/** + * Manages the execution of chat completion requests for Hugging Face models. + *

+ * This class is responsible for creating and executing requests to Hugging Face's chat completion API. + * It extends {@link HuggingFaceRequestManager} to provide specific functionality for chat completion models. + *

+ */ +public class HuggingFaceChatCompletionRequestManager extends HuggingFaceRequestManager { + private static final Logger logger = LogManager.getLogger(HuggingFaceChatCompletionRequestManager.class); - public static HuggingFaceCompletionRequestManager of( + public static HuggingFaceChatCompletionRequestManager of( HuggingFaceChatCompletionModel model, ResponseHandler responseHandler, ThreadPool threadPool ) { - return new HuggingFaceCompletionRequestManager( + return new HuggingFaceChatCompletionRequestManager( Objects.requireNonNull(model), Objects.requireNonNull(responseHandler), Objects.requireNonNull(threadPool) @@ -41,7 +48,7 @@ public static HuggingFaceCompletionRequestManager of( private final HuggingFaceChatCompletionModel model; private final ResponseHandler responseHandler; - private HuggingFaceCompletionRequestManager( + private HuggingFaceChatCompletionRequestManager( HuggingFaceChatCompletionModel model, ResponseHandler responseHandler, ThreadPool threadPool diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java index c6ed5a000dc98..b4fc7217d8338 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java @@ -18,7 +18,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequest; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest; import java.util.List; import java.util.Objects; @@ -26,9 +26,22 @@ import static org.elasticsearch.xpack.inference.common.Truncator.truncate; +/** + * This class is responsible for managing requests to the Hugging Face API for generating embeddings. + * It handles the execution of requests, including truncation of input data and response handling. + */ public class HuggingFaceEmbeddingsRequestManager extends HuggingFaceRequestManager { private static final Logger logger = LogManager.getLogger(HuggingFaceEmbeddingsRequestManager.class); + /** + * Creates a new instance of HuggingFaceEmbeddingsRequestManager. + * + * @param model The Hugging Face model to be used for generating embeddings. + * @param responseHandler The response handler for processing the API responses. + * @param truncator The truncator for handling input data truncation. + * @param threadPool The thread pool for executing requests. + * @return A new instance of HuggingFaceEmbeddingsRequestManager. + */ public static HuggingFaceEmbeddingsRequestManager of( HuggingFaceModel model, ResponseHandler responseHandler, @@ -68,7 +81,7 @@ public void execute( ) { List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getTokenLimit()); - var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); + var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 5cef2c65d13ed..13142472a1ef3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -43,11 +43,15 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; +/** + * This class is responsible for managing the Hugging Face inference service. + * It handles the creation of models, chunked inference, and unified completion inference. + */ public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; @@ -55,7 +59,8 @@ public class HuggingFaceService extends HuggingFaceBaseService { private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, - TaskType.COMPLETION + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION ); public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { @@ -152,7 +157,21 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - throwUnsupportedUnifiedCompletionOperation(NAME); + if (model instanceof HuggingFaceChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model; + var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); + var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest()); + var action = overriddenModel.accept(actionCreator); + + action.execute(inputs, timeout, listener); + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); } @Override @@ -180,6 +199,9 @@ public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); } + private Configuration() { + } + private static final LazyInitializable configuration = new LazyInitializable<>( () -> { var configurationMap = new HashMap(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index 82390b9d358c2..a9b4d5bedb7ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -11,15 +11,16 @@ import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceCompletionRequestManager; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionRequestManager; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import java.util.Objects; @@ -71,12 +72,12 @@ public ExecutableAction create(HuggingFaceElserModel model) { @Override public ExecutableAction create(HuggingFaceChatCompletionModel model) { - var responseHandler = new HuggingFaceResponseHandler( + var responseHandler = new OpenAiUnifiedChatCompletionResponseHandler( "hugging face chat completion", - HuggingFaceChatCompletionResponseEntity::fromResponse + OpenAiChatCompletionResponseEntity::fromResponse ); - var requestCreator = HuggingFaceCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool()); + var requestCreator = HuggingFaceChatCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool()); var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java deleted file mode 100644 index fcaab46835090..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionService.java +++ /dev/null @@ -1,167 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.huggingface.completion; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.util.LazyInitializable; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInference; -import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.InferenceServiceConfiguration; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.SettingsConfiguration; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; -import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; -import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; -import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; -import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; - -import java.util.EnumSet; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; -import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings.URL; - -public class HuggingFaceChatCompletionService extends HuggingFaceBaseService { - public static final String NAME = "hugging_face_chat_completion"; - - private static final String SERVICE_NAME = "Hugging Face CHAT COMPLETION"; - private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); - - public HuggingFaceChatCompletionService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { - super(factory, serviceComponents); - } - - @Override - public String name() { - return NAME; - } - - @Override - protected HuggingFaceModel createModel( - String inferenceEntityId, - TaskType taskType, - Map serviceSettings, - ChunkingSettings chunkingSettings, - @Nullable Map secretSettings, - String failureMessage, - ConfigurationParseContext context - ) { - return switch (taskType) { - case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel( - inferenceEntityId, - taskType, - NAME, - serviceSettings, - secretSettings, - context - ); - default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); - }; - } - - @Override - protected void doUnifiedCompletionInfer( - Model model, - UnifiedChatInput inputs, - TimeValue timeout, - ActionListener listener - ) { - if (model instanceof HuggingFaceChatCompletionModel == false) { - listener.onFailure(createInvalidModelException(model)); - return; - } - HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model; - var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); - var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest()); - var action = overriddenModel.accept(actionCreator); - - action.execute(inputs, timeout, listener); - } - - @Override - protected void doChunkedInfer( - Model model, - EmbeddingsInput inputs, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener> listener - ) { - throwUnsupportedUnifiedCompletionOperation(NAME); - } - - @Override - public InferenceServiceConfiguration getConfiguration() { - return Configuration.get(); - } - - @Override - public boolean hideFromConfigurationApi() { - return true; - } - - @Override - public EnumSet supportedTaskTypes() { - return SUPPORTED_TASK_TYPES; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_12_0; - } - - public static class Configuration { - public static InferenceServiceConfiguration get() { - return configuration.getOrCompute(); - } - - private static final LazyInitializable configuration = new LazyInitializable<>( - () -> { - var configurationMap = new HashMap(); - - configurationMap.put( - URL, - new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") - .setLabel("URL") - .setRequired(true) - .setSensitive(false) - .setUpdatable(false) - .setType(SettingsConfigurationFieldType.STRING) - .build() - ); - - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); - - return new InferenceServiceConfiguration.Builder().setService(NAME) - .setName(SERVICE_NAME) - .setTaskTypes(SUPPORTED_TASK_TYPES) - .setConfigurations(configurationMap) - .build(); - } - ); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index 8ab84dea6b12b..d0c04ae2c8169 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -34,6 +34,13 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; +/** + * Settings for the Hugging Face chat completion service. + *

+ * This class contains the settings required to configure a Hugging Face chat completion service, including the model ID, URL, maximum input + * tokens, and rate limit settings. + *

+ */ public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings, @@ -44,7 +51,14 @@ public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentOb // At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess // 3000 requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); - + private static final int DEFAULT_TOKEN_LIMIT = 512; + + /** + * Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a map of settings. + * @param map the map of settings + * @param context the context for parsing the settings + * @return a new instance of {@link HuggingFaceChatCompletionServiceSettings} + */ public static HuggingFaceChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); @@ -80,7 +94,7 @@ public static HuggingFaceChatCompletionServiceSettings fromMap(Map inputs) implements ToXContentObject { +/** + * This class represents the request entity for Hugging Face embeddings. + * It contains a list of input strings that will be used to generate embeddings. + */ +public record HuggingFaceEmbeddingsRequestEntity(List inputs) implements ToXContentObject { private static final String INPUTS_FIELD = "inputs"; - public HuggingFaceInferenceRequestEntity { + public HuggingFaceEmbeddingsRequestEntity { Objects.requireNonNull(inputs); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceChatCompletionResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceChatCompletionResponseEntity.java deleted file mode 100644 index 7e7dacb148f14..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceChatCompletionResponseEntity.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.huggingface.response; - -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentParser; -import org.elasticsearch.xcontent.XContentParserConfiguration; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; -import org.elasticsearch.xpack.inference.external.http.HttpResult; -import org.elasticsearch.xpack.inference.external.request.Request; - -import java.io.IOException; -import java.util.List; - -import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; -import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; - -public class HuggingFaceChatCompletionResponseEntity { - private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in HuggingFace chat response"; - - public static ChatCompletionResults fromResponse(Request request, HttpResult response) throws IOException { - var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { - moveToFirstToken(jsonParser); - - // Ensure the response starts with an array - XContentParser.Token token = jsonParser.currentToken(); - ensureExpectedToken(XContentParser.Token.START_ARRAY, token, jsonParser); - - // Move to the first object in the array - token = jsonParser.nextToken(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); - - // Position the parser at the "generated_text" field - positionParserAtTokenAfterField(jsonParser, "generated_text", FAILED_TO_FIND_FIELD_TEMPLATE); - - // Extract the "generated_text" value - XContentParser.Token contentToken = jsonParser.currentToken(); - ensureExpectedToken(XContentParser.Token.VALUE_STRING, contentToken, jsonParser); - String content = jsonParser.text(); - - return new ChatCompletionResults(List.of(new ChatCompletionResults.Result(content))); - } - } - - private HuggingFaceChatCompletionResponseEntity() {} -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestEntityTests.java similarity index 83% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestEntityTests.java index 3f7e716d25749..22e6e1c33a36c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestEntityTests.java @@ -12,17 +12,17 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequestEntity; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequestEntity; import java.io.IOException; import java.util.List; import static org.hamcrest.CoreMatchers.is; -public class HuggingFaceInferenceRequestEntityTests extends ESTestCase { +public class HuggingFaceEmbeddingsRequestEntityTests extends ESTestCase { public void testXContent() throws IOException { - var entity = new HuggingFaceInferenceRequestEntity(List.of("abc")); + var entity = new HuggingFaceEmbeddingsRequestEntity(List.of("abc")); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestTests.java similarity index 92% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestTests.java index e1f38b57ece3f..3cbf8a5e5b28f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestTests.java @@ -14,7 +14,7 @@ import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; -import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceInferenceRequest; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest; import java.io.IOException; import java.net.URI; @@ -26,7 +26,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class HuggingFaceInferenceRequestTests extends ESTestCase { +public class HuggingFaceEmbeddingsRequestTests extends ESTestCase { @SuppressWarnings("unchecked") public void testCreateRequest() throws URISyntaxException, IOException { var huggingFaceRequest = createRequest("www.google.com", "secret", "abc"); @@ -68,9 +68,9 @@ public void testIsTruncated_ReturnsTrue() throws URISyntaxException, IOException assertTrue(truncatedRequest.getTruncationInfo()[0]); } - public static HuggingFaceInferenceRequest createRequest(String url, String apiKey, String input) throws URISyntaxException { + public static HuggingFaceEmbeddingsRequest createRequest(String url, String apiKey, String input) throws URISyntaxException { - return new HuggingFaceInferenceRequest( + return new HuggingFaceEmbeddingsRequest( TruncatorTests.createTruncator(), new Truncator.TruncationResult(List.of(input), new boolean[] { false }), HuggingFaceEmbeddingsModelTests.createModel(url, apiKey) From 404f640ea62cf723a3876cc43af52e89d93593aa Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 25 Apr 2025 11:08:32 +0000 Subject: [PATCH 03/29] [CI] Auto commit changes from spotless --- .../inference/services/huggingface/HuggingFaceService.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 13142472a1ef3..08f9cd3645ef9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -199,8 +199,7 @@ public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); } - private Configuration() { - } + private Configuration() {} private static final LazyInitializable configuration = new LazyInitializable<>( () -> { From ceebb9adf723400c3ebbb6517434ef5d2822e555 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 25 Apr 2025 16:17:10 +0300 Subject: [PATCH 04/29] Add support for non-streaming completion task for HuggingFace --- .../huggingface/HuggingFaceModel.java | 15 +++++++++-- .../huggingface/HuggingFaceService.java | 25 +++++++++++++++-- .../action/HuggingFaceActionCreator.java | 27 ++++++++++++++----- 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java index 6a750a9ded9b3..62133eff4b658 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java @@ -9,17 +9,18 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.Objects; -public abstract class HuggingFaceModel extends Model { +public abstract class HuggingFaceModel extends RateLimitGroupingModel { private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings; private final SecureString apiKey; @@ -38,6 +39,16 @@ public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() { return rateLimitServiceSettings; } + @Override + public int rateLimitGroupingHash() { + return Objects.hash(rateLimitServiceSettings.uri(), apiKey); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); + } + public SecureString apiKey() { return apiKey; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 08f9cd3645ef9..66510d2250c04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -26,7 +26,10 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; @@ -36,6 +39,9 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -45,6 +51,7 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -55,6 +62,8 @@ public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; + private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = + "Failed to send Hugging Face %s request from inference entity id [%s]"; private static final String SERVICE_NAME = "Hugging Face"; private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( TaskType.TEXT_EMBEDDING, @@ -62,6 +71,10 @@ public class HuggingFaceService extends HuggingFaceBaseService { TaskType.COMPLETION, TaskType.CHAT_COMPLETION ); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + "hugging face chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); @@ -161,10 +174,18 @@ protected void doUnifiedCompletionInfer( listener.onFailure(createInvalidModelException(model)); return; } + HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model; - var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents()); var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest()); - var action = overriddenModel.accept(actionCreator); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); action.execute(inputs, timeout, listener); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index a9b4d5bedb7ea..6d0c7b04e986d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -9,14 +9,19 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionRequestManager; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; @@ -31,8 +36,14 @@ */ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { + public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions"; + private static final String USER_ROLE = "user"; private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Hugging Face %s request from inference entity id [%s]"; + static final ResponseHandler COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + "hugging face completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); private final Sender sender; private final ServiceComponents serviceComponents; @@ -72,13 +83,15 @@ public ExecutableAction create(HuggingFaceElserModel model) { @Override public ExecutableAction create(HuggingFaceChatCompletionModel model) { - var responseHandler = new OpenAiUnifiedChatCompletionResponseHandler( - "hugging face chat completion", - OpenAiChatCompletionResponseEntity::fromResponse + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class ); - var requestCreator = HuggingFaceChatCompletionRequestManager.of(model, responseHandler, serviceComponents.threadPool()); - var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId()); - return new SenderExecutableAction(sender, requestCreator, errorMessage); + var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "COMPLETION", model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); } } From acaa35b5d7a20f0c3ae14e13418b5d1add50e87d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 25 Apr 2025 16:28:28 +0300 Subject: [PATCH 05/29] Remove RequestManager for HF Chat Completion Task --- ...ggingFaceChatCompletionRequestManager.java | 73 ------------------- 1 file changed, 73 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionRequestManager.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionRequestManager.java deleted file mode 100644 index 0eec750042379..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionRequestManager.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.huggingface; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; -import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; - -import java.util.Objects; -import java.util.function.Supplier; - -/** - * Manages the execution of chat completion requests for Hugging Face models. - *

- * This class is responsible for creating and executing requests to Hugging Face's chat completion API. - * It extends {@link HuggingFaceRequestManager} to provide specific functionality for chat completion models. - *

- */ -public class HuggingFaceChatCompletionRequestManager extends HuggingFaceRequestManager { - private static final Logger logger = LogManager.getLogger(HuggingFaceChatCompletionRequestManager.class); - - public static HuggingFaceChatCompletionRequestManager of( - HuggingFaceChatCompletionModel model, - ResponseHandler responseHandler, - ThreadPool threadPool - ) { - return new HuggingFaceChatCompletionRequestManager( - Objects.requireNonNull(model), - Objects.requireNonNull(responseHandler), - Objects.requireNonNull(threadPool) - ); - } - - private final HuggingFaceChatCompletionModel model; - private final ResponseHandler responseHandler; - - private HuggingFaceChatCompletionRequestManager( - HuggingFaceChatCompletionModel model, - ResponseHandler responseHandler, - ThreadPool threadPool - ) { - super(model, threadPool); - this.model = model; - this.responseHandler = responseHandler; - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - var chatCompletionInput = inferenceInputs.castTo(UnifiedChatInput.class); - HuggingFaceUnifiedChatCompletionRequest request = new HuggingFaceUnifiedChatCompletionRequest(chatCompletionInput, model); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); - } -} From ff3ef500f894bd92ce51a1b703d232e37861ffe4 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 28 Apr 2025 11:49:24 +0300 Subject: [PATCH 06/29] Refactored Hugging Face Completion Service Settings, removed Request Manager, added Unit Tests --- .../HuggingFaceEmbeddingsRequestManager.java | 88 ------ .../HuggingFaceRequestManager.java | 58 +++- .../action/HuggingFaceActionCreator.java | 6 +- ...gingFaceChatCompletionServiceSettings.java | 35 +- .../huggingface/HuggingFaceServiceTests.java | 8 +- .../action/HuggingFaceActionCreatorTests.java | 157 +++++++-- .../HuggingFaceChatCompletionModelTests.java | 83 +++++ ...aceChatCompletionServiceSettingsTests.java | 298 ++++++++++++++++++ 8 files changed, 595 insertions(+), 138 deletions(-) delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java deleted file mode 100644 index b4fc7217d8338..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceEmbeddingsRequestManager.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.huggingface; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.common.Truncator; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest; - -import java.util.List; -import java.util.Objects; -import java.util.function.Supplier; - -import static org.elasticsearch.xpack.inference.common.Truncator.truncate; - -/** - * This class is responsible for managing requests to the Hugging Face API for generating embeddings. - * It handles the execution of requests, including truncation of input data and response handling. - */ -public class HuggingFaceEmbeddingsRequestManager extends HuggingFaceRequestManager { - private static final Logger logger = LogManager.getLogger(HuggingFaceEmbeddingsRequestManager.class); - - /** - * Creates a new instance of HuggingFaceEmbeddingsRequestManager. - * - * @param model The Hugging Face model to be used for generating embeddings. - * @param responseHandler The response handler for processing the API responses. - * @param truncator The truncator for handling input data truncation. - * @param threadPool The thread pool for executing requests. - * @return A new instance of HuggingFaceEmbeddingsRequestManager. - */ - public static HuggingFaceEmbeddingsRequestManager of( - HuggingFaceModel model, - ResponseHandler responseHandler, - Truncator truncator, - ThreadPool threadPool - ) { - return new HuggingFaceEmbeddingsRequestManager( - Objects.requireNonNull(model), - Objects.requireNonNull(responseHandler), - Objects.requireNonNull(truncator), - Objects.requireNonNull(threadPool) - ); - } - - private final HuggingFaceModel model; - private final ResponseHandler responseHandler; - private final Truncator truncator; - - private HuggingFaceEmbeddingsRequestManager( - HuggingFaceModel model, - ResponseHandler responseHandler, - Truncator truncator, - ThreadPool threadPool - ) { - super(model, threadPool); - this.model = model; - this.responseHandler = responseHandler; - this.truncator = truncator; - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); - var truncatedInput = truncate(docsInput, model.getTokenLimit()); - var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java index 33c97b6fb811b..7bb140e91ec5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java @@ -7,12 +7,66 @@ package org.elasticsearch.xpack.inference.services.huggingface; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.BaseRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest; -public abstract class HuggingFaceRequestManager extends BaseRequestManager { - protected HuggingFaceRequestManager(HuggingFaceModel model, ThreadPool threadPool) { +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class HuggingFaceRequestManager extends BaseRequestManager { + private static final Logger logger = LogManager.getLogger(HuggingFaceRequestManager.class); + + public static HuggingFaceRequestManager of( + HuggingFaceModel model, + ResponseHandler responseHandler, + Truncator truncator, + ThreadPool threadPool + ) { + return new HuggingFaceRequestManager( + Objects.requireNonNull(model), + Objects.requireNonNull(responseHandler), + Objects.requireNonNull(truncator), + Objects.requireNonNull(threadPool) + ); + } + + private final HuggingFaceModel model; + private final ResponseHandler responseHandler; + private final Truncator truncator; + + private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler responseHandler, Truncator truncator, ThreadPool threadPool) { super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings()); + this.model = model; + this.responseHandler = responseHandler; + this.truncator = truncator; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); + var truncatedInput = truncate(docsInput, model.getTokenLimit()); + var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); } record RateLimitGrouping(int accountHash) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index 6d0c7b04e986d..4c3dc67c51c04 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -16,7 +16,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; @@ -58,7 +58,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) { "hugging face text embeddings", HuggingFaceEmbeddingsResponseEntity::fromResponse ); - var requestCreator = HuggingFaceEmbeddingsRequestManager.of( + var requestCreator = HuggingFaceRequestManager.of( model, responseHandler, serviceComponents.truncator(), @@ -71,7 +71,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) { @Override public ExecutableAction create(HuggingFaceElserModel model) { var responseHandler = new HuggingFaceResponseHandler("hugging face elser", HuggingFaceElserResponseEntity::fromResponse); - var requestCreator = HuggingFaceEmbeddingsRequestManager.of( + var requestCreator = HuggingFaceRequestManager.of( model, responseHandler, serviceComponents.truncator(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index d0c04ae2c8169..cc4c5bdb8fe63 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -29,9 +29,10 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; /** @@ -47,11 +48,9 @@ public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentOb HuggingFaceRateLimitServiceSettings { public static final String NAME = "hugging_face_completion_service_settings"; - public static final String URL = "url"; // At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess // 3000 requests per minute private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); - private static final int DEFAULT_TOKEN_LIMIT = 512; /** * Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a map of settings. @@ -62,7 +61,7 @@ public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentOb public static HuggingFaceChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); - String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); var uri = extractUri(map, URL, validationException); @@ -93,7 +92,7 @@ public static HuggingFaceChatCompletionServiceSettings fromMap(Map) requestMap.get("inputs"); @@ -178,14 +181,14 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx ); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(1)); assertThat(requestMap.get("inputs"), instanceOf(List.class)); var inputList = (List) requestMap.get("inputs"); @@ -228,14 +231,14 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(1)); assertThat(requestMap.get("inputs"), instanceOf(List.class)); var inputList = (List) requestMap.get("inputs"); @@ -292,14 +295,14 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws ); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); assertThat(requestMap.size(), is(1)); assertThat(requestMap.get("inputs"), instanceOf(List.class)); var inputList = (List) requestMap.get("inputs"); @@ -350,14 +353,14 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc assertThat(webServer.requests(), hasSize(2)); { - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var initialRequestAsMap = entityAsMap(webServer.requests().get(0).getBody()); + var initialRequestAsMap = entityAsMap(webServer.requests().getFirst().getBody()); var initialInputs = initialRequestAsMap.get("inputs"); assertThat(initialInputs, is(List.of("abcd"))); } @@ -412,17 +415,123 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); + assertNull(webServer.requests().getFirst().getUri().getQuery()); assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var initialRequestAsMap = entityAsMap(webServer.requests().get(0).getBody()); + var initialRequestAsMap = entityAsMap(webServer.requests().getFirst().getBody()); var initialInputs = initialRequestAsMap.get("inputs"); assertThat(initialInputs, is(List.of("123"))); } } + + public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "object": "chat.completion", + "id": "", + "created": 1745855316, + "model": "/repository", + "system_fingerprint": "3.2.3-sha-a1f3ebe", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 8, + "completion_tokens": 50, + "total_tokens": 58 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + var actionCreator = new HuggingFaceActionCreator(sender, createWithEmptySettings(threadPool)); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + } + + public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "invalid_field": "unexpected" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + var actionCreator = new HuggingFaceActionCreator( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send Hugging Face COMPLETION request from inference entity id " + "[model]. Cause: Required [choices]") + ); + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java new file mode 100644 index 0000000000000..c35ba36081440 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.List; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class HuggingFaceChatCompletionModelTests extends ESTestCase { + + public void testThrowsURISyntaxException_ForInvalidUrl() { + var thrownException = expectThrows(IllegalArgumentException.class, () -> createCompletionModel("^^", "secret", "id")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); + } + + public static HuggingFaceChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) { + return new HuggingFaceChatCompletionModel( + modelId, + TaskType.COMPLETION, + "service", + new HuggingFaceChatCompletionServiceSettings(modelId, url, null, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static HuggingFaceChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) { + return new HuggingFaceChatCompletionModel( + modelId, + TaskType.CHAT_COMPLETION, + "service", + new HuggingFaceChatCompletionServiceSettings(modelId, url, null, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { + var model = createCompletionModel("url", "api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = HuggingFaceChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel("url", "api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = HuggingFaceChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..0e31e78b21a44 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java @@ -0,0 +1,298 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class HuggingFaceChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< + HuggingFaceChatCompletionServiceSettings> { + + public static final String MODEL_ID = "some model"; + public static final String CORRECT_URL = "https://www.elastic.co"; + public static final int INPUT_TOKENS = 8192; + public static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + ServiceFields.MAX_INPUT_TOKENS, + INPUT_TOKENS, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new HuggingFaceChatCompletionServiceSettings( + MODEL_ID, + ServiceUtils.createUri(CORRECT_URL), + INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_MissingMaxInputTokens_Success() { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new HuggingFaceChatCompletionServiceSettings( + MODEL_ID, + ServiceUtils.createUri(CORRECT_URL), + null, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_MissingModelId_Success() { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + CORRECT_URL, + ServiceFields.MAX_INPUT_TOKENS, + INPUT_TOKENS, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new HuggingFaceChatCompletionServiceSettings( + null, + ServiceUtils.createUri(CORRECT_URL), + INPUT_TOKENS, + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_MissingRateLimit_Success() { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL, ServiceFields.MAX_INPUT_TOKENS, INPUT_TOKENS) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is(new HuggingFaceChatCompletionServiceSettings(MODEL_ID, ServiceUtils.createUri(CORRECT_URL), INPUT_TOKENS, null)) + ); + } + + public void testFromMap_MissingUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.MAX_INPUT_TOKENS, + INPUT_TOKENS, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format("Validation Failed: 1: [service_settings] does not contain the required setting [url];", ServiceFields.URL) + ) + ); + } + + public void testFromMap_EmptyUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + "", + ServiceFields.MAX_INPUT_TOKENS, + INPUT_TOKENS, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;", + ServiceFields.URL + ) + ) + ); + } + + public void testFromMap_InvalidUrl_ThrowsException() { + String invalidUrl = "https://www.elastic^^co"; + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + invalidUrl, + ServiceFields.MAX_INPUT_TOKENS, + INPUT_TOKENS, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", + invalidUrl, + ServiceFields.URL + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + ServiceFields.MAX_INPUT_TOKENS, + INPUT_TOKENS, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model_id":"some model","url":"https://www.elastic.co","max_input_tokens":8192,"rate_limit":{"requests_per_minute":2}}""")); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"url":"https://www.elastic.co","rate_limit":{"requests_per_minute":3000}}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return HuggingFaceChatCompletionServiceSettings::new; + } + + @Override + protected HuggingFaceChatCompletionServiceSettings createTestInstance() { + return createRandomWithNonNullUrl(); + } + + @Override + protected HuggingFaceChatCompletionServiceSettings mutateInstance(HuggingFaceChatCompletionServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, HuggingFaceChatCompletionServiceSettingsTests::createRandomWithNonNullUrl); + } + + private static HuggingFaceChatCompletionServiceSettings createRandomWithNonNullUrl() { + return createRandom(randomAlphaOfLength(15)); + } + + private static HuggingFaceChatCompletionServiceSettings createRandom(String url) { + var modelId = randomAlphaOfLength(8); + var maxInputTokens = randomFrom(randomIntBetween(128, 4096), null); + + return new HuggingFaceChatCompletionServiceSettings( + modelId, + ServiceUtils.createUri(url), + maxInputTokens, + RateLimitSettingsTests.createRandom() + ); + } +} From 965093b052fd8a08648686b071ccdb606d9eed23 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 29 Apr 2025 21:19:03 +0300 Subject: [PATCH 07/29] Refactored Hugging Face Action Creator, added Unit Tests --- .../action/HuggingFaceActionCreator.java | 6 +- .../huggingface/HuggingFaceServiceTests.java | 422 ++++++++++++++++++ .../HuggingFaceChatCompletionActionTests.java | 243 ++++++++++ ...aceChatCompletionServiceSettingsTests.java | 9 + ...ifiedChatCompletionRequestEntityTests.java | 69 +++ ...FaceUnifiedChatCompletionRequestTests.java | 80 ++++ 6 files changed, 826 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index 4c3dc67c51c04..bc54e4cb38ab9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -24,7 +24,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; -import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import java.util.Objects; @@ -37,10 +37,10 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions"; - private static final String USER_ROLE = "user"; + static final String USER_ROLE = "user"; private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = "Failed to send Hugging Face %s request from inference entity id [%s]"; - static final ResponseHandler COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( "hugging face completion", OpenAiChatCompletionResponseEntity::fromResponse ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 3646610e17588..c96554e4029a4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -12,6 +12,7 @@ import org.apache.http.HttpHeaders; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; @@ -29,20 +30,29 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettingsTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; @@ -53,14 +63,19 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -74,7 +89,12 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; public class HuggingFaceServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @@ -175,6 +195,389 @@ public void testParseRequestConfig_CreatesAnElserModel() throws IOException { } } + public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel() throws IOException { + var url = "url"; + var model = "model"; + var secret = "secret"; + + try (var service = createHuggingFaceService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class)); + + var completionsModel = (HuggingFaceChatCompletionModel) m; + + assertThat(completionsModel.getServiceSettings().uri().toString(), is(url)); + assertThat(completionsModel.getServiceSettings().modelId(), is(model)); + assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap( + HuggingFaceChatCompletionServiceSettingsTests.getServiceSettingsMap(url, model), + getSecretSettingsMap(secret) + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel_WithoutModelId() throws IOException { + var url = "url"; + var secret = "secret"; + + try (var service = createHuggingFaceService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class)); + + var completionsModel = (HuggingFaceChatCompletionModel) m; + + assertThat(completionsModel.getServiceSettings().uri().toString(), is(url)); + assertNull(completionsModel.getServiceSettings().modelId()); + assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap(getServiceSettingsMap(url), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); + + try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":"stop"\ + }\ + ],\ + "usage":{\ + "prompt_tokens": 16,\ + "completion_tokens": 28,\ + "total_tokens": 44,\ + "prompt_tokens_details": {\ + "cached_tokens": 0,\ + "audio_tokens": 0\ + },\ + "completion_tokens_details": {\ + "reasoning_tokens": 0,\ + "audio_tokens": 0,\ + "accepted_prediction_tokens": 0,\ + "rejected_prediction_tokens": 0\ + }\ + }\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """ + "model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """ + "usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}"""); + } + } + + public void testUnifiedCompletionError() throws Exception { + String responseJson = """ + { + "error": { + "message": "The model `gpt-4awero` does not exist or you do not have access to it.", + "type": "invalid_request_error", + "param": null, + "code": "model_not_found" + } + }"""; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + var latch = new CountDownLatch(1); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { + try (var builder = XContentFactory.jsonBuilder()) { + var t = unwrapCause(e); + assertThat(t, isA(UnifiedChatCompletionException.class)); + ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(""" + {\ + "error":{\ + "code":"model_not_found",\ + "message":"Received an unsuccessful status code for request from inference entity id [model] status \ + [404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\ + "type":"invalid_request_error"\ + }}""")); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }), latch::countDown) + ); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + event: error + data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "message":"Received an error response for request from inference entity id [model]. Error message: \ + [Timed out waiting for more data]",\ + "type":"timeout"\ + }}"""); + } + + private void testStreamError(String expectedResponse) throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testUnifiedCompletionMalformedError() throws Exception { + String responseJson = """ + data: { invalid json } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"bad_request",\ + "message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\ + at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\ + "type":"x_content_parse_exception"\ + }}"""); + } + + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":null\ + }\ + ]\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + + private InferenceEventsAssertion streamCompletion() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + public void testInfer_StreamRequest_ErrorResponse() throws Exception { + String responseJson = """ + { + "error": { + "message": "You didn't provide an API key...", + "type": "invalid_request_error", + "param": null, + "code": null + } + }"""; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); + assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED)); + assertThat( + e.getMessage(), + equalTo( + "Received an authentication error status code for request from inference entity id [model] status [401]. " + + "Error message: [You didn't provide an API key...]" + ) + ); + } + + public void testInfer_StreamRequestRetry() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(503).setBody(""" + { + "error": { + "message": "server busy", + "type": "server_busy" + } + }""")); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":null\ + }\ + ]\ + } + + """)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + + public void testSupportsStreaming() throws IOException { + try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) { + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); + } + } + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createHuggingFaceService()) { var config = getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")); @@ -258,6 +661,25 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throw } } + public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws IOException { + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret")); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(HuggingFaceChatCompletionModel.class)); + + var chatCompletionModel = (HuggingFaceChatCompletionModel) model; + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is("url")); + assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java new file mode 100644 index 0000000000000..9f65bd82fc76e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java @@ -0,0 +1,243 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.USER_ROLE; +import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class HuggingFaceChatCompletionActionTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "result content" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content")))); + assertThat(webServer.requests(), hasSize(1)); + + MockRequest request = webServer.requests().getFirst(); + + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + } + + public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException { + try (var sender = mock(Sender.class)) { + var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("^^", sender)); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("Failed to send hugging face chat completions request. Cause: failed")); + } + + public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("hugging face chat completions only accepts 1 input")); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + private ExecutableAction createAction(String url, Sender sender) { + var model = createCompletionModel(url, "secret", "model"); + var manager = new GenericRequestManager<>( + threadPool, + model, + COMPLETION_HANDLER, + inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + var errorMessage = constructFailedToSendRequestMessage("hugging face chat completions"); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "hugging face chat completions"); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java index 0e31e78b21a44..61c6756d54426 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java @@ -295,4 +295,13 @@ private static HuggingFaceChatCompletionServiceSettings createRandom(String url) RateLimitSettingsTests.createRandom() ); } + + public static Map getServiceSettingsMap(String url, String model) { + var map = new HashMap(); + + map.put(ServiceFields.URL, url); + map.put(ServiceFields.MODEL_ID, model); + + return map; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..81d26036036c6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity; + +import java.io.IOException; +import java.util.ArrayList; + +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; +import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel; + +public class HuggingFaceUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + private static final String ROLE = "user"; + + public void testModelUserFieldsSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + HuggingFaceChatCompletionModel model = createCompletionModel("test-url", "api-key", "test-endpoint"); + + HuggingFaceUnifiedChatCompletionRequestEntity entity = new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java new file mode 100644 index 0000000000000..8989c76a46cf6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class HuggingFaceUnifiedChatCompletionRequestTests extends ESTestCase { + + public void testCreateRequest_WithStreaming() throws IOException { + var request = createRequest("url", "secret", "abcd", "model", true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("stream"), is(true)); + } + + public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { + var request = createRequest("url", "secret", "abcd", "model", true); + var truncatedRequest = request.truncate(); + assertThat(request.getURI().toString(), is("url")); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(5)); + + // We do not truncate for Hugging Face chat completions + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest("url", "secret", "abcd", "model", true); + assertNull(request.getTruncationInfo()); + } + + public static HuggingFaceUnifiedChatCompletionRequest createRequest(String url, String apiKey, String input, @Nullable String model) { + return createRequest(url, apiKey, input, model, false); + } + + public static HuggingFaceUnifiedChatCompletionRequest createRequest( + @Nullable String url, + String apiKey, + String input, + @Nullable String model, + boolean stream + ) { + var chatCompletionModel = HuggingFaceChatCompletionModelTests.createCompletionModel(url, apiKey, model); + return new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } + +} From 6757b073a93f6043f13106f98cc3d51fc1ae250f Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Tue, 29 Apr 2025 22:21:19 +0300 Subject: [PATCH 08/29] Add Hugging Face Server Test --- .../xpack/inference/InferenceGetServicesIT.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index da65da3689514..c3f4ba3b52016 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -115,7 +115,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(10)); + assertThat(services.size(), equalTo(11)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -133,6 +133,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "cohere", "deepseek", "googleaistudio", + "hugging_face", "openai", "streaming_completion_test_service" ).toArray(), @@ -143,7 +144,7 @@ public void testGetServicesWithCompletionTaskType() throws IOException { @SuppressWarnings("unchecked") public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); - assertThat(services.size(), equalTo(4)); + assertThat(services.size(), equalTo(5)); String[] providers = new String[services.size()]; for (int i = 0; i < services.size(); i++) { @@ -151,7 +152,8 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { providers[i] = (String) serviceConfig.get("service"); } - assertArrayEquals(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray(), providers); + assertArrayEquals(List.of("deepseek", "elastic", "hugging_face", "openai", + "streaming_completion_test_service").toArray(), providers); } @SuppressWarnings("unchecked") From df845eb9e31fa3fa5e4b96134880ba93b27c5392 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 29 Apr 2025 19:31:09 +0000 Subject: [PATCH 09/29] [CI] Auto commit changes from spotless --- .../xpack/inference/InferenceGetServicesIT.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index c3f4ba3b52016..ab73ef0c5f693 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -152,8 +152,10 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { providers[i] = (String) serviceConfig.get("service"); } - assertArrayEquals(List.of("deepseek", "elastic", "hugging_face", "openai", - "streaming_completion_test_service").toArray(), providers); + assertArrayEquals( + List.of("deepseek", "elastic", "hugging_face", "openai", "streaming_completion_test_service").toArray(), + providers + ); } @SuppressWarnings("unchecked") From 5bbe3b7768a09b70f4a441648b8bee0771e2755d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 12:13:34 +0300 Subject: [PATCH 10/29] Removed parameters from media type for Chat Completion Request and unit tests --- .../HuggingFaceUnifiedChatCompletionRequest.java | 2 +- .../action/HuggingFaceActionCreatorTests.java | 10 ++++++++-- .../action/HuggingFaceChatCompletionActionTests.java | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java index ea7a9402a8f93..718ee082aa813 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java @@ -54,7 +54,7 @@ public HttpRequest createHttpRequest() { ); httpPost.setEntity(byteEntity); - httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); httpPost.setHeader(createAuthBearerHeader(model.apiKey())); return new HttpRequest(httpPost, getInferenceEntityId()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index f6996512f98e3..e342c3526d862 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -475,7 +475,10 @@ public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() thro assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); @@ -523,7 +526,10 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th assertThat(webServer.requests(), hasSize(1)); assertNull(webServer.requests().getFirst().getUri().getQuery()); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat( + webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java index 9f65bd82fc76e..6e60b919328bc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java @@ -129,7 +129,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { MockRequest request = webServer.requests().getFirst(); assertNull(request.getUri().getQuery()); - assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); var requestMap = entityAsMap(request.getBody()); From 36848162d2015389278af42ab76e1cf750093822 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 13:27:03 +0300 Subject: [PATCH 11/29] Removed OpenAI default URL in HuggingFaceService's configuration, fixed formatting in InferenceGetServicesIT --- .../xpack/inference/InferenceGetServicesIT.java | 5 ++++- .../inference/services/huggingface/HuggingFaceService.java | 3 +-- .../services/huggingface/HuggingFaceServiceTests.java | 1 - 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index eb75671c82ea9..1aa302fcb2b32 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -153,7 +153,10 @@ public void testGetServicesWithChatCompletionTaskType() throws IOException { var providers = providers(services); - assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())); + assertThat( + providers, + containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray()) + ); } public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 66510d2250c04..a7f66b6b293f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -228,8 +228,7 @@ private Configuration() {} configurationMap.put( URL, - new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDefaultValue("https://api.openai.com/v1/embeddings") - .setDescription("The URL endpoint to use for the requests.") + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") .setLabel("URL") .setRequired(true) .setSensitive(false) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index c96554e4029a4..5801b9984a9a8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -1264,7 +1264,6 @@ public void testGetConfiguration() throws Exception { "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"] }, "url": { - "default_value": "https://api.openai.com/v1/embeddings", "description": "The URL endpoint to use for the requests.", "label": "URL", "required": true, From 7670d2c798622227a2c0a6e0376ea78b8d9b31e6 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 14:08:43 +0300 Subject: [PATCH 12/29] Refactor error message handling in HuggingFaceActionCreator and HuggingFaceService --- .../services/huggingface/HuggingFaceService.java | 5 +---- .../action/HuggingFaceActionCreator.java | 13 ++++++++----- .../action/HuggingFaceActionCreatorTests.java | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index a7f66b6b293f2..9b908e5c68817 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -51,7 +51,6 @@ import java.util.Map; import java.util.Set; -import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -62,8 +61,6 @@ public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; - private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = - "Failed to send Hugging Face %s request from inference entity id [%s]"; private static final String SERVICE_NAME = "Hugging Face"; private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( TaskType.TEXT_EMBEDDING, @@ -184,7 +181,7 @@ protected void doUnifiedCompletionInfer( unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel), UnifiedChatInput.class ); - var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId()); + var errorMessage = HuggingFaceActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId()); var action = new SenderExecutableAction(getSender(), manager, errorMessage); action.execute(inputs, timeout, listener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index bc54e4cb38ab9..df1ddcb017970 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.huggingface.action; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; @@ -38,8 +39,6 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions"; static final String USER_ROLE = "user"; - private static final String FAILED_TO_SEND_REQUEST_ERROR_MESSAGE = - "Failed to send Hugging Face %s request from inference entity id [%s]"; static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( "hugging face completion", OpenAiChatCompletionResponseEntity::fromResponse @@ -64,7 +63,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) { serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "text embeddings", model.getInferenceEntityId()); + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } @@ -77,7 +76,7 @@ public ExecutableAction create(HuggingFaceElserModel model) { serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "ELSER", model.getInferenceEntityId()); + var errorMessage = buildErrorMessage(TaskType.SPARSE_EMBEDDING, model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } @@ -91,7 +90,11 @@ public ExecutableAction create(HuggingFaceChatCompletionModel model) { ChatCompletionInput.class ); - var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "COMPLETION", model.getInferenceEntityId()); + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); } + + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index e342c3526d862..064d1d9ab1a6b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -521,7 +521,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to send Hugging Face COMPLETION request from inference entity id " + "[model]. Cause: Required [choices]") + is("Failed to send Hugging Face completion request from inference entity id " + "[model]. Cause: Required [choices]") ); assertThat(webServer.requests(), hasSize(1)); From 6630be7dff2cc8e9fef7759617cc5e9068274111 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 16:50:06 +0300 Subject: [PATCH 13/29] Update minimal supported version and add Hugging Face transport version constants --- .../src/main/java/org/elasticsearch/TransportVersions.java | 2 ++ .../HuggingFaceChatCompletionServiceSettings.java | 7 ++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 8e05c6755c91c..9765d2ca40b01 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -164,6 +164,7 @@ static TransportVersion def(int id) { public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20); public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21); public static final TransportVersion ESQL_REPORT_ORIGINAL_TYPES_BACKPORT_8_19 = def(8_841_0_22); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_8_19 = def(8_841_0_23); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00); @@ -236,6 +237,7 @@ static TransportVersion def(int id) { public static final TransportVersion PINNED_RETRIEVER = def(9_068_0_00); public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_069_0_00); public static final TransportVersion WRITE_LOAD_INCLUDES_BUFFER_WRITES = def(9_070_00_0); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE = def(9_071_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index cc4c5bdb8fe63..185914bc9792f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -178,7 +178,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_14_0; + return TransportVersions.ML_INFERENCE_HUGGING_FACE; } @Override @@ -186,10 +186,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(modelId); out.writeString(uri.toString()); out.writeOptionalVInt(maxInputTokens); - - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { - rateLimitSettings.writeTo(out); - } + rateLimitSettings.writeTo(out); } @Override From 1efb2ee8e2c047ea51d591a6274acbc19b5c4d40 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 18:23:47 +0300 Subject: [PATCH 14/29] Made modelId field optional in HuggingFaceChatCompletionModel, updated unit tests --- .../HuggingFaceChatCompletionModel.java | 12 +++++- .../huggingface/HuggingFaceServiceTests.java | 6 +-- .../action/HuggingFaceActionCreatorTests.java | 2 +- .../HuggingFaceChatCompletionModelTests.java | 42 +++++++++++++++++-- 4 files changed, 53 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java index c843ab03f9932..cfa2aedab192b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java @@ -19,14 +19,22 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.util.Map; -import java.util.Objects; public class HuggingFaceChatCompletionModel extends HuggingFaceModel { + /** + * Creates a new {@link HuggingFaceChatCompletionModel} by copying properties from an existing model, + * replacing the {@code modelId} in the service settings with the one from the given {@link UnifiedCompletionRequest}, + * if present. If the request does not specify a model ID, the original value is retained. + * + * @param model the original model to copy from + * @param request the request potentially containing an overridden model ID + * @return a new {@link HuggingFaceChatCompletionModel} with updated service settings + */ public static HuggingFaceChatCompletionModel of(HuggingFaceChatCompletionModel model, UnifiedCompletionRequest request) { var originalModelServiceSettings = model.getServiceSettings(); var overriddenServiceSettings = new HuggingFaceChatCompletionServiceSettings( - Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + request.model() != null ? request.model() : originalModelServiceSettings.modelId(), originalModelServiceSettings.uri(), originalModelServiceSettings.maxInputTokens(), originalModelServiceSettings.rateLimitSettings() diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 5801b9984a9a8..7e5f6cc78a241 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -386,7 +386,7 @@ public void testUnifiedCompletionError() throws Exception { {\ "error":{\ "code":"model_not_found",\ - "message":"Received an unsuccessful status code for request from inference entity id [model] status \ + "message":"Received an unsuccessful status code for request from inference entity id [id] status \ [404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\ "type":"invalid_request_error"\ }}""")); @@ -409,7 +409,7 @@ public void testMidStreamUnifiedCompletionError() throws Exception { testStreamError(""" {\ "error":{\ - "message":"Received an error response for request from inference entity id [model]. Error message: \ + "message":"Received an error response for request from inference entity id [id]. Error message: \ [Timed out waiting for more data]",\ "type":"timeout"\ }}"""); @@ -532,7 +532,7 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { assertThat( e.getMessage(), equalTo( - "Received an authentication error status code for request from inference entity id [model] status [401]. " + "Received an authentication error status code for request from inference entity id [id] status [401]. " + "Error message: [You didn't provide an API key...]" ) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 064d1d9ab1a6b..8315a7c1318a6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -521,7 +521,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat( thrownException.getMessage(), - is("Failed to send Hugging Face completion request from inference entity id " + "[model]. Cause: Required [choices]") + is("Failed to send Hugging Face completion request from inference entity id " + "[id]. Cause: Required [choices]") ); assertThat(webServer.requests(), hasSize(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java index c35ba36081440..38aa268fe176c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java @@ -27,7 +27,7 @@ public void testThrowsURISyntaxException_ForInvalidUrl() { public static HuggingFaceChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) { return new HuggingFaceChatCompletionModel( - modelId, + "id", TaskType.COMPLETION, "service", new HuggingFaceChatCompletionServiceSettings(modelId, url, null, null), @@ -37,7 +37,7 @@ public static HuggingFaceChatCompletionModel createCompletionModel(String url, S public static HuggingFaceChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) { return new HuggingFaceChatCompletionModel( - modelId, + "id", TaskType.CHAT_COMPLETION, "service", new HuggingFaceChatCompletionServiceSettings(modelId, url, null, null), @@ -45,7 +45,7 @@ public static HuggingFaceChatCompletionModel createChatCompletionModel(String ur ); } - public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { var model = createCompletionModel("url", "api_key", "model_name"); var request = new UnifiedCompletionRequest( List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), @@ -63,6 +63,42 @@ public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); } + public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() { + var model = createCompletionModel("url", "api_key", null); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = HuggingFaceChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel("url", "api_key", null); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = HuggingFaceChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + } + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { var model = createCompletionModel("url", "api_key", "model_name"); var request = new UnifiedCompletionRequest( From 61537d072f9879a8f260846f6471761d695afa63 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 19:43:10 +0300 Subject: [PATCH 15/29] Removed max input tokens field from HuggingFaceChatCompletionServiceSettings, fixed unit tests --- .../HuggingFaceChatCompletionModel.java | 3 +- ...gingFaceChatCompletionServiceSettings.java | 41 ++--------- .../HuggingFaceChatCompletionModelTests.java | 4 +- ...aceChatCompletionServiceSettingsTests.java | 70 ++----------------- 4 files changed, 13 insertions(+), 105 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java index cfa2aedab192b..450cec1cc8199 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java @@ -36,7 +36,6 @@ public static HuggingFaceChatCompletionModel of(HuggingFaceChatCompletionModel m var overriddenServiceSettings = new HuggingFaceChatCompletionServiceSettings( request.model() != null ? request.model() : originalModelServiceSettings.modelId(), originalModelServiceSettings.uri(), - originalModelServiceSettings.maxInputTokens(), originalModelServiceSettings.rateLimitSettings() ); @@ -98,6 +97,6 @@ public ExecutableAction accept(HuggingFaceActionVisitor creator) { @Override public Integer getTokenLimit() { - return getServiceSettings().maxInputTokens(); + throw new UnsupportedOperationException("Token Limit for chat completion is sent in request and not retrieved from the model"); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index 185914bc9792f..ff05a7bb4073b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -27,11 +27,9 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; @@ -65,13 +63,6 @@ public static HuggingFaceChatCompletionServiceSettings fromMap(Map(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -58,35 +55,6 @@ public void testFromMap_AllFields_Success() { new HuggingFaceChatCompletionServiceSettings( MODEL_ID, ServiceUtils.createUri(CORRECT_URL), - INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT) - ) - ) - ); - } - - public void testFromMap_MissingMaxInputTokens_Success() { - var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( - new HashMap<>( - Map.of( - ServiceFields.MODEL_ID, - MODEL_ID, - ServiceFields.URL, - CORRECT_URL, - RateLimitSettings.FIELD_NAME, - new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) - ) - ), - ConfigurationParseContext.PERSISTENT - ); - - assertThat( - serviceSettings, - is( - new HuggingFaceChatCompletionServiceSettings( - MODEL_ID, - ServiceUtils.createUri(CORRECT_URL), - null, new RateLimitSettings(RATE_LIMIT) ) ) @@ -99,8 +67,6 @@ public void testFromMap_MissingModelId_Success() { Map.of( ServiceFields.URL, CORRECT_URL, - ServiceFields.MAX_INPUT_TOKENS, - INPUT_TOKENS, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -110,29 +76,17 @@ public void testFromMap_MissingModelId_Success() { assertThat( serviceSettings, - is( - new HuggingFaceChatCompletionServiceSettings( - null, - ServiceUtils.createUri(CORRECT_URL), - INPUT_TOKENS, - new RateLimitSettings(RATE_LIMIT) - ) - ) + is(new HuggingFaceChatCompletionServiceSettings(null, ServiceUtils.createUri(CORRECT_URL), new RateLimitSettings(RATE_LIMIT))) ); } public void testFromMap_MissingRateLimit_Success() { var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( - new HashMap<>( - Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL, ServiceFields.MAX_INPUT_TOKENS, INPUT_TOKENS) - ), + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), ConfigurationParseContext.PERSISTENT ); - assertThat( - serviceSettings, - is(new HuggingFaceChatCompletionServiceSettings(MODEL_ID, ServiceUtils.createUri(CORRECT_URL), INPUT_TOKENS, null)) - ); + assertThat(serviceSettings, is(new HuggingFaceChatCompletionServiceSettings(MODEL_ID, ServiceUtils.createUri(CORRECT_URL), null))); } public void testFromMap_MissingUrl_ThrowsException() { @@ -143,8 +97,6 @@ public void testFromMap_MissingUrl_ThrowsException() { Map.of( ServiceFields.MODEL_ID, MODEL_ID, - ServiceFields.MAX_INPUT_TOKENS, - INPUT_TOKENS, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -171,8 +123,6 @@ public void testFromMap_EmptyUrl_ThrowsException() { MODEL_ID, ServiceFields.URL, "", - ServiceFields.MAX_INPUT_TOKENS, - INPUT_TOKENS, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -203,8 +153,6 @@ public void testFromMap_InvalidUrl_ThrowsException() { MODEL_ID, ServiceFields.URL, invalidUrl, - ServiceFields.MAX_INPUT_TOKENS, - INPUT_TOKENS, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -233,8 +181,6 @@ public void testToXContent_WritesAllValues() throws IOException { MODEL_ID, ServiceFields.URL, CORRECT_URL, - ServiceFields.MAX_INPUT_TOKENS, - INPUT_TOKENS, RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) ) @@ -247,7 +193,7 @@ public void testToXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" - {"model_id":"some model","url":"https://www.elastic.co","max_input_tokens":8192,"rate_limit":{"requests_per_minute":2}}""")); + {"model_id":"some model","url":"https://www.elastic.co","rate_limit":{"requests_per_minute":2}}""")); } public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { @@ -286,14 +232,8 @@ private static HuggingFaceChatCompletionServiceSettings createRandomWithNonNullU private static HuggingFaceChatCompletionServiceSettings createRandom(String url) { var modelId = randomAlphaOfLength(8); - var maxInputTokens = randomFrom(randomIntBetween(128, 4096), null); - return new HuggingFaceChatCompletionServiceSettings( - modelId, - ServiceUtils.createUri(url), - maxInputTokens, - RateLimitSettingsTests.createRandom() - ); + return new HuggingFaceChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); } public static Map getServiceSettingsMap(String url, String model) { From 64c0685b26363ac0e22c2a66c539235b4931eeb7 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 19:56:05 +0300 Subject: [PATCH 16/29] Removed if statement checking TransportVersion for HuggingFaceChatCompletionServiceSettings constructor with StreamInput param --- .../HuggingFaceChatCompletionServiceSettings.java | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index ff05a7bb4073b..b6490a5cff19a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -99,12 +99,7 @@ public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, URI ur public HuggingFaceChatCompletionServiceSettings(StreamInput in) throws IOException { this.modelId = in.readOptionalString(); this.uri = createUri(in.readString()); - - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { - this.rateLimitSettings = new RateLimitSettings(in); - } else { - this.rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; - } + this.rateLimitSettings = new RateLimitSettings(in); } @Override From 46889014b7cf5a18a7d50f6ab295456363699192 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 19:59:46 +0300 Subject: [PATCH 17/29] Removed getFirst() method calls for backport compatibility --- .../action/HuggingFaceActionCreatorTests.java | 64 +++++++++---------- .../HuggingFaceChatCompletionActionTests.java | 2 +- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 8315a7c1318a6..08b13b54a67ae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -115,14 +115,14 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce ); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(1)); assertThat(requestMap.get("inputs"), instanceOf(List.class)); var inputList = (List) requestMap.get("inputs"); @@ -181,14 +181,14 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx ); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(1)); assertThat(requestMap.get("inputs"), instanceOf(List.class)); var inputList = (List) requestMap.get("inputs"); @@ -231,14 +231,14 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I assertThat(result.asMap(), is(TextEmbeddingFloatResultsTests.buildExpectationFloat(List.of(new float[] { -0.0123F, 0.123F })))); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(1)); assertThat(requestMap.get("inputs"), instanceOf(List.class)); var inputList = (List) requestMap.get("inputs"); @@ -295,14 +295,14 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws ); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(1)); assertThat(requestMap.get("inputs"), instanceOf(List.class)); var inputList = (List) requestMap.get("inputs"); @@ -353,14 +353,14 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc assertThat(webServer.requests(), hasSize(2)); { - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var initialRequestAsMap = entityAsMap(webServer.requests().getFirst().getBody()); + var initialRequestAsMap = entityAsMap(webServer.requests().get(0).getBody()); var initialInputs = initialRequestAsMap.get("inputs"); assertThat(initialInputs, is(List.of("abcd"))); } @@ -415,14 +415,14 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var initialRequestAsMap = entityAsMap(webServer.requests().getFirst().getBody()); + var initialRequestAsMap = entityAsMap(webServer.requests().get(0).getBody()); var initialInputs = initialRequestAsMap.get("inputs"); assertThat(initialInputs, is(List.of("123"))); @@ -474,14 +474,14 @@ public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() thro assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); assertThat(requestMap.get("model"), is("model")); @@ -525,14 +525,14 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th ); assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().getFirst().getUri().getQuery()); + assertNull(webServer.requests().get(0).getUri().getQuery()); assertThat( - webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters()) ); - assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - var requestMap = entityAsMap(webServer.requests().getFirst().getBody()); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); assertThat(requestMap.get("model"), is("model")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java index 6e60b919328bc..6fdd24b970a9d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java @@ -126,7 +126,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content")))); assertThat(webServer.requests(), hasSize(1)); - MockRequest request = webServer.requests().getFirst(); + MockRequest request = webServer.requests().get(0); assertNull(request.getUri().getQuery()); assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); From bfc807221311ecae2a855ae991eca36117a92c57 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 20:09:29 +0300 Subject: [PATCH 18/29] Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWCWireSerializationTestCase for future serialization testing --- ...ggingFaceChatCompletionServiceSettingsTests.java | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java index 7cfc2d7d2adf5..1dd5f3533268e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java @@ -7,13 +7,14 @@ package org.elasticsearch.xpack.inference.services.huggingface.completion; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -27,7 +28,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -public class HuggingFaceChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< +public class HuggingFaceChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< HuggingFaceChatCompletionServiceSettings> { public static final String MODEL_ID = "some model"; @@ -226,6 +227,14 @@ protected HuggingFaceChatCompletionServiceSettings mutateInstance(HuggingFaceCha return randomValueOtherThan(instance, HuggingFaceChatCompletionServiceSettingsTests::createRandomWithNonNullUrl); } + @Override + protected HuggingFaceChatCompletionServiceSettings mutateInstanceForVersion( + HuggingFaceChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + private static HuggingFaceChatCompletionServiceSettings createRandomWithNonNullUrl() { return createRandom(randomAlphaOfLength(15)); } From 13ef13b5f17994012122deafbc28b826758a73ee Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 20:25:29 +0300 Subject: [PATCH 19/29] Refactored tests to use stripWhitespace method for readability --- ...aceChatCompletionServiceSettingsTests.java | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java index 1dd5f3533268e..51b6429884374 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; @@ -192,9 +193,17 @@ public void testToXContent_WritesAllValues() throws IOException { XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"model_id":"some model","url":"https://www.elastic.co","rate_limit":{"requests_per_minute":2}}""")); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 2 + } + } + """); + + assertThat(xContentResult, is(expected)); } public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { @@ -206,9 +215,15 @@ public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"url":"https://www.elastic.co","rate_limit":{"requests_per_minute":3000}}""")); + var expected = XContentHelper.stripWhitespace(""" + { + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """); + assertThat(xContentResult, is(expected)); } @Override From 129caaf3cf73bf050719cc380ab8b1332830c35b Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 20:42:31 +0300 Subject: [PATCH 20/29] Refactored javadoc for HuggingFaceService --- .../inference/services/huggingface/HuggingFaceService.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 9b908e5c68817..462358007cd0b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -56,7 +56,7 @@ /** * This class is responsible for managing the Hugging Face inference service. - * It handles the creation of models, chunked inference, and unified completion inference. + * It manages model creation, as well as chunked, non-chunked, and unified completion inference. */ public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; From 214de5fafd4db54a0b506ccd9e51bfeccaa7048e Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 21:45:09 +0300 Subject: [PATCH 21/29] Renamed HF chat completion TransportVersion constant names --- server/src/main/java/org/elasticsearch/TransportVersions.java | 4 ++-- .../completion/HuggingFaceChatCompletionServiceSettings.java | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 9765d2ca40b01..2b09ee98e926b 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -164,7 +164,7 @@ static TransportVersion def(int id) { public static final TransportVersion SEARCH_INCREMENTAL_TOP_DOCS_NULL_BACKPORT_8_19 = def(8_841_0_20); public static final TransportVersion ML_INFERENCE_SAGEMAKER_8_19 = def(8_841_0_21); public static final TransportVersion ESQL_REPORT_ORIGINAL_TYPES_BACKPORT_8_19 = def(8_841_0_22); - public static final TransportVersion ML_INFERENCE_HUGGING_FACE_8_19 = def(8_841_0_23); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_23); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00); @@ -237,7 +237,7 @@ static TransportVersion def(int id) { public static final TransportVersion PINNED_RETRIEVER = def(9_068_0_00); public static final TransportVersion ML_INFERENCE_SAGEMAKER = def(9_069_0_00); public static final TransportVersion WRITE_LOAD_INCLUDES_BUFFER_WRITES = def(9_070_00_0); - public static final TransportVersion ML_INFERENCE_HUGGING_FACE = def(9_071_0_00); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_071_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java index b6490a5cff19a..af88316ef5161 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -144,7 +144,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_HUGGING_FACE; + return TransportVersions.ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED; } @Override From d3411d607960241bfaacffa227a2bc17451651c3 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 21:51:38 +0300 Subject: [PATCH 22/29] Added random string generation in unit test --- ...HuggingFaceUnifiedChatCompletionRequestTests.java | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java index 8989c76a46cf6..103a2a035cee1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java @@ -15,7 +15,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; import java.io.IOException; -import java.net.URISyntaxException; import java.util.List; import java.util.Map; @@ -27,7 +26,7 @@ public class HuggingFaceUnifiedChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithStreaming() throws IOException { - var request = createRequest("url", "secret", "abcd", "model", true); + var request = createRequest("url", "secret", randomAlphaOfLength(15), "model", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -37,8 +36,9 @@ public void testCreateRequest_WithStreaming() throws IOException { assertThat(requestMap.get("stream"), is(true)); } - public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { - var request = createRequest("url", "secret", "abcd", "model", true); + public void testTruncate_DoesNotReduceInputTextSize() throws IOException { + String input = randomAlphaOfLength(5); + var request = createRequest("url", "secret", input, "model", true); var truncatedRequest = request.truncate(); assertThat(request.getURI().toString(), is("url")); @@ -50,7 +50,7 @@ public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, assertThat(requestMap, aMapWithSize(5)); // We do not truncate for Hugging Face chat completions - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd")))); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); assertTrue((Boolean) requestMap.get("stream")); @@ -58,7 +58,7 @@ public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, } public void testTruncationInfo_ReturnsNull() { - var request = createRequest("url", "secret", "abcd", "model", true); + var request = createRequest("url", "secret", randomAlphaOfLength(5), "model", true); assertNull(request.getTruncationInfo()); } From e170b96f51a5d1b41bd9a0dac176299d1dbb929d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 22:40:17 +0300 Subject: [PATCH 23/29] Refactored javadocs for HuggingFace requests --- .../completion/HuggingFaceUnifiedChatCompletionRequest.java | 2 +- .../request/embeddings/HuggingFaceEmbeddingsRequest.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java index 718ee082aa813..71ce78b8bd0b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java @@ -25,7 +25,7 @@ import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; /** - * This class is responsible for creating a request to the Hugging Face API for chat completions. + * This class is responsible for creating Hugging Face chat completions HTTP requests. * It handles the preparation of the HTTP request with the necessary headers and body. */ public class HuggingFaceUnifiedChatCompletionRequest implements Request { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequest.java index 44dfb7eebe894..788ae0ab15a2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequest.java @@ -25,7 +25,7 @@ import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; /** - * This class is responsible for creating a request to the Hugging Face API for embeddings. + * This class is responsible for creating Hugging Face embeddings HTTP requests. * It handles the truncation of input data and prepares the HTTP request with the necessary headers and body. */ public class HuggingFaceEmbeddingsRequest implements Request { From 473dee6c48c37cf0b0b16a198a07a03a318813fb Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 22:41:14 +0300 Subject: [PATCH 24/29] Refactored tests to reduce duplication --- .../action/HuggingFaceActionCreatorTests.java | 72 +++++++++---------- 1 file changed, 32 insertions(+), 40 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 08b13b54a67ae..895aea8067c46 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests; @@ -462,31 +463,13 @@ public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() thro """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); - var actionCreator = new HuggingFaceActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + PlainActionFuture listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool)); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("n"), is(1)); - assertThat(requestMap.get("stream"), is(false)); + assertChatCompletionRequest(); } } @@ -508,15 +491,10 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); - var actionCreator = new HuggingFaceActionCreator( + PlainActionFuture listener = createChatCompletionFuture( sender, new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) ); - var action = actionCreator.create(model); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -524,20 +502,34 @@ public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() th is("Failed to send Hugging Face completion request from inference entity id " + "[id]. Cause: Required [choices]") ); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat( - webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), - equalTo(XContentType.JSON.mediaTypeWithoutParameters()) - ); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("n"), is(1)); - assertThat(requestMap.get("stream"), is(false)); + assertChatCompletionRequest(); } } + + private PlainActionFuture createChatCompletionFuture(Sender sender, ServiceComponents threadPool) { + var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + var actionCreator = new HuggingFaceActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + return listener; + } + + private void assertChatCompletionRequest() throws IOException { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } } From cb0310040dcf2f2f8de6b207779f4ea3e9a9772c Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Fri, 2 May 2025 22:45:35 +0300 Subject: [PATCH 25/29] Added changelog file --- docs/changelog/127254.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/127254.yaml diff --git a/docs/changelog/127254.yaml b/docs/changelog/127254.yaml new file mode 100644 index 0000000000000..366b1a2cce00b --- /dev/null +++ b/docs/changelog/127254.yaml @@ -0,0 +1,5 @@ +pr: 127254 +summary: "[ML] Add HuggingFace Chat Completion support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] From aae528a012eaf6e45ce1c9363542d85108e31908 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 5 May 2025 22:35:39 +0300 Subject: [PATCH 26/29] Add HuggingFaceChatCompletionResponseHandler and associated tests --- ...gingFaceChatCompletionResponseHandler.java | 167 ++++++++++++++++++ .../huggingface/HuggingFaceService.java | 3 +- ...iUnifiedChatCompletionResponseHandler.java | 19 +- ...aceChatCompletionResponseHandlerTests.java | 132 ++++++++++++++ 4 files changed, 316 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..921991f9a2e97 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -0,0 +1,167 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +import java.util.Locale; +import java.util.Optional; +import java.util.concurrent.Flow; + +import static org.elasticsearch.core.Strings.format; + +/** + * Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints. + * Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code". + */ +public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + return super.parseResult(request, flow); + } + + public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, HuggingFaceErrorResponse::fromResponse); + } + + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + if (request.isStreaming()) { + var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + return errorResponse instanceof HuggingFaceErrorResponse huggingFaceErrorResponse + ? new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + extractErrorCode(huggingFaceErrorResponse) + ) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } else { + return super.buildError(message, request, result, errorResponse); + } + } + + @Override + protected Exception buildMidStreamError(Request request, String message, Exception e) { + var errorResponse = HuggingFaceErrorResponse.fromString(message); + if (errorResponse instanceof HuggingFaceErrorResponse huggingFaceErrorResponse) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + request.getInferenceEntityId(), + errorResponse.getErrorMessage() + ), + createErrorType(errorResponse), + extractErrorCode(huggingFaceErrorResponse) + ); + } else if (e != null) { + return UnifiedChatCompletionException.fromThrowable(e); + } else { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), + createErrorType(errorResponse), + "stream_error" + ); + } + } + + private static String extractErrorCode(HuggingFaceErrorResponse huggingFaceErrorResponse) { + return huggingFaceErrorResponse.httpStatusCode() != null ? String.valueOf(huggingFaceErrorResponse.httpStatusCode()) : null; + } + + private static class HuggingFaceErrorResponse extends ErrorResponse { + private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( + "hugging_face_error", + true, + args -> Optional.ofNullable((HuggingFaceErrorResponse) args[0]) + ); + private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( + "hugging_face_error", + true, + args -> new HuggingFaceErrorResponse((String) args[0], (Integer) args[1]) + ); + + static { + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); + ERROR_BODY_PARSER.declareIntOrNull(ConstructingObjectParser.optionalConstructorArg(), -1, new ParseField("http_status_code")); + + ERROR_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField("error") + ); + } + + private static ErrorResponse fromResponse(HttpResult response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response.body()) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } + + private static ErrorResponse fromString(String response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } + + @Nullable + private final Integer httpStatusCode; + + HuggingFaceErrorResponse(String errorMessage, @Nullable Integer httpStatusCode) { + super(errorMessage); + this.httpStatusCode = httpStatusCode; + } + + @Nullable + public Integer httpStatusCode() { + return httpStatusCode; + } + + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 462358007cd0b..133f7b5be6b62 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -40,7 +40,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; -import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -68,7 +67,7 @@ public class HuggingFaceService extends HuggingFaceBaseService { TaskType.COMPLETION, TaskType.CHAT_COMPLETION ); - private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler( + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler( "hugging face chat completion", OpenAiChatCompletionResponseEntity::fromResponse ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index efd52a4960f11..e60ba31823107 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -29,6 +29,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.Flow; +import java.util.function.Function; import static org.elasticsearch.core.Strings.format; @@ -37,6 +38,14 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa super(requestType, parseFunction, OpenAiErrorResponse::fromResponse); } + public OpenAiUnifiedChatCompletionResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction + ) { + super(requestType, parseFunction, errorParseFunction); + } + @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); @@ -59,7 +68,7 @@ protected Exception buildError(String message, Request request, HttpResult resul : new UnifiedChatCompletionException( restStatus, errorMessage, - errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown", + createErrorType(errorResponse), restStatus.name().toLowerCase(Locale.ROOT) ); } else { @@ -67,7 +76,11 @@ protected Exception buildError(String message, Request request, HttpResult resul } } - private static Exception buildMidStreamError(Request request, String message, Exception e) { + protected static String createErrorType(ErrorResponse errorResponse) { + return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; + } + + protected Exception buildMidStreamError(Request request, String message, Exception e) { var errorResponse = OpenAiErrorResponse.fromString(message); if (errorResponse instanceof OpenAiErrorResponse oer) { return new UnifiedChatCompletionException( @@ -88,7 +101,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex return new UnifiedChatCompletionException( RestStatus.INTERNAL_SERVER_ERROR, format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown", + createErrorType(errorResponse), "stream_error" ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..e209666e4f07f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase { + private final HuggingFaceChatCompletionResponseHandler responseHandler = new HuggingFaceChatCompletionResponseHandler( + "chat completions", + (a, b) -> mock() + ); + + public void testFailValidationWithAllFields() throws IOException { + var responseJson = """ + { + "error": { + "message": "a message", + "http_status_code": 422 + } + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(""" + {"error":{"code":"422","message":"Received a server error status code for request from inference entity id [id] status [500]. \ + Error message: [a message]","type":"HuggingFaceErrorResponse"}}""")); + } + + public void testFailValidationWithoutOptionalFields() throws IOException { + var responseJson = """ + { + "error": { + "message": "a message" + } + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(""" + {"error":{"message":"Received a server error status code for request from inference entity id [id] status [500]. \ + Error message: [a message]","type":"HuggingFaceErrorResponse"}}""")); + } + + public void testFailValidationWithInvalidJson() throws IOException { + var responseJson = """ + what? this isn't a json + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(""" + {"error":{"code":"bad_request","message":"Received a server error status code for request from inference entity id [id] status\ + [500]","type":"ErrorResponse"}}""")); + } + + private String invalidResponseJson(String responseJson) throws IOException { + var exception = invalidResponse(responseJson); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private Exception invalidResponse(String responseJson) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)), + true + ) + ); + } + + private static Request mockRequest() { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn("id"); + when(request.isStreaming()).thenReturn(true); + return request; + } + + private static HttpResponse mock500Response() { + int statusCode = 500; + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } +} From 82f8049936a4a2687c98bd7a7eef1c2be1384b0d Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Mon, 5 May 2025 23:13:14 +0300 Subject: [PATCH 27/29] Refactor error handling in HuggingFaceServiceTests to standardize error response codes and types --- .../huggingface/HuggingFaceServiceTests.java | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 7e5f6cc78a241..63a0410533479 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -352,9 +352,7 @@ public void testUnifiedCompletionError() throws Exception { { "error": { "message": "The model `gpt-4awero` does not exist or you do not have access to it.", - "type": "invalid_request_error", - "param": null, - "code": "model_not_found" + "http_status_code": "404" } }"""; webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); @@ -385,10 +383,10 @@ public void testUnifiedCompletionError() throws Exception { assertThat(json, is(""" {\ "error":{\ - "code":"model_not_found",\ + "code":"404",\ "message":"Received an unsuccessful status code for request from inference entity id [id] status \ [404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\ - "type":"invalid_request_error"\ + "type":"HuggingFaceErrorResponse"\ }}""")); } catch (IOException ex) { throw new RuntimeException(ex); @@ -402,7 +400,7 @@ public void testUnifiedCompletionError() throws Exception { public void testMidStreamUnifiedCompletionError() throws Exception { String responseJson = """ event: error - data: { "error": { "message": "Timed out waiting for more data", "type": "timeout" } } + data: { "error": { "message": "Timed out waiting for more data" } } """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); @@ -411,7 +409,7 @@ public void testMidStreamUnifiedCompletionError() throws Exception { "error":{\ "message":"Received an error response for request from inference entity id [id]. Error message: \ [Timed out waiting for more data]",\ - "type":"timeout"\ + "type":"HuggingFaceErrorResponse"\ }}"""); } From cdb3c1cbb6a12c064c82c02e618863e776dc2064 Mon Sep 17 00:00:00 2001 From: Jan Kazlouski Date: Wed, 7 May 2025 20:25:29 +0300 Subject: [PATCH 28/29] Refactor HuggingFace error handling to improve response structure and add streaming support --- ...gingFaceChatCompletionResponseHandler.java | 84 +++++++------ ...aceUnifiedChatCompletionRequestEntity.java | 4 +- .../HuggingFaceErrorResponseEntity.java | 3 + ...aceChatCompletionResponseHandlerTests.java | 21 ++-- .../huggingface/HuggingFaceServiceTests.java | 119 +++++++++++++----- 5 files changed, 145 insertions(+), 86 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java index 921991f9a2e97..8dffd612db5c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.inference.services.huggingface; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -21,11 +20,11 @@ import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity; import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; import java.util.Locale; import java.util.Optional; -import java.util.concurrent.Flow; import static org.elasticsearch.core.Strings.format; @@ -35,13 +34,10 @@ */ public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { - @Override - public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { - return super.parseResult(request, flow); - } + private static final String HUGGING_FACE_ERROR = "hugging_face_error"; public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { - super(requestType, parseFunction, HuggingFaceErrorResponse::fromResponse); + super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse); } @Override @@ -51,12 +47,12 @@ protected Exception buildError(String message, Request request, HttpResult resul if (request.isStreaming()) { var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); var restStatus = toRestStatus(responseStatusCode); - return errorResponse instanceof HuggingFaceErrorResponse huggingFaceErrorResponse + return errorResponse instanceof HuggingFaceErrorResponseEntity ? new UnifiedChatCompletionException( restStatus, errorMessage, - createErrorType(errorResponse), - extractErrorCode(huggingFaceErrorResponse) + HUGGING_FACE_ERROR, + restStatus.name().toLowerCase(Locale.ROOT) ) : new UnifiedChatCompletionException( restStatus, @@ -71,8 +67,8 @@ protected Exception buildError(String message, Request request, HttpResult resul @Override protected Exception buildMidStreamError(Request request, String message, Exception e) { - var errorResponse = HuggingFaceErrorResponse.fromString(message); - if (errorResponse instanceof HuggingFaceErrorResponse huggingFaceErrorResponse) { + var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message); + if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { return new UnifiedChatCompletionException( RestStatus.INTERNAL_SERVER_ERROR, format( @@ -81,8 +77,8 @@ protected Exception buildMidStreamError(Request request, String message, Excepti request.getInferenceEntityId(), errorResponse.getErrorMessage() ), - createErrorType(errorResponse), - extractErrorCode(huggingFaceErrorResponse) + HUGGING_FACE_ERROR, + extractErrorCode(streamingHuggingFaceErrorResponseEntity) ); } else if (e != null) { return UnifiedChatCompletionException.fromThrowable(e); @@ -96,25 +92,40 @@ protected Exception buildMidStreamError(Request request, String message, Excepti } } - private static String extractErrorCode(HuggingFaceErrorResponse huggingFaceErrorResponse) { - return huggingFaceErrorResponse.httpStatusCode() != null ? String.valueOf(huggingFaceErrorResponse.httpStatusCode()) : null; + private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { + return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null + ? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode()) + : null; } - private static class HuggingFaceErrorResponse extends ErrorResponse { + /** + * Represents a structured error response specifically for streaming operations + * using HuggingFace APIs. This is separate from non-streaming error responses, + * which are handled by {@link HuggingFaceErrorResponseEntity}. + * An example error response for failed field validation for streaming operation would look like + * + * { + * "error": "Input validation error: cannot compile regex from schema", + * "http_status_code": 422 + * } + * + */ + private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse { private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( - "hugging_face_error", + HUGGING_FACE_ERROR, true, - args -> Optional.ofNullable((HuggingFaceErrorResponse) args[0]) - ); - private static final ConstructingObjectParser ERROR_BODY_PARSER = new ConstructingObjectParser<>( - "hugging_face_error", - true, - args -> new HuggingFaceErrorResponse((String) args[0], (Integer) args[1]) + args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0]) ); + private static final ConstructingObjectParser ERROR_BODY_PARSER = + new ConstructingObjectParser<>( + HUGGING_FACE_ERROR, + true, + args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1]) + ); static { - ERROR_BODY_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("message")); - ERROR_BODY_PARSER.declareIntOrNull(ConstructingObjectParser.optionalConstructorArg(), -1, new ParseField("http_status_code")); + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message")); + ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code")); ERROR_PARSER.declareObjectOrNull( ConstructingObjectParser.optionalConstructorArg(), @@ -124,19 +135,12 @@ private static class HuggingFaceErrorResponse extends ErrorResponse { ); } - private static ErrorResponse fromResponse(HttpResult response) { - try ( - XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(XContentParserConfiguration.EMPTY, response.body()) - ) { - return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); - } catch (Exception e) { - // swallow the error - } - - return ErrorResponse.UNDEFINED_ERROR; - } - + /** + * Parses a streaming HuggingFace error response from a JSON string. + * + * @param response the raw JSON string representing an error + * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails + */ private static ErrorResponse fromString(String response) { try ( XContentParser parser = XContentFactory.xContent(XContentType.JSON) @@ -153,7 +157,7 @@ private static ErrorResponse fromString(String response) { @Nullable private final Integer httpStatusCode; - HuggingFaceErrorResponse(String errorMessage, @Nullable Integer httpStatusCode) { + StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) { super(errorMessage); this.httpStatusCode = httpStatusCode; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java index de7f1bf04de0d..372954b3ddf07 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java @@ -36,7 +36,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); unifiedRequestEntity.toXContent(builder, params); - builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); + if (model.getServiceSettings().modelId() != null) { + builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); + } if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { builder.field(MAX_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java index 63b5e5622d871..d30a60c341a58 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java @@ -21,6 +21,9 @@ public HuggingFaceErrorResponseEntity(String message) { } /** + * Represents a structured error response specifically for non-streaming operations + * using HuggingFace APIs. This is separate from streaming error responses, + * which are handled by private nested HuggingFaceChatCompletionResponseHandler.StreamingHuggingFaceErrorResponseEntity. * An example error response for invalid auth would look like * * { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java index e209666e4f07f..f1560edc48ed0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java @@ -37,34 +37,33 @@ public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase { public void testFailValidationWithAllFields() throws IOException { var responseJson = """ { - "error": { - "message": "a message", - "http_status_code": 422 - } + "error": "a message", + "type": "validation" } """; var errorJson = invalidResponseJson(responseJson); assertThat(errorJson, is(""" - {"error":{"code":"422","message":"Received a server error status code for request from inference entity id [id] status [500]. \ - Error message: [a message]","type":"HuggingFaceErrorResponse"}}""")); + {"error":{"code":"bad_request","message":"Received a server error status code for request from \ + inference entity id [id] status [500]. \ + Error message: [a message]",\ + "type":"hugging_face_error"}}""")); } public void testFailValidationWithoutOptionalFields() throws IOException { var responseJson = """ { - "error": { - "message": "a message" - } + "error": "a message" } """; var errorJson = invalidResponseJson(responseJson); assertThat(errorJson, is(""" - {"error":{"message":"Received a server error status code for request from inference entity id [id] status [500]. \ - Error message: [a message]","type":"HuggingFaceErrorResponse"}}""")); + {"error":{"code":"bad_request","message":"Received a server error status code for request from \ + inference entity id [id] status [500]. \ + Error message: [a message]","type":"hugging_face_error"}}""")); } public void testFailValidationWithInvalidJson() throws IOException { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 63a0410533479..12c0d85674a58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -347,14 +347,12 @@ public void testUnifiedCompletionInfer() throws Exception { } } - public void testUnifiedCompletionError() throws Exception { + public void testUnifiedCompletionNonStreamingError() throws Exception { String responseJson = """ { - "error": { - "message": "The model `gpt-4awero` does not exist or you do not have access to it.", - "http_status_code": "404" - } - }"""; + "error": "Model not found." + } + """; webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -383,10 +381,10 @@ public void testUnifiedCompletionError() throws Exception { assertThat(json, is(""" {\ "error":{\ - "code":"404",\ + "code":"not_found",\ "message":"Received an unsuccessful status code for request from inference entity id [id] status \ - [404]. Error message: [The model `gpt-4awero` does not exist or you do not have access to it.]",\ - "type":"HuggingFaceErrorResponse"\ + [404]. Error message: [Model not found.]",\ + "type":"hugging_face_error"\ }}""")); } catch (IOException ex) { throw new RuntimeException(ex); @@ -400,7 +398,46 @@ public void testUnifiedCompletionError() throws Exception { public void testMidStreamUnifiedCompletionError() throws Exception { String responseJson = """ event: error - data: { "error": { "message": "Timed out waiting for more data" } } + data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \ + {\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by Outlines.\\n\ + If it should be supported, please open an issue.","http_status_code":422}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"422",\ + "message":"Received an error response for request from inference entity id [id]. Error message: [Input validation error: \ + cannot compile regex from schema: Unsupported JSON Schema structure {\\"id\\":\\"123\\"} \\nMake sure it is valid to the \ + JSON Schema specification and check if it's supported by Outlines.\\nIf it should be supported, please open an issue.]",\ + "type":"hugging_face_error"\ + }}"""); + } + + public void testMidStreamUnifiedCompletionErrorNoMessage() throws Exception { + String responseJson = """ + event: error + data: {"error":{"http_status_code":422}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"422",\ + "message":"Received an error response for request from inference entity id [id]. Error message: \ + [unknown]",\ + "type":"hugging_face_error"\ + }}"""); + } + + public void testMidStreamUnifiedCompletionErrorNoHttpStatusCode() throws Exception { + String responseJson = """ + event: error + data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \ + {\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by \ + Outlines.\\nIf it should be supported, please open an issue."}} """; webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); @@ -408,8 +445,42 @@ public void testMidStreamUnifiedCompletionError() throws Exception { {\ "error":{\ "message":"Received an error response for request from inference entity id [id]. Error message: \ - [Timed out waiting for more data]",\ - "type":"HuggingFaceErrorResponse"\ + [Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \ + {\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported\ + by Outlines.\\nIf it should be supported, please open an issue.]",\ + "type":"hugging_face_error"\ + }}"""); + } + + public void testMidStreamUnifiedCompletionErrorNoHttpStatusCodeNoMessage() throws Exception { + String responseJson = """ + event: error + data: {"error":{}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "message":"Received an error response for request from inference entity id [id]. Error message: \ + [unknown]",\ + "type":"hugging_face_error"\ + }}"""); + } + + public void testUnifiedCompletionMalformedError() throws Exception { + String responseJson = """ + data: { invalid json } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"bad_request",\ + "message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\ + at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\ + "type":"x_content_parse_exception"\ }}"""); } @@ -448,22 +519,6 @@ private void testStreamError(String expectedResponse) throws Exception { } } - public void testUnifiedCompletionMalformedError() throws Exception { - String responseJson = """ - data: { invalid json } - - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - testStreamError(""" - {\ - "error":{\ - "code":"bad_request",\ - "message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\ - at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\ - "type":"x_content_parse_exception"\ - }}"""); - } - public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {\ @@ -517,10 +572,7 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { String responseJson = """ { "error": { - "message": "You didn't provide an API key...", - "type": "invalid_request_error", - "param": null, - "code": null + "message": "You didn't provide an API key..." } }"""; webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); @@ -540,8 +592,7 @@ public void testInfer_StreamRequestRetry() throws Exception { webServer.enqueue(new MockResponse().setResponseCode(503).setBody(""" { "error": { - "message": "server busy", - "type": "server_busy" + "message": "server busy" } }""")); webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" From 9044beead5a30c917a4390230ec7da892a1d8796 Mon Sep 17 00:00:00 2001 From: Jonathan Buttner Date: Fri, 9 May 2025 11:11:15 -0400 Subject: [PATCH 29/29] Allowing null function name for hugging face models --- .../OpenAiUnifiedStreamingProcessor.java | 2 +- .../OpenAiUnifiedStreamingProcessorTests.java | 69 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 5ab743c3d4cc0..983bb5efbf3fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -250,7 +250,7 @@ private static class FunctionParser { static { PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD)); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD)); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD)); } public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessorTests.java index c51644a1e279f..1c72aaf4a147b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.util.List; +import static org.hamcrest.Matchers.is; + public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase { public void testJsonLiteral() { @@ -182,6 +184,73 @@ public void testJsonLiteralCornerCases() { } } + public void testJsonNullFunctionName() throws IOException { + String json = """ + { + "object": "chat.completion.chunk", + "id": "", + "created": 1746800254, + "model": "/repository", + "system_fingerprint": "3.2.3-sha-a1f3ebe", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": [ + { + "index": 0, + "id": "8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a", + "type": "function", + "function": { + "name": null, + "arguments": " \\\"" + } + } + ] + }, + "logprobs": null, + "finish_reason": null + } + ], + "usage": null + } + """; + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertThat(chunk.id(), is("")); + assertThat(chunk.model(), is("/repository")); + assertThat(chunk.object(), is("chat.completion.chunk")); + assertNull(chunk.usage()); + + List choices = chunk.choices(); + assertThat(choices.size(), is(1)); + + // First choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0); + assertNull(firstChoice.delta().content()); + assertNull(firstChoice.delta().refusal()); + assertThat(firstChoice.delta().role(), is("assistant")); + assertNull(firstChoice.finishReason()); + assertThat(firstChoice.index(), is(0)); + + List toolCalls = firstChoice.delta() + .toolCalls(); + assertThat(toolCalls.size(), is(1)); + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertThat(toolCall.index(), is(0)); + assertThat(toolCall.id(), is("8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a")); + assertThat(toolCall.type(), is("function")); + assertNull(toolCall.function().name()); + assertThat(toolCall.function().arguments(), is(" \"")); + } + } + public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException { // Generate random values for the JSON fields int toolCallIndex = randomIntBetween(0, 10);