Skip to content

Add Hugging Face Chat Completion support to Inference Plugin #127254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

Jan-Kazlouski-elastic
Copy link
Contributor

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic commented Apr 23, 2025

Change to existing Hugging Face provider integration allowing completion (both streaming and non-streaming) and chat_completion (only streaming) to be executed as part of inference API.
Examples of RQ/RS from local testing:

Non-streaming

Create Completion Endpoint.txt

Create Completion Endpoint:

RQ:
curl --location --request PUT 'localhost:9200/_inference/completion/hugging-face-completion' \
--header 'Content-Type: application/json' \
--header 'Authorization: Basic %auth_token%' \
--data '{
    "service": "hugging_face",
    "service_settings": {
        "api_key": "%hf_token%",
        "model_id": "tgi",
        "max_input_tokens" : 128,
        "url": "%hf_url%"
    }
}'

RS:
{
    "inference_id": "hugging-face-completion",
    "task_type": "completion",
    "service": "hugging_face",
    "service_settings": {
        "model_id": "tgi",
        "url": "%hf_url%",
        "max_input_tokens": 128,
        "rate_limit": {
            "requests_per_minute": 3000
        }
    }
}

Perform Non-Streaming Completion.txt

Perform Non-Streaming Completion:

RQ:
curl --location 'localhost:9200/_inference/completion/hugging-face-completion' \
--header 'Content-Type: application/json' \
--header 'Authorization: Basic %auth_token%' \
--data '{
    "input": "The sky above the port was the color of television tuned to a dead channel."
}'

RS:
{
    "completion": [
        {
            "result": "This instruction describes an imagery, specifically a scene that creates a visual analogy between the sky and a television screen. The analogy implies that the sky is empty, dull, and possibly grim, as a television set with no signal would be bare and desolate.\n\n\nFor a rephrased instruction that maintains the same core meaning but changes the details, one could say:\n\n\n\"The firmament above the harbor mirrored the static hum of an unlit screen.\"\n\n\nThis sentence still captures the essence of an uneventful sky while using different words and imagery. It evokes a sense of abandonment and stillness by comparing it to the static seen on an unused television."
        }
    ]
}
Streaming

Create Chat Completion Endpoint.txt

RQ:
curl --location --request PUT 'localhost:9200/_inference/chat_completion/hugging-face-chat-completion' \
--header 'Content-Type: application/json' \
--header 'Authorization: Basic %auth_token%' \
--data '{
    "service": "hugging_face",
    "service_settings": {
        "api_key": "%hf_token%",
        "model_id": "tgi",
        "max_input_tokens" : 128,
        "url": "%hf_url%"
    }
}'

RS:
{
    "inference_id": "hugging-face-completion",
    "task_type": "completion",
    "service": "hugging_face",
    "service_settings": {
        "model_id": "tgi",
        "url": "%hf_url%",
        "max_input_tokens": 128,
        "rate_limit": {
            "requests_per_minute": 3000
        }
    }
}

Perform Streaming Completion.txt

RQ:
curl --location 'localhost:9200/_inference/completion/hugging-face-completion/_stream' \
--header 'Content-Type: application/json' \
--header 'Authorization: Basic %auth_token%' \
--data '{
    "input": "The sky above the port was the color of television tuned to a dead channel."
}'

Perform Streaming Chat Completion.txt

RQ:
curl --location 'localhost:9200/_inference/chat_completion/hugging-face-chat-completion/_stream' \
--header 'Content-Type: application/json' \
--header 'Authorization: Basic %auth_token%' \
--data '{
    "model": "tgi",
    "messages": [
        {
            "role": "user",
            "content": "What is deep learning?"
        }
    ],
    "max_completion_tokens": 150
}'

Tested on models:
https://huggingface.co/Qwen/QwQ-32B
https://huggingface.co/microsoft/Phi-3-mini-128k-instruct

  • Have you signed the contributor license agreement? - YES
  • Have you followed the contributor guidelines? - YES
  • If submitting code, have you built your formula locally prior to submission with gradle check? - YES
  • If submitting code, is your pull request against main? Unless there is a good reason otherwise, we prefer pull requests against main and will backport as needed. - YES
  • If submitting code, have you checked that your submission is for an OS and architecture that we support? - YES
  • If you are submitting this code for a class then read our policy for that. - YES

