Skip to content

Add Hugging Face Chat Completion support to Inference Plugin #127254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
63f21de
Add Hugging Face Chat Completion support to Inference Plugin
Jan-Kazlouski-elastic Apr 23, 2025
6b7dd2e
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic Apr 25, 2025
65e4060
Add support for streaming chat completion task for HuggingFace
Jan-Kazlouski-elastic Apr 25, 2025
404f640
[CI] Auto commit changes from spotless
elasticsearchmachine Apr 25, 2025
ceebb9a
Add support for non-streaming completion task for HuggingFace
Jan-Kazlouski-elastic Apr 25, 2025
acaa35b
Remove RequestManager for HF Chat Completion Task
Jan-Kazlouski-elastic Apr 25, 2025
91fa92e
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic Apr 28, 2025
ff3ef50
Refactored Hugging Face Completion Service Settings, removed Request …
Jan-Kazlouski-elastic Apr 28, 2025
965093b
Refactored Hugging Face Action Creator, added Unit Tests
Jan-Kazlouski-elastic Apr 29, 2025
6757b07
Add Hugging Face Server Test
Jan-Kazlouski-elastic Apr 29, 2025
58ea9fd
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic Apr 29, 2025
df845eb
[CI] Auto commit changes from spotless
elasticsearchmachine Apr 29, 2025
cc24e68
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 2, 2025
5bbe3b7
Removed parameters from media type for Chat Completion Request and un…
Jan-Kazlouski-elastic May 2, 2025
3684816
Removed OpenAI default URL in HuggingFaceService's configuration, fix…
Jan-Kazlouski-elastic May 2, 2025
7670d2c
Refactor error message handling in HuggingFaceActionCreator and Huggi…
Jan-Kazlouski-elastic May 2, 2025
6630be7
Update minimal supported version and add Hugging Face transport versi…
Jan-Kazlouski-elastic May 2, 2025
1efb2ee
Made modelId field optional in HuggingFaceChatCompletionModel, update…
Jan-Kazlouski-elastic May 2, 2025
61537d0
Removed max input tokens field from HuggingFaceChatCompletionServiceS…
Jan-Kazlouski-elastic May 2, 2025
64c0685
Removed if statement checking TransportVersion for HuggingFaceChatCom…
Jan-Kazlouski-elastic May 2, 2025
4688901
Removed getFirst() method calls for backport compatibility
Jan-Kazlouski-elastic May 2, 2025
bfc8072
Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWC…
Jan-Kazlouski-elastic May 2, 2025
13ef13b
Refactored tests to use stripWhitespace method for readability
Jan-Kazlouski-elastic May 2, 2025
129caaf
Refactored javadoc for HuggingFaceService
Jan-Kazlouski-elastic May 2, 2025
214de5f
Renamed HF chat completion TransportVersion constant names
Jan-Kazlouski-elastic May 2, 2025
d3411d6
Added random string generation in unit test
Jan-Kazlouski-elastic May 2, 2025
e170b96
Refactored javadocs for HuggingFace requests
Jan-Kazlouski-elastic May 2, 2025
473dee6
Refactored tests to reduce duplication
Jan-Kazlouski-elastic May 2, 2025
cb03100
Added changelog file
Jan-Kazlouski-elastic May 2, 2025
c856853
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 5, 2025
bd2e601
Merge remote-tracking branch 'refs/remotes/origin/main' into feature/…
Jan-Kazlouski-elastic May 5, 2025
aae528a
Add HuggingFaceChatCompletionResponseHandler and associated tests
Jan-Kazlouski-elastic May 5, 2025
82f8049
Refactor error handling in HuggingFaceServiceTests to standardize err…
Jan-Kazlouski-elastic May 5, 2025
b0679d5
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 6, 2025
2fa3dff
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 7, 2025
cdb3c1c
Refactor HuggingFace error handling to improve response structure and…
Jan-Kazlouski-elastic May 7, 2025
9370b57
Merge remote-tracking branch 'origin/main' into feature/hugging-face-…
Jan-Kazlouski-elastic May 11, 2025
9044bee
Allowing null function name for hugging face models
jonathan-buttner May 9, 2025
e72a312
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 12, 2025
e2cb334
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 13, 2025
a4b5d2c
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 13, 2025
c5988ed
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 14, 2025
1547559
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 19, 2025
71c6057
Merge branch 'main' into feature/hugging-face-chat-completion-integra…
Jan-Kazlouski-elastic May 19, 2025
228fffa
Merge branch 'main' of https://github.com/Jan-Kazlouski-elastic/elast…
Jan-Kazlouski-elastic May 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/127254.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 127254
summary: "[ML] Add HuggingFace Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_28);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING_8_19 = def(8_841_0_29);
public static final TransportVersion ESQL_DRIVER_TASK_DESCRIPTION_8_19 = def(8_841_0_30);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED_8_19 = def(8_841_0_31);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -255,6 +256,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_FIELD_ATTRIBUTE_DROP_TYPE = def(9_075_0_00);
public static final TransportVersion ESQL_TIME_SERIES_SOURCE_STATUS = def(9_076_0_00);
public static final TransportVersion ESQL_HASH_OPERATOR_STATUS_OUTPUT_TIME = def(9_077_0_00);
public static final TransportVersion ML_INFERENCE_HUGGING_FACE_CHAT_COMPLETION_ADDED = def(9_078_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public void testGetServicesWithRerankTaskType() throws IOException {

public void testGetServicesWithCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.COMPLETION);
assertThat(services.size(), equalTo(10));
assertThat(services.size(), equalTo(11));

var providers = providers(services);

Expand All @@ -140,19 +140,23 @@ public void testGetServicesWithCompletionTaskType() throws IOException {
"deepseek",
"googleaistudio",
"openai",
"streaming_completion_test_service"
"streaming_completion_test_service",
"hugging_face"
).toArray()
)
);
}

