diff --git a/docs/changelog/127254.yaml b/docs/changelog/127254.yaml new file mode 100644 index 0000000000000..366b1a2cce00b --- /dev/null +++ b/docs/changelog/127254.yaml @@ -0,0 +1,5 @@ +pr: 127254 +summary: "[ML] Add HuggingFace Chat Completion support to the Inference Plugin" +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index b66af9934dad4..5313e3e6a2d13 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -175,6 +175,7 @@ static TransportVersion def(int id) { public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28); public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29); public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_31); public static final TransportVersion V_9_0_0 = def(9_000_0_09); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11); @@ -255,6 +256,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00); public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00); public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00); + public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index 682eebd0fa69b..1aa302fcb2b32 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -123,7 +123,7 @@ public void testGetServicesWithRerankTaskType() throws IOException { public void testGetServicesWithCompletionTaskType() throws IOException { List services = getServices(TaskType.COMPLETION); - assertThat(services.size(), equalTo(10)); + assertThat(services.size(), equalTo(11)); var providers = providers(services); @@ -140,7 +140,8 @@ public void testGetServicesWithCompletionTaskType() throws IOException { "deepseek", "googleaistudio", "openai", - "streaming_completion_test_service" + "streaming_completion_test_service", + "hugging_face" ).toArray() ) ); @@ -148,11 +149,14 @@ public void testGetServicesWithCompletionTaskType() throws IOException { public void testGetServicesWithChatCompletionTaskType() throws IOException { List services = getServices(TaskType.CHAT_COMPLETION); - assertThat(services.size(), equalTo(4)); + assertThat(services.size(), equalTo(5)); var providers = providers(services); - assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray())); + assertThat( + providers, + containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray()) + ); } public void testGetServicesWithSparseEmbeddingTaskType() throws IOException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 5c719c08142a2..dc6c96d0d860d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -78,6 +78,7 @@ import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings; @@ -357,6 +358,13 @@ private static void addHuggingFaceNamedWriteables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..8dffd612db5c8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandler.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler; + +import java.util.Locale; +import java.util.Optional; + +import static org.elasticsearch.core.Strings.format; + +/** + * Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints. + * Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code". + */ +public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler { + + private static final String HUGGING_FACE_ERROR = "hugging_face_error"; + + public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse); + } + + @Override + protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) { + assert request.isStreaming() : "Only streaming requests support this format"; + var responseStatusCode = result.response().getStatusLine().getStatusCode(); + if (request.isStreaming()) { + var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode); + var restStatus = toRestStatus(responseStatusCode); + return errorResponse instanceof HuggingFaceErrorResponseEntity + ? new UnifiedChatCompletionException( + restStatus, + errorMessage, + HUGGING_FACE_ERROR, + restStatus.name().toLowerCase(Locale.ROOT) + ) + : new UnifiedChatCompletionException( + restStatus, + errorMessage, + createErrorType(errorResponse), + restStatus.name().toLowerCase(Locale.ROOT) + ); + } else { + return super.buildError(message, request, result, errorResponse); + } + } + + @Override + protected Exception buildMidStreamError(Request request, String message, Exception e) { + var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message); + if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format( + "%s for request from inference entity id [%s]. Error message: [%s]", + SERVER_ERROR_OBJECT, + request.getInferenceEntityId(), + errorResponse.getErrorMessage() + ), + HUGGING_FACE_ERROR, + extractErrorCode(streamingHuggingFaceErrorResponseEntity) + ); + } else if (e != null) { + return UnifiedChatCompletionException.fromThrowable(e); + } else { + return new UnifiedChatCompletionException( + RestStatus.INTERNAL_SERVER_ERROR, + format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), + createErrorType(errorResponse), + "stream_error" + ); + } + } + + private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) { + return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null + ? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode()) + : null; + } + + /** + * Represents a structured error response specifically for streaming operations + * using HuggingFace APIs. This is separate from non-streaming error responses, + * which are handled by {@link HuggingFaceErrorResponseEntity}. + * An example error response for failed field validation for streaming operation would look like + * + * { + * "error": "Input validation error: cannot compile regex from schema", + * "http_status_code": 422 + * } + * + */ + private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse { + private static final ConstructingObjectParser, Void> ERROR_PARSER = new ConstructingObjectParser<>( + HUGGING_FACE_ERROR, + true, + args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0]) + ); + private static final ConstructingObjectParser ERROR_BODY_PARSER = + new ConstructingObjectParser<>( + HUGGING_FACE_ERROR, + true, + args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1]) + ); + + static { + ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message")); + ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code")); + + ERROR_PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + ERROR_BODY_PARSER, + null, + new ParseField("error") + ); + } + + /** + * Parses a streaming HuggingFace error response from a JSON string. + * + * @param response the raw JSON string representing an error + * @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails + */ + private static ErrorResponse fromString(String response) { + try ( + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(XContentParserConfiguration.EMPTY, response) + ) { + return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR); + } catch (Exception e) { + // swallow the error + } + + return ErrorResponse.UNDEFINED_ERROR; + } + + @Nullable + private final Integer httpStatusCode; + + StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) { + super(errorMessage); + this.httpStatusCode = httpStatusCode; + } + + @Nullable + public Integer httpStatusCode() { + return httpStatusCode; + } + + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java index 6a750a9ded9b3..62133eff4b658 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceModel.java @@ -9,17 +9,18 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.Objects; -public abstract class HuggingFaceModel extends Model { +public abstract class HuggingFaceModel extends RateLimitGroupingModel { private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings; private final SecureString apiKey; @@ -38,6 +39,16 @@ public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() { return rateLimitServiceSettings; } + @Override + public int rateLimitGroupingHash() { + return Objects.hash(rateLimitServiceSettings.uri(), apiKey); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); + } + public SecureString apiKey() { return apiKey; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java index b09cf8d98b7f3..7bb140e91ec5d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceRequestManager.java @@ -19,7 +19,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest; import java.util.List; import java.util.Objects; @@ -64,7 +64,7 @@ public void execute( ) { List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getTokenLimit()); - var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); + var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index f2a53520e18e6..133f7b5be6b62 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -26,15 +26,21 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -42,16 +48,29 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; +/** + * This class is responsible for managing the Hugging Face inference service. + * It manages model creation, as well as chunked, non-chunked, and unified completion inference. + */ public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; private static final String SERVICE_NAME = "Hugging Face"; - private static final EnumSet supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING); + private static final EnumSet SUPPORTED_TASK_TYPES = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING, + TaskType.COMPLETION, + TaskType.CHAT_COMPLETION + ); + private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler( + "hugging face chat completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { super(factory, serviceComponents); @@ -78,6 +97,14 @@ protected HuggingFaceModel createModel( context ); case SPARSE_EMBEDDING -> new HuggingFaceElserModel(inferenceEntityId, taskType, NAME, serviceSettings, secretSettings, context); + case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + secretSettings, + context + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -139,7 +166,29 @@ protected void doUnifiedCompletionInfer( TimeValue timeout, ActionListener listener ) { - throwUnsupportedUnifiedCompletionOperation(NAME); + if (model instanceof HuggingFaceChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + HuggingFaceChatCompletionModel huggingFaceChatCompletionModel = (HuggingFaceChatCompletionModel) model; + var overriddenModel = HuggingFaceChatCompletionModel.of(huggingFaceChatCompletionModel, inputs.getRequest()); + var manager = new GenericRequestManager<>( + getServiceComponents().threadPool(), + overriddenModel, + UNIFIED_CHAT_COMPLETION_HANDLER, + unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel), + UnifiedChatInput.class + ); + var errorMessage = HuggingFaceActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId()); + var action = new SenderExecutableAction(getSender(), manager, errorMessage); + + action.execute(inputs, timeout, listener); + } + + @Override + public Set supportedStreamingTasks() { + return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION); } @Override @@ -149,7 +198,7 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return supportedTaskTypes; + return SUPPORTED_TASK_TYPES; } @Override @@ -167,14 +216,15 @@ public static InferenceServiceConfiguration get() { return configuration.getOrCompute(); } + private Configuration() {} + private static final LazyInitializable configuration = new LazyInitializable<>( () -> { var configurationMap = new HashMap(); configurationMap.put( URL, - new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings") - .setDescription("The URL endpoint to use for the requests.") + new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDescription("The URL endpoint to use for the requests.") .setLabel("URL") .setRequired(true) .setSensitive(false) @@ -183,12 +233,12 @@ public static InferenceServiceConfiguration get() { .build() ); - configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes)); - configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes)); + configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); + configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(SUPPORTED_TASK_TYPES)); return new InferenceServiceConfiguration.Builder().setService(NAME) .setName(SERVICE_NAME) - .setTaskTypes(supportedTaskTypes) + .setTaskTypes(SUPPORTED_TASK_TYPES) .setConfigurations(configurationMap) .build(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java index d0578edaeaec6..df1ddcb017970 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java @@ -7,16 +7,26 @@ package org.elasticsearch.xpack.inference.services.huggingface.action; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRequestManager; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceResponseHandler; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceElserResponseEntity; import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity; import java.util.Objects; @@ -26,6 +36,13 @@ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the hugging face model type. */ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor { + + public static final String COMPLETION_ERROR_PREFIX = "Hugging Face completions"; + static final String USER_ROLE = "user"; + static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( + "hugging face completion", + OpenAiChatCompletionResponseEntity::fromResponse + ); private final Sender sender; private final ServiceComponents serviceComponents; @@ -46,11 +63,7 @@ public ExecutableAction create(HuggingFaceEmbeddingsModel model) { serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format( - "Failed to send Hugging Face %s request from inference entity id [%s]", - "text embeddings", - model.getInferenceEntityId() - ); + var errorMessage = buildErrorMessage(TaskType.TEXT_EMBEDDING, model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } @@ -63,11 +76,25 @@ public ExecutableAction create(HuggingFaceElserModel model) { serviceComponents.truncator(), serviceComponents.threadPool() ); - var errorMessage = format( - "Failed to send Hugging Face %s request from inference entity id [%s]", - "ELSER", - model.getInferenceEntityId() - ); + var errorMessage = buildErrorMessage(TaskType.SPARSE_EMBEDDING, model.getInferenceEntityId()); return new SenderExecutableAction(sender, requestCreator, errorMessage); } + + @Override + public ExecutableAction create(HuggingFaceChatCompletionModel model) { + var manager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + + var errorMessage = buildErrorMessage(TaskType.COMPLETION, model.getInferenceEntityId()); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, COMPLETION_ERROR_PREFIX); + } + + public static String buildErrorMessage(TaskType requestType, String inferenceId) { + return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestType.toString(), inferenceId); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java index 3fb7b538769e9..ee308db774b1d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionVisitor.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.huggingface.action; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; @@ -15,4 +16,6 @@ public interface HuggingFaceActionVisitor { ExecutableAction create(HuggingFaceEmbeddingsModel model); ExecutableAction create(HuggingFaceElserModel model); + + ExecutableAction create(HuggingFaceChatCompletionModel model); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java new file mode 100644 index 0000000000000..450cec1cc8199 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModel.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; +import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Map; + +public class HuggingFaceChatCompletionModel extends HuggingFaceModel { + + /** + * Creates a new {@link HuggingFaceChatCompletionModel} by copying properties from an existing model, + * replacing the {@code modelId} in the service settings with the one from the given {@link UnifiedCompletionRequest}, + * if present. If the request does not specify a model ID, the original value is retained. + * + * @param model the original model to copy from + * @param request the request potentially containing an overridden model ID + * @return a new {@link HuggingFaceChatCompletionModel} with updated service settings + */ + public static HuggingFaceChatCompletionModel of(HuggingFaceChatCompletionModel model, UnifiedCompletionRequest request) { + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new HuggingFaceChatCompletionServiceSettings( + request.model() != null ? request.model() : originalModelServiceSettings.modelId(), + originalModelServiceSettings.uri(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new HuggingFaceChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getSecretSettings() + ); + } + + public HuggingFaceChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + HuggingFaceChatCompletionServiceSettings.fromMap(serviceSettings, context), + DefaultSecretSettings.fromMap(secrets) + ); + } + + HuggingFaceChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + HuggingFaceChatCompletionServiceSettings serviceSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings), + new ModelSecrets(secretSettings), + serviceSettings, + secretSettings + ); + } + + @Override + public HuggingFaceChatCompletionServiceSettings getServiceSettings() { + return (HuggingFaceChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(HuggingFaceActionVisitor creator) { + return creator.create(this); + } + + @Override + public Integer getTokenLimit() { + throw new UnsupportedOperationException("Token Limit for chat completion is sent in request and not retrieved from the model"); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..af88316ef5161 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettings.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings.extractUri; + +/** + * Settings for the Hugging Face chat completion service. + *

+ * This class contains the settings required to configure a Hugging Face chat completion service, including the model ID, URL, maximum input + * tokens, and rate limit settings. + *

+ */ +public class HuggingFaceChatCompletionServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + HuggingFaceRateLimitServiceSettings { + + public static final String NAME = "hugging_face_completion_service_settings"; + // At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess + // 3000 requests per minute + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000); + + /** + * Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a map of settings. + * @param map the map of settings + * @param context the context for parsing the settings + * @return a new instance of {@link HuggingFaceChatCompletionServiceSettings} + */ + public static HuggingFaceChatCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + var uri = extractUri(map, URL, validationException); + + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + HuggingFaceService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + return new HuggingFaceChatCompletionServiceSettings(modelId, uri, rateLimitSettings); + } + + private final String modelId; + private final URI uri; + private final RateLimitSettings rateLimitSettings; + + public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, String url, @Nullable RateLimitSettings rateLimitSettings) { + this(modelId, createUri(url), rateLimitSettings); + } + + public HuggingFaceChatCompletionServiceSettings(@Nullable String modelId, URI uri, @Nullable RateLimitSettings rateLimitSettings) { + this.modelId = modelId; + this.uri = uri; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + /** + * Creates a new instance of {@link HuggingFaceChatCompletionServiceSettings} from a stream input. + * @param in the stream input + * @throws IOException if an I/O error occurs + */ + public HuggingFaceChatCompletionServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readOptionalString(); + this.uri = createUri(in.readString()); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public URI uri() { + return uri; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentFragmentOfExposedFields(builder, params); + builder.endObject(); + + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (modelId != null) { + builder.field(MODEL_ID, modelId); + } + builder.field(URL, uri.toString()); + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(modelId); + out.writeString(uri.toString()); + rateLimitSettings.writeTo(out); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + HuggingFaceChatCompletionServiceSettings that = (HuggingFaceChatCompletionServiceSettings) object; + return Objects.equals(modelId, that.modelId) + && Objects.equals(uri, that.uri) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, uri, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java new file mode 100644 index 0000000000000..71ce78b8bd0b2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequest.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.completion; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceAccount; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +/** + * This class is responsible for creating Hugging Face chat completions HTTP requests. + * It handles the preparation of the HTTP request with the necessary headers and body. + */ +public class HuggingFaceUnifiedChatCompletionRequest implements Request { + + private final HuggingFaceAccount account; + private final HuggingFaceChatCompletionModel model; + private final UnifiedChatInput unifiedChatInput; + + public HuggingFaceUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) { + this.account = HuggingFaceAccount.of(model); + this.model = Objects.requireNonNull(model); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + } + + /** + * Creates an HTTP request to the Hugging Face API for chat completions. + * The request includes the necessary headers and the input data as a JSON entity. + * + * @return an HttpRequest object containing the HTTP POST request + */ + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(getURI()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); + httpPost.setHeader(createAuthBearerHeader(model.apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + public URI getURI() { + return account.uri(); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public Request truncate() { + // Truncation is not applicable for chat completion requests + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // Truncation is not applicable for chat completion requests + return null; + } + + @Override + public boolean isStreaming() { + return unifiedChatInput.stream(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..372954b3ddf07 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,51 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request.completion; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +public class HuggingFaceUnifiedChatCompletionRequestEntity implements ToXContentObject { + + private static final String MODEL_FIELD = "model"; + private static final String MAX_TOKENS_FIELD = "max_tokens"; + + private final UnifiedChatInput unifiedChatInput; + private final HuggingFaceChatCompletionModel model; + private final UnifiedChatCompletionRequestEntity unifiedRequestEntity; + + public HuggingFaceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) { + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); + this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + unifiedRequestEntity.toXContent(builder, params); + + if (model.getServiceSettings().modelId() != null) { + builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); + } + + if (unifiedChatInput.getRequest().maxCompletionTokens() != null) { + builder.field(MAX_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens()); + } + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequest.java similarity index 73% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequest.java index af4fafff0fb2f..788ae0ab15a2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequest.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.huggingface.request; +package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; @@ -24,25 +24,35 @@ import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; -public class HuggingFaceInferenceRequest implements Request { +/** + * This class is responsible for creating Hugging Face embeddings HTTP requests. + * It handles the truncation of input data and prepares the HTTP request with the necessary headers and body. + */ +public class HuggingFaceEmbeddingsRequest implements Request { private final Truncator truncator; private final HuggingFaceAccount account; private final Truncator.TruncationResult truncationResult; private final HuggingFaceModel model; - public HuggingFaceInferenceRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) { + public HuggingFaceEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, HuggingFaceModel model) { this.truncator = Objects.requireNonNull(truncator); this.account = HuggingFaceAccount.of(model); this.truncationResult = Objects.requireNonNull(input); this.model = Objects.requireNonNull(model); } + /** + * Creates an HTTP request to the Hugging Face API for embeddings. + * The request includes the necessary headers and the input data as a JSON entity. + * + * @return an HttpRequest object containing the HTTP POST request + */ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new HuggingFaceInferenceRequestEntity(truncationResult.input())).getBytes(StandardCharsets.UTF_8) + Strings.toString(new HuggingFaceEmbeddingsRequestEntity(truncationResult.input())).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters()); @@ -64,7 +74,7 @@ public String getInferenceEntityId() { public Request truncate() { var truncateResult = truncator.truncate(truncationResult.input()); - return new HuggingFaceInferenceRequest(truncator, truncateResult, model); + return new HuggingFaceEmbeddingsRequest(truncator, truncateResult, model); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequestEntity.java similarity index 72% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequestEntity.java index 0c38929568f61..b63f01a8df2b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/embeddings/HuggingFaceEmbeddingsRequestEntity.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.huggingface.request; +package org.elasticsearch.xpack.inference.services.huggingface.request.embeddings; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -14,11 +14,15 @@ import java.util.List; import java.util.Objects; -public record HuggingFaceInferenceRequestEntity(List inputs) implements ToXContentObject { +/** + * This class represents the request entity for Hugging Face embeddings. + * It contains a list of input strings that will be used to generate embeddings. + */ +public record HuggingFaceEmbeddingsRequestEntity(List inputs) implements ToXContentObject { private static final String INPUTS_FIELD = "inputs"; - public HuggingFaceInferenceRequestEntity { + public HuggingFaceEmbeddingsRequestEntity { Objects.requireNonNull(inputs); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java index 63b5e5622d871..d30a60c341a58 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceErrorResponseEntity.java @@ -21,6 +21,9 @@ public HuggingFaceErrorResponseEntity(String message) { } /** + * Represents a structured error response specifically for non-streaming operations + * using HuggingFace APIs. This is separate from streaming error responses, + * which are handled by private nested HuggingFaceChatCompletionResponseHandler.StreamingHuggingFaceErrorResponseEntity. * An example error response for invalid auth would look like * * { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java index efd52a4960f11..e60ba31823107 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -29,6 +29,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.Flow; +import java.util.function.Function; import static org.elasticsearch.core.Strings.format; @@ -37,6 +38,14 @@ public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponsePa super(requestType, parseFunction, OpenAiErrorResponse::fromResponse); } + public OpenAiUnifiedChatCompletionResponseHandler( + String requestType, + ResponseParser parseFunction, + Function errorParseFunction + ) { + super(requestType, parseFunction, errorParseFunction); + } + @Override public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); @@ -59,7 +68,7 @@ protected Exception buildError(String message, Request request, HttpResult resul : new UnifiedChatCompletionException( restStatus, errorMessage, - errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown", + createErrorType(errorResponse), restStatus.name().toLowerCase(Locale.ROOT) ); } else { @@ -67,7 +76,11 @@ protected Exception buildError(String message, Request request, HttpResult resul } } - private static Exception buildMidStreamError(Request request, String message, Exception e) { + protected static String createErrorType(ErrorResponse errorResponse) { + return errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown"; + } + + protected Exception buildMidStreamError(Request request, String message, Exception e) { var errorResponse = OpenAiErrorResponse.fromString(message); if (errorResponse instanceof OpenAiErrorResponse oer) { return new UnifiedChatCompletionException( @@ -88,7 +101,7 @@ private static Exception buildMidStreamError(Request request, String message, Ex return new UnifiedChatCompletionException( RestStatus.INTERNAL_SERVER_ERROR, format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()), - errorResponse != null ? errorResponse.getClass().getSimpleName() : "unknown", + createErrorType(errorResponse), "stream_error" ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java index 5ab743c3d4cc0..983bb5efbf3fa 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessor.java @@ -250,7 +250,7 @@ private static class FunctionParser { static { PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD)); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD)); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD)); } public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java new file mode 100644 index 0000000000000..f1560edc48ed0 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.ExceptionsHelper.unwrapCause; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase { + private final HuggingFaceChatCompletionResponseHandler responseHandler = new HuggingFaceChatCompletionResponseHandler( + "chat completions", + (a, b) -> mock() + ); + + public void testFailValidationWithAllFields() throws IOException { + var responseJson = """ + { + "error": "a message", + "type": "validation" + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(""" + {"error":{"code":"bad_request","message":"Received a server error status code for request from \ + inference entity id [id] status [500]. \ + Error message: [a message]",\ + "type":"hugging_face_error"}}""")); + } + + public void testFailValidationWithoutOptionalFields() throws IOException { + var responseJson = """ + { + "error": "a message" + } + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(""" + {"error":{"code":"bad_request","message":"Received a server error status code for request from \ + inference entity id [id] status [500]. \ + Error message: [a message]","type":"hugging_face_error"}}""")); + } + + public void testFailValidationWithInvalidJson() throws IOException { + var responseJson = """ + what? this isn't a json + """; + + var errorJson = invalidResponseJson(responseJson); + + assertThat(errorJson, is(""" + {"error":{"code":"bad_request","message":"Received a server error status code for request from inference entity id [id] status\ + [500]","type":"ErrorResponse"}}""")); + } + + private String invalidResponseJson(String responseJson) throws IOException { + var exception = invalidResponse(responseJson); + assertThat(exception, isA(RetryException.class)); + assertThat(unwrapCause(exception), isA(UnifiedChatCompletionException.class)); + return toJson((UnifiedChatCompletionException) unwrapCause(exception)); + } + + private Exception invalidResponse(String responseJson) { + return expectThrows( + RetryException.class, + () -> responseHandler.validateResponse( + mock(), + mock(), + mockRequest(), + new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)), + true + ) + ); + } + + private static Request mockRequest() { + var request = mock(Request.class); + when(request.getInferenceEntityId()).thenReturn("id"); + when(request.isStreaming()).thenReturn(true); + return request; + } + + private static HttpResponse mock500Response() { + int statusCode = 500; + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + + return response; + } + + private String toJson(UnifiedChatCompletionException e) throws IOException { + try (var builder = XContentFactory.jsonBuilder()) { + e.toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + return XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index b50b821d66359..12c0d85674a58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -12,6 +12,7 @@ import org.apache.http.HttpHeaders; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; @@ -29,20 +30,29 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettingsTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel; @@ -53,14 +63,19 @@ import org.junit.Before; import java.io.IOException; +import java.util.EnumSet; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.ExceptionsHelper.unwrapCause; import static org.elasticsearch.common.xcontent.XContentHelper.toXContent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; @@ -74,7 +89,12 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.isA; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; public class HuggingFaceServiceTests extends ESTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); @@ -175,6 +195,438 @@ public void testParseRequestConfig_CreatesAnElserModel() throws IOException { } } + public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel() throws IOException { + var url = "url"; + var model = "model"; + var secret = "secret"; + + try (var service = createHuggingFaceService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class)); + + var completionsModel = (HuggingFaceChatCompletionModel) m; + + assertThat(completionsModel.getServiceSettings().uri().toString(), is(url)); + assertThat(completionsModel.getServiceSettings().modelId(), is(model)); + assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap( + HuggingFaceChatCompletionServiceSettingsTests.getServiceSettingsMap(url, model), + getSecretSettingsMap(secret) + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_CreatesHuggingFaceChatCompletionsModel_WithoutModelId() throws IOException { + var url = "url"; + var secret = "secret"; + + try (var service = createHuggingFaceService()) { + ActionListener modelVerificationListener = ActionListener.wrap(m -> { + assertThat(m, instanceOf(HuggingFaceChatCompletionModel.class)); + + var completionsModel = (HuggingFaceChatCompletionModel) m; + + assertThat(completionsModel.getServiceSettings().uri().toString(), is(url)); + assertNull(completionsModel.getServiceSettings().modelId()); + assertThat(completionsModel.getSecretSettings().apiKey().toString(), is(secret)); + + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap(getServiceSettingsMap(url), getSecretSettingsMap(secret)), + modelVerificationListener + ); + } + } + + public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name", TaskType.CHAT_COMPLETION); + + try (var service = new HuggingFaceService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + null, + null, + List.of(""), + false, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":"stop"\ + }\ + ],\ + "usage":{\ + "prompt_tokens": 16,\ + "completion_tokens": 28,\ + "total_tokens": 44,\ + "prompt_tokens_details": {\ + "cached_tokens": 0,\ + "audio_tokens": 0\ + },\ + "completion_tokens_details": {\ + "reasoning_tokens": 0,\ + "audio_tokens": 0,\ + "accepted_prediction_tokens": 0,\ + "rejected_prediction_tokens": 0\ + }\ + }\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """ + "model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """ + "usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}"""); + } + } + + public void testUnifiedCompletionNonStreamingError() throws Exception { + String responseJson = """ + { + "error": "Model not found." + } + """; + webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + var latch = new CountDownLatch(1); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + ActionListener.runAfter(ActionTestUtils.assertNoSuccessListener(e -> { + try (var builder = XContentFactory.jsonBuilder()) { + var t = unwrapCause(e); + assertThat(t, isA(UnifiedChatCompletionException.class)); + ((UnifiedChatCompletionException) t).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(""" + {\ + "error":{\ + "code":"not_found",\ + "message":"Received an unsuccessful status code for request from inference entity id [id] status \ + [404]. Error message: [Model not found.]",\ + "type":"hugging_face_error"\ + }}""")); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }), latch::countDown) + ); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + } + } + + public void testMidStreamUnifiedCompletionError() throws Exception { + String responseJson = """ + event: error + data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \ + {\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by Outlines.\\n\ + If it should be supported, please open an issue.","http_status_code":422}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"422",\ + "message":"Received an error response for request from inference entity id [id]. Error message: [Input validation error: \ + cannot compile regex from schema: Unsupported JSON Schema structure {\\"id\\":\\"123\\"} \\nMake sure it is valid to the \ + JSON Schema specification and check if it's supported by Outlines.\\nIf it should be supported, please open an issue.]",\ + "type":"hugging_face_error"\ + }}"""); + } + + public void testMidStreamUnifiedCompletionErrorNoMessage() throws Exception { + String responseJson = """ + event: error + data: {"error":{"http_status_code":422}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"422",\ + "message":"Received an error response for request from inference entity id [id]. Error message: \ + [unknown]",\ + "type":"hugging_face_error"\ + }}"""); + } + + public void testMidStreamUnifiedCompletionErrorNoHttpStatusCode() throws Exception { + String responseJson = """ + event: error + data: {"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \ + {\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported by \ + Outlines.\\nIf it should be supported, please open an issue."}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "message":"Received an error response for request from inference entity id [id]. Error message: \ + [Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure \ + {\\"id\\":\\"123\\"} \\nMake sure it is valid to the JSON Schema specification and check if it's supported\ + by Outlines.\\nIf it should be supported, please open an issue.]",\ + "type":"hugging_face_error"\ + }}"""); + } + + public void testMidStreamUnifiedCompletionErrorNoHttpStatusCodeNoMessage() throws Exception { + String responseJson = """ + event: error + data: {"error":{}} + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "message":"Received an error response for request from inference entity id [id]. Error message: \ + [unknown]",\ + "type":"hugging_face_error"\ + }}"""); + } + + public void testUnifiedCompletionMalformedError() throws Exception { + String responseJson = """ + data: { invalid json } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + testStreamError(""" + {\ + "error":{\ + "code":"bad_request",\ + "message":"[1:3] Unexpected character ('i' (code 105)): was expecting double-quote to start field name\\n\ + at [Source: (String)\\"{ invalid json }\\"; line: 1, column: 3]",\ + "type":"x_content_parse_exception"\ + }}"""); + } + + private void testStreamError(String expectedResponse) throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + var model = HuggingFaceChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null)) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoEvents().hasErrorMatching(e -> { + e = unwrapCause(e); + assertThat(e, isA(UnifiedChatCompletionException.class)); + try (var builder = XContentFactory.jsonBuilder()) { + ((UnifiedChatCompletionException) e).toXContentChunked(EMPTY_PARAMS).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, EMPTY_PARAMS); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + }); + var json = XContentHelper.convertToJson(BytesReference.bytes(builder), false, builder.contentType()); + + assertThat(json, is(expectedResponse)); + } + }); + } + } + + public void testInfer_StreamRequest() throws Exception { + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":null\ + }\ + ]\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + + private InferenceEventsAssertion streamCompletion() throws Exception { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) { + var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + true, + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + return InferenceEventsAssertion.assertThat(listener.actionGet(TIMEOUT)).hasFinishedStream(); + } + } + + public void testInfer_StreamRequest_ErrorResponse() throws Exception { + String responseJson = """ + { + "error": { + "message": "You didn't provide an API key..." + } + }"""; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var e = assertThrows(ElasticsearchStatusException.class, this::streamCompletion); + assertThat(e.status(), equalTo(RestStatus.UNAUTHORIZED)); + assertThat( + e.getMessage(), + equalTo( + "Received an authentication error status code for request from inference entity id [id] status [401]. " + + "Error message: [You didn't provide an API key...]" + ) + ); + } + + public void testInfer_StreamRequestRetry() throws Exception { + webServer.enqueue(new MockResponse().setResponseCode(503).setBody(""" + { + "error": { + "message": "server busy" + } + }""")); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(""" + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":null\ + }\ + ]\ + } + + """)); + + streamCompletion().hasNoErrors().hasEvent(""" + {"completion":[{"delta":"hello, world"}]}"""); + } + + public void testSupportsStreaming() throws IOException { + try (var service = new HuggingFaceService(mock(), createWithEmptySettings(mock()))) { + assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION))); + assertFalse(service.canStream(TaskType.ANY)); + } + } + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createHuggingFaceService()) { var config = getRequestConfigMap(getServiceSettingsMap("url"), getSecretSettingsMap("secret")); @@ -258,6 +710,25 @@ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel() throw } } + public void testParsePersistedConfigWithSecrets_CreatesACompletionModel() throws IOException { + try (var service = createHuggingFaceService()) { + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("url"), new HashMap<>(), getSecretSettingsMap("secret")); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(HuggingFaceChatCompletionModel.class)); + + var chatCompletionModel = (HuggingFaceChatCompletionModel) model; + assertThat(chatCompletionModel.getServiceSettings().uri().toString(), is("url")); + assertThat(chatCompletionModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createHuggingFaceService()) { var persistedConfig = getPersistedConfigMap( @@ -821,7 +1292,7 @@ public void testGetConfiguration() throws Exception { { "service": "hugging_face", "name": "Hugging Face", - "task_types": ["text_embedding", "sparse_embedding"], + "task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"], "configurations": { "api_key": { "description": "API Key for the provider you're connecting to.", @@ -830,7 +1301,7 @@ public void testGetConfiguration() throws Exception { "sensitive": true, "updatable": true, "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"] }, "rate_limit.requests_per_minute": { "description": "Minimize the number of rate limit errors.", @@ -839,17 +1310,16 @@ public void testGetConfiguration() throws Exception { "sensitive": false, "updatable": false, "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"] }, "url": { - "default_value": "https://api.openai.com/v1/embeddings", "description": "The URL endpoint to use for the requests.", "label": "URL", "required": true, "sensitive": false, "updatable": false, "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding"] + "supported_task_types": ["text_embedding", "sparse_embedding", "completion", "chat_completion"] } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index 09619cd076ff5..895aea8067c46 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -24,10 +24,13 @@ import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; import org.junit.After; @@ -38,6 +41,7 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; @@ -425,4 +429,107 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { } } + + public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "object": "chat.completion", + "id": "", + "created": 1745855316, + "model": "/repository", + "system_fingerprint": "3.2.3-sha-a1f3ebe", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 8, + "completion_tokens": 50, + "total_tokens": 58 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createChatCompletionFuture(sender, createWithEmptySettings(threadPool)); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("Hello there, how may I assist you today?")))); + + assertChatCompletionRequest(); + } + } + + public void testSend_FailsFromInvalidResponseFormat_ForChatCompletionAction() throws IOException { + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "invalid_field": "unexpected" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + PlainActionFuture listener = createChatCompletionFuture( + sender, + new ServiceComponents(threadPool, mockThrottlerManager(), settings, TruncatorTests.createTruncator()) + ); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("Failed to send Hugging Face completion request from inference entity id " + "[id]. Cause: Required [choices]") + ); + + assertChatCompletionRequest(); + } + } + + private PlainActionFuture createChatCompletionFuture(Sender sender, ServiceComponents threadPool) { + var model = HuggingFaceChatCompletionModelTests.createCompletionModel(getUrl(webServer), "secret", "model"); + var actionCreator = new HuggingFaceActionCreator(sender, threadPool); + var action = actionCreator.create(model); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("Hello"), false), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + return listener; + } + + private void assertChatCompletionRequest() throws IOException { + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat( + webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), + equalTo(XContentType.JSON.mediaTypeWithoutParameters()) + ); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "Hello")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java new file mode 100644 index 0000000000000..6fdd24b970a9d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceChatCompletionActionTests.java @@ -0,0 +1,243 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.action; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.COMPLETION_HANDLER; +import static org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator.USER_ROLE; +import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class HuggingFaceChatCompletionActionTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "result content" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result content")))); + assertThat(webServer.requests(), hasSize(1)); + + MockRequest request = webServer.requests().get(0); + + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaTypeWithoutParameters())); + assertThat(request.getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); + + var requestMap = entityAsMap(request.getBody()); + assertThat(requestMap.size(), is(4)); + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); + } + } + + public void testExecute_ThrowsURISyntaxException_ForInvalidUrl() throws IOException { + try (var sender = mock(Sender.class)) { + var thrownException = expectThrows(IllegalArgumentException.class, () -> createAction("^^", sender)); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("Failed to send hugging face chat completions request. Cause: failed")); + } + + public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = createSender(senderFactory)) { + sender.start(); + + String responseJson = """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("hugging face chat completions only accepts 1 input")); + assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST)); + } + } + + private ExecutableAction createAction(String url, Sender sender) { + var model = createCompletionModel(url, "secret", "model"); + var manager = new GenericRequestManager<>( + threadPool, + model, + COMPLETION_HANDLER, + inputs -> new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model), + ChatCompletionInput.class + ); + var errorMessage = constructFailedToSendRequestMessage("hugging face chat completions"); + return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "hugging face chat completions"); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java new file mode 100644 index 0000000000000..f23e5f8969c5b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionModelTests.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.List; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class HuggingFaceChatCompletionModelTests extends ESTestCase { + + public void testThrowsURISyntaxException_ForInvalidUrl() { + var thrownException = expectThrows(IllegalArgumentException.class, () -> createCompletionModel("^^", "secret", "id")); + assertThat(thrownException.getMessage(), containsString("unable to parse url [^^]")); + } + + public static HuggingFaceChatCompletionModel createCompletionModel(String url, String apiKey, String modelId) { + return new HuggingFaceChatCompletionModel( + "id", + TaskType.COMPLETION, + "service", + new HuggingFaceChatCompletionServiceSettings(modelId, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static HuggingFaceChatCompletionModel createChatCompletionModel(String url, String apiKey, String modelId) { + return new HuggingFaceChatCompletionModel( + "id", + TaskType.CHAT_COMPLETION, + "service", + new HuggingFaceChatCompletionServiceSettings(modelId, url, null), + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() { + var model = createCompletionModel("url", "api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = HuggingFaceChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() { + var model = createCompletionModel("url", "api_key", null); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = HuggingFaceChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model")); + } + + public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() { + var model = createCompletionModel("url", "api_key", null); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = HuggingFaceChatCompletionModel.of(model, request); + + assertNull(overriddenModel.getServiceSettings().modelId()); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createCompletionModel("url", "api_key", "model_name"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + var overriddenModel = HuggingFaceChatCompletionModel.of(model, request); + + assertThat(overriddenModel.getServiceSettings().modelId(), is("model_name")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..51b6429884374 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/completion/HuggingFaceChatCompletionServiceSettingsTests.java @@ -0,0 +1,271 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class HuggingFaceChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + HuggingFaceChatCompletionServiceSettings> { + + public static final String MODEL_ID = "some model"; + public static final String CORRECT_URL = "https://www.elastic.co"; + public static final int RATE_LIMIT = 2; + + public void testFromMap_AllFields_Success() { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new HuggingFaceChatCompletionServiceSettings( + MODEL_ID, + ServiceUtils.createUri(CORRECT_URL), + new RateLimitSettings(RATE_LIMIT) + ) + ) + ); + } + + public void testFromMap_MissingModelId_Success() { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is(new HuggingFaceChatCompletionServiceSettings(null, ServiceUtils.createUri(CORRECT_URL), new RateLimitSettings(RATE_LIMIT))) + ); + } + + public void testFromMap_MissingRateLimit_Success() { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, MODEL_ID, ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(serviceSettings, is(new HuggingFaceChatCompletionServiceSettings(MODEL_ID, ServiceUtils.createUri(CORRECT_URL), null))); + } + + public void testFromMap_MissingUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format("Validation Failed: 1: [service_settings] does not contain the required setting [url];", ServiceFields.URL) + ) + ); + } + + public void testFromMap_EmptyUrl_ThrowsException() { + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + "", + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value empty string. [%s] must be a non-empty string;", + ServiceFields.URL + ) + ) + ); + } + + public void testFromMap_InvalidUrl_ThrowsException() { + String invalidUrl = "https://www.elastic^^co"; + var thrownException = expectThrows( + ValidationException.class, + () -> HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + invalidUrl, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid url [%s] received for field [%s]", + invalidUrl, + ServiceFields.URL + ) + ) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + MODEL_ID, + ServiceFields.URL, + CORRECT_URL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, RATE_LIMIT)) + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "model_id": "some model", + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 2 + } + } + """); + + assertThat(xContentResult, is(expected)); + } + + public void testToXContent_DoesNotWriteOptionalValues_DefaultRateLimit() throws IOException { + var serviceSettings = HuggingFaceChatCompletionServiceSettings.fromMap( + new HashMap<>(Map.of(ServiceFields.URL, CORRECT_URL)), + ConfigurationParseContext.PERSISTENT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + serviceSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + var expected = XContentHelper.stripWhitespace(""" + { + "url": "https://www.elastic.co", + "rate_limit": { + "requests_per_minute": 3000 + } + } + """); + assertThat(xContentResult, is(expected)); + } + + @Override + protected Writeable.Reader instanceReader() { + return HuggingFaceChatCompletionServiceSettings::new; + } + + @Override + protected HuggingFaceChatCompletionServiceSettings createTestInstance() { + return createRandomWithNonNullUrl(); + } + + @Override + protected HuggingFaceChatCompletionServiceSettings mutateInstance(HuggingFaceChatCompletionServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, HuggingFaceChatCompletionServiceSettingsTests::createRandomWithNonNullUrl); + } + + @Override + protected HuggingFaceChatCompletionServiceSettings mutateInstanceForVersion( + HuggingFaceChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + private static HuggingFaceChatCompletionServiceSettings createRandomWithNonNullUrl() { + return createRandom(randomAlphaOfLength(15)); + } + + private static HuggingFaceChatCompletionServiceSettings createRandom(String url) { + var modelId = randomAlphaOfLength(8); + + return new HuggingFaceChatCompletionServiceSettings(modelId, ServiceUtils.createUri(url), RateLimitSettingsTests.createRandom()); + } + + public static Map getServiceSettingsMap(String url, String model) { + var map = new HashMap(); + + map.put(ServiceFields.URL, url); + map.put(ServiceFields.MODEL_ID, model); + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestEntityTests.java similarity index 78% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestEntityTests.java index ed15fb77d04a7..22e6e1c33a36c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestEntityTests.java @@ -12,16 +12,17 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequestEntity; import java.io.IOException; import java.util.List; import static org.hamcrest.CoreMatchers.is; -public class HuggingFaceInferenceRequestEntityTests extends ESTestCase { +public class HuggingFaceEmbeddingsRequestEntityTests extends ESTestCase { public void testXContent() throws IOException { - var entity = new HuggingFaceInferenceRequestEntity(List.of("abc")); + var entity = new HuggingFaceEmbeddingsRequestEntity(List.of("abc")); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestTests.java similarity index 90% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestTests.java index 768eeb622b943..3cbf8a5e5b28f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceInferenceRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceEmbeddingsRequestTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest; import java.io.IOException; import java.net.URI; @@ -25,7 +26,7 @@ import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class HuggingFaceInferenceRequestTests extends ESTestCase { +public class HuggingFaceEmbeddingsRequestTests extends ESTestCase { @SuppressWarnings("unchecked") public void testCreateRequest() throws URISyntaxException, IOException { var huggingFaceRequest = createRequest("www.google.com", "secret", "abc"); @@ -67,9 +68,9 @@ public void testIsTruncated_ReturnsTrue() throws URISyntaxException, IOException assertTrue(truncatedRequest.getTruncationInfo()[0]); } - public static HuggingFaceInferenceRequest createRequest(String url, String apiKey, String input) throws URISyntaxException { + public static HuggingFaceEmbeddingsRequest createRequest(String url, String apiKey, String input) throws URISyntaxException { - return new HuggingFaceInferenceRequest( + return new HuggingFaceEmbeddingsRequest( TruncatorTests.createTruncator(), new Truncator.TruncationResult(List.of(input), new boolean[] { false }), HuggingFaceEmbeddingsModelTests.createModel(url, apiKey) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..81d26036036c6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModel; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequestEntity; + +import java.io.IOException; +import java.util.ArrayList; + +import static org.elasticsearch.xpack.inference.Utils.assertJsonEquals; +import static org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests.createCompletionModel; + +public class HuggingFaceUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + private static final String ROLE = "user"; + + public void testModelUserFieldsSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + ROLE, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + var unifiedRequest = UnifiedCompletionRequest.of(messageList); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + HuggingFaceChatCompletionModel model = createCompletionModel("test-url", "api-key", "test-endpoint"); + + HuggingFaceUnifiedChatCompletionRequestEntity entity = new HuggingFaceUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java new file mode 100644 index 0000000000000..103a2a035cee1 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/HuggingFaceUnifiedChatCompletionRequestTests.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.huggingface.request; + +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.huggingface.request.completion.HuggingFaceUnifiedChatCompletionRequest; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class HuggingFaceUnifiedChatCompletionRequestTests extends ESTestCase { + + public void testCreateRequest_WithStreaming() throws IOException { + var request = createRequest("url", "secret", randomAlphaOfLength(15), "model", true); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap.get("stream"), is(true)); + } + + public void testTruncate_DoesNotReduceInputTextSize() throws IOException { + String input = randomAlphaOfLength(5); + var request = createRequest("url", "secret", input, "model", true); + var truncatedRequest = request.truncate(); + assertThat(request.getURI().toString(), is("url")); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(5)); + + // We do not truncate for Hugging Face chat completions + assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input)))); + assertThat(requestMap.get("model"), is("model")); + assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); + } + + public void testTruncationInfo_ReturnsNull() { + var request = createRequest("url", "secret", randomAlphaOfLength(5), "model", true); + assertNull(request.getTruncationInfo()); + } + + public static HuggingFaceUnifiedChatCompletionRequest createRequest(String url, String apiKey, String input, @Nullable String model) { + return createRequest(url, apiKey, input, model, false); + } + + public static HuggingFaceUnifiedChatCompletionRequest createRequest( + @Nullable String url, + String apiKey, + String input, + @Nullable String model, + boolean stream + ) { + var chatCompletionModel = HuggingFaceChatCompletionModelTests.createCompletionModel(url, apiKey, model); + return new HuggingFaceUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessorTests.java index c51644a1e279f..1c72aaf4a147b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -19,6 +19,8 @@ import java.io.IOException; import java.util.List; +import static org.hamcrest.Matchers.is; + public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase { public void testJsonLiteral() { @@ -182,6 +184,73 @@ public void testJsonLiteralCornerCases() { } } + public void testJsonNullFunctionName() throws IOException { + String json = """ + { + "object": "chat.completion.chunk", + "id": "", + "created": 1746800254, + "model": "/repository", + "system_fingerprint": "3.2.3-sha-a1f3ebe", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "tool_calls": [ + { + "index": 0, + "id": "8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a", + "type": "function", + "function": { + "name": null, + "arguments": " \\\"" + } + } + ] + }, + "logprobs": null, + "finish_reason": null + } + ], + "usage": null + } + """; + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertThat(chunk.id(), is("")); + assertThat(chunk.model(), is("/repository")); + assertThat(chunk.object(), is("chat.completion.chunk")); + assertNull(chunk.usage()); + + List choices = chunk.choices(); + assertThat(choices.size(), is(1)); + + // First choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0); + assertNull(firstChoice.delta().content()); + assertNull(firstChoice.delta().refusal()); + assertThat(firstChoice.delta().role(), is("assistant")); + assertNull(firstChoice.finishReason()); + assertThat(firstChoice.index(), is(0)); + + List toolCalls = firstChoice.delta() + .toolCalls(); + assertThat(toolCalls.size(), is(1)); + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertThat(toolCall.index(), is(0)); + assertThat(toolCall.id(), is("8f7c27be-6803-48e6-bba4-8cdcbcd2ff9a")); + assertThat(toolCall.type(), is("function")); + assertNull(toolCall.function().name()); + assertThat(toolCall.function().arguments(), is(" \"")); + } + } + public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException { // Generate random values for the JSON fields int toolCallIndex = randomIntBetween(0, 10);