@elasticsearchmachine elasticsearchmachine added external-contributor Pull request authored by a developer outside the Elasticsearch team v9.1.0 labels Apr 23, 2025
@jonathan-buttner jonathan-buttner added :ml Machine learning Team:ML Meta label for the ML team >enhancement v8.19.0 labels Apr 23, 2025
@@ -361,6 +362,7 @@ public void loadExtensions(ExtensionLoader loader) {
public List<InferenceServiceExtension.Factory> getInferenceServiceFactories() {
return List.of(
context -> new HuggingFaceElserService(httpFactory.get(), serviceComponents.get()),
context -> new HuggingFaceChatCompletionService(httpFactory.get(), serviceComponents.get()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't looked through the entire PR but just wanted to check. We should try to add the chat completion functionality to the existing HuggingFaceService logic.

For example the OpenAI service supports many task types: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java#L175-L197

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change is done. Now completion logic is in single HuggingFaceService class.

import java.util.Objects;
import java.util.function.Supplier;

public class HuggingFaceCompletionRequestManager extends HuggingFaceRequestManager {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to move away from the request manager pattern because it adds duplicate code. Could you look into following the pattern we started here (we haven't refactored all the services yet but if it's possible to do for hugging face it'd be great if we could do it now)?

#124144

One option would be to leave the other hugging face request managers as they are (if possible, it may not be though) and then use one of the generic request managers like shown in the PR above for this new functionality.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing. I will adopt the approach from the shared PR. Thanks Jonathan!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did the change that allowed us to move away from request manager for chat_completion and completion tasks.

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic marked this pull request as ready for review April 29, 2025 20:15
@elasticsearchmachine
Copy link
Collaborator

Pinging @elastic/ml-core (Team:ML)

@jonathan-buttner jonathan-buttner added the auto-backport Automatically create backport pull requests when merged label Apr 30, 2025
Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, left a few comments

I was testing streaming chat completions and ran into an issue with how hugging face returns an error related to a request with a tools message:

I provisioned a HF inference endpoint for the model Qwen/QwQ-32B

This is the endpoint name: jon-qwq-32b-qkm

PUT _inference/chat_completion/test-chat
{
    "service": "hugging_face",
    "service_settings": {
        "api_key": "<api_key>",
        "max_input_tokens" : 128,
        "model_id": "tgi",
        "url": "<url>"
    }
}

The following request fails:

POST _inference/chat_completion/test-chat/_stream
{
    "messages": [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "What's the price of a scarf?"
                }
            ]
        }
    ],
    "tools": [
        {
            "type": "function",
            "function": {
                "name": "get_current_price",
                "description": "Get the current price of a item",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "item": {
                            "id": "123"
                        }
                    }
                }
            }
        }
    ],
    "tool_choice": {
        "type": "function",
        "function": {
            "name": "get_current_price"
        }
    }
}

Response

event: error
data: {"error":{"code":"bad_request","message":"Required [id, choices, model, object]","type":"illegal_argument_exception"}}

The logs show the underlying issue is that hugging face returns:

{"error":{"message":"Input validation error: cannot compile regex from schema: Unsupported JSON Schema structure {\"id\":\"123\"} \nMake sure it is valid to the JSON Schema specification and check if it's supported by Outlines.\nIf it should be supported, please open an issue.","http_status_code":422}}

{"error": {"message" ...}, ...}

Our openai code requires there to be a type in the object. To fix this we can do this:

We'll create a new response handler similar to what we're doing here: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java

The new response handler can extend OpenAiUnifiedChatCompletionResponseHandler

We'll do the same thing as the file I linked but when we'll pass in a different lambda like we're doing here: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java#L37

Our lambda will use an error parser that can extract the fields from the error I mentioned about. I think my suggestion would be to create a new class similar to ErrorMessageResponseEntity which can extract the message field and maybe the http_status_code.

We'll then use that code like what's being done here: https://github.com/elastic/elasticsearch/blob/main/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceUnifiedChatCompletionResponseHandler.java#L62

Finally we'll use our new response handler here:

private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new OpenAiUnifiedChatCompletionResponseHandler(

It'll look something like this:

    private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new HuggingFaceChatCompletionResponseHandler(
        "hugging face chat completion",
        OpenAiChatCompletionResponseEntity::fromResponse
    );

We're actually trying to move away from including "unified" in the names but we haven't gotten around to cleaning up the rest of the code base yet.

private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable<>(
() -> {
var configurationMap = new HashMap<String, SettingsConfiguration>();

configurationMap.put(
URL,
new SettingsConfiguration.Builder(supportedTaskTypes).setDefaultValue("https://api.openai.com/v1/embeddings")
new SettingsConfiguration.Builder(SUPPORTED_TASK_TYPES).setDefaultValue("https://api.openai.com/v1/embeddings")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops looks like we have an existing bug here (unrelated to your changes). Can you remove the setDefaultValue that shouldn't be pointing to openai 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially assumed it is there for some internal configuration and didn't want to introduce any risks by changing it. Removed.

unifiedChatInput -> new HuggingFaceUnifiedChatCompletionRequest(unifiedChatInput, overriddenModel),
UnifiedChatInput.class
);
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "CHAT COMPLETION", model.getInferenceEntityId());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about we move this into a function something like:

private static String errorMessage(String requestDescription, String inferenceId) {
  return format("Failed to send Hugging Face %s request from inference entity id [%s]", requestDescription, inferenceId)
}

It might be a little easier to see how the string is being formatted if the raw string is included in the format call.

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:
Maybe we should use TaskType taskType instead of String requestDescription parameter in it? That way we'd restrict values to be a part of specified list of clearly defined tasks erasing possibility of different formatting. Because in current implementation it is "text embeddings" and "ELSER" which is a bit messy.
Such approach would change "ELSER" to sparse_embedding and make other values lowercase as well.

P.S. Also having elser vs sparse embedding used interchangeably might be worth unifying to keep the vocabulary more strict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added version described above. Please do tell if you'd like to stick with the version you proposed initially.

"text embeddings",
model.getInferenceEntityId()
);
var errorMessage = format(FAILED_TO_SEND_REQUEST_ERROR_MESSAGE, "text embeddings", model.getInferenceEntityId());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Same comment as above suggesting making this a function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the change described in my comment above.


@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.V_8_14_0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to create a new version number instead of using an old one here. Here's an example: https://github.com/elastic/elasticsearch/pull/122218/files#diff-85e782e9e33a0f8ca8e99b41c17f9d04e3a7981d435abf44a3aa5d954a47cd8f

Basically we need to create a version for the 8.x branch, in the linked PR that was ML_INFERENCE_DEEPSEEK_8_19 = def(8_841_0_09);

And we need to create a version for the 9.x branch: ML_INFERENCE_DEEPSEEK = def(9_029_00_0);

When creating the variables we'll want to increment the value so it is the newest for both 8.x and 9.x. So as of writing this the latest value is here: https://github.com/elastic/elasticsearch/blob/main/server/src/main/java/org/elasticsearch/TransportVersions.java#L231

So our value for the 9.x branch should be def(9_0656_0_00);, for the 8.x branch it should be: def(8_842_0_20);

The value we put here on line 181 should be the 9.x version. When we backport this PR to 8.x branch we'll switch it to the 8.x variable name. Here's an example of the backport for deepseek: https://github.com/elastic/elasticsearch/pull/124796/files#diff-85e782e9e33a0f8ca8e99b41c17f9d04e3a7981d435abf44a3aa5d954a47cd8f

This will changes as other people in the organization add their own transport versions and it will cause merge conflicts, so as you update from the main branch we'll just need to keep bumping the value until we merge the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the 8.x branch it should be: def(8_842_0_20);

842 would be a new version of server part and 20 would be new version of patch. It should be incremented only as server part or patch part. Not both. According to documentation in TransportVersions.java.

To determine the id of the next TransportVersion constant, do the following:

  • Use the same major version, unless bumping majors
  • Bump the server version part by 1, unless creating a patch version
  • Leave the subsidiary part as 0
  • Bump the patch part if creating a patch version

The last 23 versions 8.x versions are patch updates and 9.x versions are ALL server part updates. I would assume that new version for 8.x would be another patch update and for 9.x - another server part update.

Added the versions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right, good catch 👍

this.uri = createUri(in.readString());
this.maxInputTokens = in.readOptionalVInt();

if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need to do the onOrAfter() call if we're adding a new field to an existing setting that was introduced in a previous version. Since this entire file is new we can remove these if blocks as this code is guaranteed to be introduced after v8.15.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed redundant check.


String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);

var uri = extractUri(map, URL, validationException);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note it seems like for both dedicated inference endpoints and serverless hugging face requires v1/chat/completions to be in the path. We'll want to make sure we include this in the documentation, that users need to include that segment or they'll get an error.

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing. Will include this into documentation along with the fact that model must support OpenAI interface. Not all of them do apparently.

out.writeString(uri.toString());
out.writeOptionalVInt(maxInputTokens);

if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the if-block

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

@@ -112,14 +115,14 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce
);