public void testGetServicesWithChatCompletionTaskType() throws IOException {
List<Object> services = getServices(TaskType.CHAT_COMPLETION);
assertThat(services.size(), equalTo(4));
assertThat(services.size(), equalTo(5));

var providers = providers(services);

assertThat(providers, containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service").toArray()));
assertThat(
providers,
containsInAnyOrder(List.of("deepseek", "elastic", "openai", "streaming_completion_test_service", "hugging_face").toArray())
);
}

public void testGetServicesWithSparseEmbeddingTaskType() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.completion.HuggingFaceChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
Expand Down Expand Up @@ -357,6 +358,13 @@ private static void addHuggingFaceNamedWriteables(List<NamedWriteableRegistry.En
namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, HuggingFaceServiceSettings.NAME, HuggingFaceServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
HuggingFaceChatCompletionServiceSettings.NAME,
HuggingFaceChatCompletionServiceSettings::new
)
);
}

private static void addGoogleAiStudioNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.huggingface;

import org.elasticsearch.core.Nullable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.huggingface.response.HuggingFaceErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;

import java.util.Locale;
import java.util.Optional;

import static org.elasticsearch.core.Strings.format;

/**
* Handles streaming chat completion responses and error parsing for Hugging Face inference endpoints.
* Adapts the OpenAI handler to support Hugging Face's simpler error schema with fields like "message" and "http_status_code".
*/
public class HuggingFaceChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {

private static final String HUGGING_FACE_ERROR = "hugging_face_error";

public HuggingFaceChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, HuggingFaceErrorResponseEntity::fromResponse);
}

@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
var responseStatusCode = result.response().getStatusLine().getStatusCode();
if (request.isStreaming()) {
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
Copy link
Contributor

@jonathan-buttner jonathan-buttner May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error format for the code path that calls buildError will be different than a mid-stream error.

For example if I change the URL to be invalid for the HF endpoint (add another character or something). The error we get back doesn't include the message from HF:

[2025-05-06T13:32:06,659][WARN ][o.e.x.i.e.h.s.GenericRequestManager] [runTask-0] Failed to process the response for request from inference entity id [test-chat2] of type [hugging face chat completion] with status [404] [Not Found] org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException: Received an unsuccessful status code for request from inference entity id [test-chat2] status [404]
        at org.elasticsearch.inference@9.1.0-SNAPSHOT/org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceChatCompletionResponseHandler.buildError(HuggingFaceChatCompletionResponseHandler.java:65)
        at org.elasticsearch.inference@9.1.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.buildError(BaseResponseHandler.java:113)
        at org.elasticsearch.inference@9.1.0-SNAPSHOT/org.elasticsearch.xpack.inference.services.openai.OpenAiResponseHandler.checkForFailureStatusCode(OpenAiResponseHandler.java:90)
        at org.elasticsearch.inference@9.1.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler.validateResponse(BaseResponseHandler.java:86)
        at org.elasticsearch.inference@9.1.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender$InternalRetrier.lambda$tryAction$1(RetryingHttpSender.java:124)
        at org.elasticsearch.server@9.1.0-SNAPSHOT/org.elasticsearch.action.ActionListenerImplementations$ResponseWrappingActionListener.onResponse(ActionListenerImplementations.java:261)
        at org.elasticsearch.inference@9.1.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.http.StreamingHttpResult$2.onComplete(StreamingHttpResult.java:72)
        at org.elasticsearch.inference@9.1.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.http.StreamingHttpResultPublisher$DataPublisher.sendToSubscriber(StreamingHttpResultPublisher.java:179)
        at org.elasticsearch.inference@9.1.0-SNAPSHOT/org.elasticsearch.xpack.inference.external.http.RequestBasedTaskRunner.run(RequestBasedTaskRunner.java:59)
        at org.elasticsearch.server@9.1.0-SNAPSHOT/org.elasticsearch.common.util.concurrent.ThreadContext$ContextPreservingRunnable.run(ThreadContext.java:977)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1095)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:619)
        at java.base/java.lang.Thread.run(Thread.java:1447)

If I use curl to perform a request to a URL that doesn't exist, this is returned:

{"error":"Not Found: <url>"}

I think the best solution would be to handle both error formats. What we can do is look at the token when we parse the error field and see if it is an object or a string. Here's an example:

https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java#L286-L292

Once we do that we can probably move the error class from here HuggingFaceErrorResponse and replace HuggingFaceErrorResponseEntity.

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I use curl to perform a request to a URL that doesn't exist, this is returned:
{"error":"Not Found: "}

I'm afraid that is not the case when I do the testing. I just get 404 with no response body. However for different type of error it would have "error" field and that is it. During testing I populated responseMap with "error message" and received such response from Elastic:

{
    "error": {
        "code": "not_found",
        "message": "Received an unsuccessful status code for request from inference entity id [hugging-face-chat-completion] status [404]. Error message: [error message]",
        "type": "hugging_face_error"
    }
}

I think it is safe to assume that errors that are returned as part of the steram from HF - will be handled inside of "buildMidStreamError" method using "StreamingHuggingFaceErrorResponseEntity.fromString".
Errors that are returned from HF NOT as stream, but need to be returned as part of the stream from our side will be handled by "buildError" method using "HuggingFaceErrorResponseEntity.fromResponse".
Unit tests are added.

var restStatus = toRestStatus(responseStatusCode);
return errorResponse instanceof HuggingFaceErrorResponseEntity
? new UnifiedChatCompletionException(
restStatus,
errorMessage,
HUGGING_FACE_ERROR,
restStatus.name().toLowerCase(Locale.ROOT)
)
: new UnifiedChatCompletionException(
restStatus,
errorMessage,
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
} else {
return super.buildError(message, request, result, errorResponse);
}
}

@Override
protected Exception buildMidStreamError(Request request, String message, Exception e) {
var errorResponse = StreamingHuggingFaceErrorResponseEntity.fromString(message);
if (errorResponse instanceof StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format(
"%s for request from inference entity id [%s]. Error message: [%s]",
SERVER_ERROR_OBJECT,
request.getInferenceEntityId(),
errorResponse.getErrorMessage()
),
HUGGING_FACE_ERROR,
extractErrorCode(streamingHuggingFaceErrorResponseEntity)
);
} else if (e != null) {
return UnifiedChatCompletionException.fromThrowable(e);
} else {
return new UnifiedChatCompletionException(
RestStatus.INTERNAL_SERVER_ERROR,
format("%s for request from inference entity id [%s]", SERVER_ERROR_OBJECT, request.getInferenceEntityId()),
createErrorType(errorResponse),
"stream_error"
);
}
}

private static String extractErrorCode(StreamingHuggingFaceErrorResponseEntity streamingHuggingFaceErrorResponseEntity) {
return streamingHuggingFaceErrorResponseEntity.httpStatusCode() != null
? String.valueOf(streamingHuggingFaceErrorResponseEntity.httpStatusCode())
: null;
}

/**
* Represents a structured error response specifically for streaming operations
* using HuggingFace APIs. This is separate from non-streaming error responses,
* which are handled by {@link HuggingFaceErrorResponseEntity}.
* An example error response for failed field validation for streaming operation would look like
* <code>
* {
* "error": "Input validation error: cannot compile regex from schema",
* "http_status_code": 422
* }
* </code>
*/
private static class StreamingHuggingFaceErrorResponseEntity extends ErrorResponse {
private static final ConstructingObjectParser<Optional<ErrorResponse>, Void> ERROR_PARSER = new ConstructingObjectParser<>(
HUGGING_FACE_ERROR,
true,
args -> Optional.ofNullable((StreamingHuggingFaceErrorResponseEntity) args[0])
);
private static final ConstructingObjectParser<StreamingHuggingFaceErrorResponseEntity, Void> ERROR_BODY_PARSER =
new ConstructingObjectParser<>(
HUGGING_FACE_ERROR,
true,
args -> new StreamingHuggingFaceErrorResponseEntity(args[0] != null ? (String) args[0] : "unknown", (Integer) args[1])
);

static {
ERROR_BODY_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("message"));
ERROR_BODY_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField("http_status_code"));

ERROR_PARSER.declareObjectOrNull(
ConstructingObjectParser.optionalConstructorArg(),
ERROR_BODY_PARSER,
null,
new ParseField("error")
);
}

/**
* Parses a streaming HuggingFace error response from a JSON string.
*
* @param response the raw JSON string representing an error
* @return a parsed {@link ErrorResponse} or {@link ErrorResponse#UNDEFINED_ERROR} if parsing fails
*/
private static ErrorResponse fromString(String response) {
try (
XContentParser parser = XContentFactory.xContent(XContentType.JSON)
.createParser(XContentParserConfiguration.EMPTY, response)
) {
return ERROR_PARSER.apply(parser, null).orElse(ErrorResponse.UNDEFINED_ERROR);
} catch (Exception e) {
// swallow the error
}

return ErrorResponse.UNDEFINED_ERROR;
}

@Nullable
private final Integer httpStatusCode;

StreamingHuggingFaceErrorResponseEntity(String errorMessage, @Nullable Integer httpStatusCode) {
super(errorMessage);
this.httpStatusCode = httpStatusCode;
}

@Nullable
public Integer httpStatusCode() {
return httpStatusCode;
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@

import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

import java.util.Objects;

public abstract class HuggingFaceModel extends Model {
public abstract class HuggingFaceModel extends RateLimitGroupingModel {
private final HuggingFaceRateLimitServiceSettings rateLimitServiceSettings;
private final SecureString apiKey;

Expand All @@ -38,6 +39,16 @@ public HuggingFaceRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}

@Override
public int rateLimitGroupingHash() {
return Objects.hash(rateLimitServiceSettings.uri(), apiKey);
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitServiceSettings.rateLimitSettings();
}

public SecureString apiKey() {
return apiKey;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.huggingface.request.HuggingFaceInferenceRequest;
import org.elasticsearch.xpack.inference.services.huggingface.request.embeddings.HuggingFaceEmbeddingsRequest;

import java.util.List;
import java.util.Objects;
Expand Down Expand Up @@ -64,7 +64,7 @@ public void execute(
) {
List<String> docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs();
var truncatedInput = truncate(docsInput, model.getTokenLimit());
var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model);
var request = new HuggingFaceEmbeddingsRequest(truncator, truncatedInput, model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener));
}
Expand Down
Loading
Loading