-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[ML] Refactor OpenAI request managers #124144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Refactor OpenAI request managers #124144
Conversation
@@ -27,6 +36,18 @@ | |||
*/ | |||
public class OpenAiActionCreator implements OpenAiActionVisitor { | |||
public static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; | |||
public static final String USER_ROLE = "user"; | |||
|
|||
static final ResponseHandler COMPLETION_HANDLER = new OpenAiChatCompletionResponseHandler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This changes are basically to move the logic from the request manager files into here.
// It's possible that two inference endpoints have the same information defining the group but have different | ||
// rate limits then they should be in different groups otherwise whoever initially created the group will set | ||
// the rate and the other inference endpoint's rate will be ignored | ||
return new EndpointGrouping(rateLimitGroup, rateLimitSettings); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to be recreating the object on each call.
this.rateLimitSettings = rateLimitSettings; | ||
} | ||
|
||
BaseRequestManager(ThreadPool threadPool, RateLimitGroupingModel rateLimitGroupingModel) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a stopgap. Once all the request managers are refactored the old constructor can be removed.
* This is a temporary class to use while we refactor all the request managers. After all the request managers extend | ||
* this class we'll move this functionality directly into the {@link BaseRequestManager}. | ||
*/ | ||
public class GenericRequestManager<T extends InferenceInputs> extends BaseRequestManager { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once all the request managers are refactored, I envision that we'll be able to move this logic up into the base class.
|
||
import static org.elasticsearch.xpack.inference.common.Truncator.truncate; | ||
|
||
public class TruncatingRequestManager extends BaseRequestManager { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this would only be used for text embedding requests.
this.truncationResult = Objects.requireNonNull(input); | ||
this.model = Objects.requireNonNull(model); | ||
} | ||
|
||
public HttpRequest createHttpRequest() { | ||
HttpPost httpPost = new HttpPost(account.uri()); | ||
HttpPost httpPost = new HttpPost(model.uri()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed all the account related stuff into the model since we need it there to calculate the hash anyway.
); | ||
} | ||
|
||
public static URI buildDefaultUri() throws URISyntaxException { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
public because it's used in a few tests.
@@ -62,4 +69,16 @@ public OpenAiRateLimitServiceSettings rateLimitServiceSettings() { | |||
} | |||
|
|||
public abstract ExecutableAction accept(OpenAiActionVisitor creator, Map<String, Object> taskSettings); | |||
|
|||
public int rateLimitGroupingHash() { | |||
return Objects.hash(rateLimitServiceSettings.modelId(), apiKey, uri); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could probably calculate this only once, I suppose to avoid weird bugs maybe it's better to do it on every call in the event that one of the fields gets reset. They shouldn't though since they're final.
Pinging @elastic/ml-core (Team:ML) |
|
||
public abstract int rateLimitGroupingHash(); | ||
|
||
public abstract RateLimitSettings rateLimitSettings(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we maybe (eventually) move this into Model, since I think everyone has RateLimitSettings anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. There might be a few places like the Bedrock implementation that doesn't. I'll see if we can handle that elegantly.
* Code compiling * Removing OpenAiAccount
💚 Backport successful
|
* Code compiling * Removing OpenAiAccount
* Code compiling * Removing OpenAiAccount
This PR demonstrates how we can remove many of the RequestManager files within the inference API. I only did this for OpenAI as a demonstration. If we're ok with the approach I can do it for the rest of the services.