assertThat(webServer.requests(), hasSize(1));
assertNull(webServer.requests().get(0).getUri().getQuery());
assertNull(webServer.requests().getFirst().getUri().getQuery());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would actually leave these as get(0) because when we backport it, the 8.19 branch uses an older version of the jdk and this method (getFirst) won't exist I believe.

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is true. It wouldn't exist. Reverted the changes.

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.is;

public class HuggingFaceChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's extend AbstractBWCWireSerializationTestCase instead (it helps with future testing when we add new fields to the serialization).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

serviceSettings.toXContent(builder, null);
String xContentResult = Strings.toString(builder);

assertThat(xContentResult, is("""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can leave this as is but for future tests let's use this utility class so we can create the expected json in a more readable fashion (I realize most of our tests use the method you have here).

        var expected = XContentHelper.stripWhitespace("""
            {
                "secret_parameters": {
                    "test_key": "test_value"
                }
            }
            """);

        assertThat(xContentResult, is(expected));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


/**
* This class is responsible for managing the Hugging Face inference service.
* It handles the creation of models, chunked inference, and unified completion inference.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The class also handles non-chunked inference which should be included in the javadoc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rephrased it so it is more specific. Thanks.

@@ -167,13 +220,15 @@ public static InferenceServiceConfiguration get() {
return configuration.getOrCompute();
}

private Configuration() {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this line needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In short - to protect this class from being instantiated.

Since there are only static members in this class - there is no reason for having an option of instantiating it. To protect this class from being instantiated we can hide default constructor that every Object has by declaring private one.
It is optional, si if you want - I can remove this.

public static final String NAME = "hugging_face_completion_service_settings";
// At the time of writing HuggingFace hasn't posted the default rate limit for inference endpoints so the value his is only a guess
// 3000 requests per minute
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(3000);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you arrive at the 3000 default? Is there somewhere that confirms that 3000 is a viable number (even if it is different than a recommended default from Hugging Face)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest this is taken as is from HuggingFace Service Settings for other tasks (sparse and text embeddings). I haven't seen any publications from HuggingFace claiming there to be a different rate limit, so decided to go with what we have for other operations.
Original is committed by @jonathan-buttner so please do let me know if this is something to reconsider.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I haven't seen a rate limit in hugging face's docs and it'll probably depend on how the model is deployed (via serverless or a dedicated endpoint). I'm ok with 3000, we can document that users should use an appropriate value for their environment.

public static HuggingFaceChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();

String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is model ID used for? My understanding for other task types for hugging face is that we take in the URI instead of model ID. What does it mean if both URI and model ID are set?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we're talking about Text Generation task on Hugging Face side - Model ID sent to HuggingFace API as part of OpenAI like request schema. While being seemingly optional for dedicated endpoints, it is mandatory for serverless HF Inference Endpoints. So in order to receive successful response from those - we need to provide it as part of the request.

For other tasks we have integration with - it is not being sent because it is not defined in request schema on HF side.
URI is mandatory always and provided by customer because HF doesn't have default URL.

What is model ID used for?

It is defined by HF as field of request being sent to HF as part of OpenAI like request. We're sending it as part of payload.

What does it mean if both URI and model ID are set?

When both model ID and URI are set that means that we know the URI to call and are able to send request with field that HF expects to be there if it is serverless endpoint.

@@ -24,25 +24,35 @@

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

public class HuggingFaceInferenceRequest implements Request {
/**
* This class is responsible for creating a request to the Hugging Face API for embeddings.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The class doesn't really create the request so much as it is the request. Maybe you can say it represents the request?

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly, I tend to disagree. While being named HuggingFaceInferenceRequest and implementing Request interface from looking at this class's behavior we can see that it literally creates request in createHttpRequest method. Object itself hardly can be called a request because it is not being sent to Hugging Face, it just creates an actual request that goes out to HF. Though I agree that it represents request in some way.
@jonathan-buttner What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah our naming convention here is confusing.

How about we change the class doc to:

This class is responsible for creating an Hugging Face embeddings HTTP request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did the change.

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader;

/**
* This class is responsible for creating a request to the Hugging Face API for chat completions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as the other request class. I think this more represents a request than creates it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied above. Would like to hear Jonathan's take on it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change is done.

builder.startObject();
unifiedRequestEntity.toXContent(builder, params);

builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the question from the Model class but if the model ID is optional , would it be better to use the URI here instead? Or maybe we want to identify it with a combination of both?

Copy link
Contributor Author

@Jan-Kazlouski-elastic Jan-Kazlouski-elastic May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

URI is mandatory
Model ID is optional depending on the endpoint type (dedicated/serverless).
Identification can be performed by URI only in case of dedicated endpoint or combination of both in case of serverless endpoint usage.
URI only would cause error response if used for serverless endpoint.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks for clarifying. It seems like in the updated code we include the model ID if it is provided here which makes sense but we do not include the URI as part of the toXContent. Is there a reason we are not including the URI if it is mandatory in all cases?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke with Jonathan and he explained that this is serializing the request body. Makes sense why we aren't serializing the URI here then. The most recent updates seem good to me then.

var initialInputs = initialRequestAsMap.get("inputs");
assertThat(initialInputs, is(List.of("123")));

}
}

public void testExecute_ReturnsSuccessfulResponse_ForChatCompletionAction() throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this test and the next one have very similar processes. Can we reduce code duplication by moving the shared code into a shared function that takes in the values that are generated differently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. They are quite different so I extracted most of the repeated code into methods.

public class HuggingFaceUnifiedChatCompletionRequestTests extends ESTestCase {

public void testCreateRequest_WithStreaming() throws IOException {
var request = createRequest("url", "secret", "abcd", "model", true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For random inputs you can use randomAlphaOfLength(x);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

…chat-completion-integration

# Conflicts:
#	x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java
@jonathan-buttner
Copy link
Contributor

jonathan-buttner commented May 9, 2025

@Jan-Kazlouski-elastic I was doing some testing with the Qwen model and noticed that the function name field was coming back as null and cause our parsing logic to fail. I put up a PR to fix that here: #127976

Would you mind cherrypicking that commit into this PR?

Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look great! thanks for addressing my concerns. I left one note about the function name change, you're welcome to pull that in if you like or I can merge PR later. Can you merge in the latest main to resolve the transport version issue? After that we're good to merge 👍

Jan-Kazlouski-elastic and others added 2 commits May 11, 2025 16:15
…chat-completion-integration

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
@Jan-Kazlouski-elastic
Copy link
Contributor Author

@jonathan-buttner
I've

  • cherry-picked your commit for null function name
  • resolved conflict for TransportVersion
  • merged fresh changes from main branch
  • executed "gradle check" task

…icsearch into feature/hugging-face-chat-completion-integration

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
…icsearch into feature/hugging-face-chat-completion-integration
…icsearch into feature/hugging-face-chat-completion-integration
…icsearch into feature/hugging-face-chat-completion-integration
…icsearch into feature/hugging-face-chat-completion-integration
Copy link
Contributor

@jonathan-buttner jonathan-buttner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

@jonathan-buttner jonathan-buttner enabled auto-merge (squash) May 19, 2025 13:58
…icsearch into feature/hugging-face-chat-completion-integration

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
auto-merge was automatically disabled May 19, 2025 14:30

Head branch was pushed to by a user without write access

@jonathan-buttner jonathan-buttner enabled auto-merge (squash) May 19, 2025 15:05
@jonathan-buttner jonathan-buttner merged commit d1ad917 into elastic:main May 19, 2025
19 checks passed
@elasticsearchmachine
Copy link
Collaborator

💔 Backport failed

Status Branch Result
8.19 Commit could not be cherrypicked due to conflicts

You can use sqren/backport to manually backport by running backport --upstream elastic/elasticsearch --pr 127254

@jonathan-buttner
Copy link
Contributor

💚 All backports created successfully

Status Branch Result
8.19

Questions ?

Please refer to the Backport tool documentation

jonathan-buttner pushed a commit to jonathan-buttner/elasticsearch that referenced this pull request May 19, 2025
…#127254)

* Add Hugging Face Chat Completion support to Inference Plugin

* Add support for streaming chat completion task for HuggingFace

* [CI] Auto commit changes from spotless

* Add support for non-streaming completion task for HuggingFace

* Remove RequestManager for HF Chat Completion Task

* Refactored Hugging Face Completion Service Settings, removed Request Manager, added Unit Tests

* Refactored Hugging Face Action Creator, added Unit Tests

* Add Hugging Face Server Test

* [CI] Auto commit changes from spotless

* Removed parameters from media type for Chat Completion Request and unit tests

* Removed OpenAI default URL in HuggingFaceService's configuration, fixed formatting in InferenceGetServicesIT

* Refactor error message handling in HuggingFaceActionCreator and HuggingFaceService

* Update minimal supported version and add Hugging Face transport version constants

* Made modelId field optional in HuggingFaceChatCompletionModel, updated unit tests

* Removed max input tokens field from HuggingFaceChatCompletionServiceSettings, fixed unit tests

* Removed if statement checking TransportVersion for HuggingFaceChatCompletionServiceSettings constructor with StreamInput param

* Removed getFirst() method calls for backport compatibility

* Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWCWireSerializationTestCase for future serialization testing

* Refactored tests to use stripWhitespace method for readability

* Refactored javadoc for HuggingFaceService

* Renamed HF chat completion TransportVersion constant names

* Added random string generation in unit test

* Refactored javadocs for HuggingFace requests

* Refactored tests to reduce duplication

* Added changelog file

* Add HuggingFaceChatCompletionResponseHandler and associated tests

* Refactor error handling in HuggingFaceServiceTests to standardize error response codes and types

* Refactor HuggingFace error handling to improve response structure and add streaming support

* Allowing null function name for hugging face models

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co>
(cherry picked from commit d1ad917)

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java
@jonathan-buttner
Copy link
Contributor

Backport is here: #128152

elasticsearchmachine pushed a commit that referenced this pull request May 19, 2025
#128152)

* Add Hugging Face Chat Completion support to Inference Plugin

* Add support for streaming chat completion task for HuggingFace

* [CI] Auto commit changes from spotless

* Add support for non-streaming completion task for HuggingFace

* Remove RequestManager for HF Chat Completion Task

* Refactored Hugging Face Completion Service Settings, removed Request Manager, added Unit Tests

* Refactored Hugging Face Action Creator, added Unit Tests

* Add Hugging Face Server Test

* [CI] Auto commit changes from spotless

* Removed parameters from media type for Chat Completion Request and unit tests

* Removed OpenAI default URL in HuggingFaceService's configuration, fixed formatting in InferenceGetServicesIT

* Refactor error message handling in HuggingFaceActionCreator and HuggingFaceService

* Update minimal supported version and add Hugging Face transport version constants

* Made modelId field optional in HuggingFaceChatCompletionModel, updated unit tests

* Removed max input tokens field from HuggingFaceChatCompletionServiceSettings, fixed unit tests

* Removed if statement checking TransportVersion for HuggingFaceChatCompletionServiceSettings constructor with StreamInput param

* Removed getFirst() method calls for backport compatibility

* Made HuggingFaceChatCompletionServiceSettingsTests extend AbstractBWCWireSerializationTestCase for future serialization testing

* Refactored tests to use stripWhitespace method for readability

* Refactored javadoc for HuggingFaceService

* Renamed HF chat completion TransportVersion constant names

* Added random string generation in unit test

* Refactored javadocs for HuggingFace requests

* Refactored tests to reduce duplication

* Added changelog file

* Add HuggingFaceChatCompletionResponseHandler and associated tests

* Refactor error handling in HuggingFaceServiceTests to standardize error response codes and types

* Refactor HuggingFace error handling to improve response structure and add streaming support

* Allowing null function name for hugging face models

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co>
(cherry picked from commit d1ad917)

# Conflicts:
#	server/src/main/java/org/elasticsearch/TransportVersions.java

Co-authored-by: Jan-Kazlouski-elastic <jan.kazlouski@elastic.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
auto-backport Automatically create backport pull requests when merged backport pending >enhancement external-contributor Pull request authored by a developer outside the Elasticsearch team :ml Machine learning Team:ML Meta label for the ML team v8.19.0 v9.1.0
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants