Skip to content

Commit 668d182

Browse files
authored
[openai] add support (#8)
* initial commit openai integration * added base files * add completion api * add new lines * update client to support async/streaming text generation * support chat * support embedding and moderation
1 parent 65ab621 commit 668d182

19 files changed

+683
-1
lines changed

llm4j-huggingface/src/main/java/org/llm4j/huggingface/HFApiFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class HFApiFactory {
1414

1515
public HFApi build(Configuration config) {
1616
String apiKey = config.getString(HFConfig.API_KEY);
17-
Duration timeout = Duration.ofMillis(config.getLong(HFConfig.TIMEOUT, 15 * 1000L));
17+
Duration timeout = Duration.ofMillis(config.getLong(HFConfig.TIMEOUT, HFConfig.DEFAULT_TIMEOUT_MILLIS));
1818
HFApi api = buildApi(apiKey, timeout);
1919
return api;
2020
}

llm4j-huggingface/src/main/java/org/llm4j/huggingface/HFConfig.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ public class HFConfig {
1111
public static String WIAT_FOR_MODEL = "hf.waitForModel";
1212
public static String USE_CAHE = "hf.useCache";
1313
public static String TIMEOUT = "timeout";
14+
public static final long DEFAULT_TIMEOUT_MILLIS = 15 * 1000L;
1415

1516
}

llm4j-openai/pom.xml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
3+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4+
5+
<modelVersion>4.0.0</modelVersion>
6+
7+
<parent>
8+
<groupId>org.llm4j</groupId>
9+
<artifactId>llm4j-parent</artifactId>
10+
<version>0.0-SNAPSHOT</version>
11+
<relativePath>../parent-pom.xml</relativePath>
12+
</parent>
13+
14+
<artifactId>llm4j-openai</artifactId>
15+
16+
<packaging>jar</packaging>
17+
<name>LLM4J OpenAI</name>
18+
<description>The LLM4J API implementation for OpenAI</description>
19+
20+
<url>http://github.com/dzlab</url>
21+
22+
<properties>
23+
<module-name>org.llm4j.openai</module-name>
24+
<llm4j.provider.implementation>org.llm4j.openai.OpenAIServiceProvider</llm4j.provider.implementation>
25+
<llm4j.provider.type>openai</llm4j.provider.type>
26+
<palm.version>0.0.0-SNAPSHOT</palm.version>
27+
</properties>
28+
29+
<dependencies>
30+
<dependency>
31+
<groupId>org.llm4j</groupId>
32+
<artifactId>llm4j-api</artifactId>
33+
</dependency>
34+
35+
<dependency>
36+
<groupId>dev.ai4j</groupId>
37+
<artifactId>openai4j</artifactId>
38+
<version>0.6.1</version>
39+
</dependency>
40+
</dependencies>
41+
42+
43+
</project>
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package org.llm4j.openai;
2+
3+
import dev.ai4j.openai4j.OpenAiClient;
4+
import dev.ai4j.openai4j.SyncOrAsync;
5+
import dev.ai4j.openai4j.SyncOrAsyncOrStreaming;
6+
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
7+
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
8+
import dev.ai4j.openai4j.completion.CompletionRequest;
9+
import dev.ai4j.openai4j.completion.CompletionResponse;
10+
import dev.ai4j.openai4j.embedding.EmbeddingRequest;
11+
import dev.ai4j.openai4j.embedding.EmbeddingResponse;
12+
import dev.ai4j.openai4j.moderation.ModerationRequest;
13+
import dev.ai4j.openai4j.moderation.ModerationResponse;
14+
import org.llm4j.openai.request.TaskCallback;
15+
import org.llm4j.openai.request.StreamingCallback;
16+
17+
import java.net.Proxy;
18+
import java.time.Duration;
19+
20+
public class OpenAIClient {
21+
22+
private final OpenAiClient client;
23+
OpenAIClient(Builder builder) {
24+
this.client = builder.client;
25+
}
26+
27+
public CompletionResponse generate(CompletionRequest request) {
28+
return client.completion(request).execute();
29+
}
30+
31+
public void generateAsync(CompletionRequest request, TaskCallback<CompletionResponse> callback) {
32+
execute(client.completion(request), callback);
33+
}
34+
35+
public void generateStream(CompletionRequest request, StreamingCallback<CompletionResponse> callback) {
36+
execute(client.completion(request), callback);
37+
}
38+
39+
public ChatCompletionResponse generate(ChatCompletionRequest request) {
40+
return client.chatCompletion(request).execute();
41+
}
42+
43+
public void generateAsync(ChatCompletionRequest request, TaskCallback<ChatCompletionResponse> callback) {
44+
execute(client.chatCompletion(request), callback);
45+
}
46+
47+
public void generateStream(ChatCompletionRequest request, StreamingCallback<ChatCompletionResponse> callback) {
48+
execute(client.chatCompletion(request), callback);
49+
}
50+
51+
public EmbeddingResponse embed(EmbeddingRequest request) {
52+
return client.embedding(request).execute();
53+
}
54+
55+
public void embedAsync(EmbeddingRequest request, TaskCallback<EmbeddingResponse> callback) {
56+
execute(client.embedding(request), callback);
57+
}
58+
59+
public ModerationResponse moderate(ModerationRequest request) {
60+
return client.moderation(request).execute();
61+
}
62+
63+
public void moderateAsync(ModerationRequest request, TaskCallback<ModerationResponse> callback) {
64+
execute(client.moderation(request), callback);
65+
}
66+
67+
private <T> void execute(SyncOrAsync<T> task, TaskCallback<T> callback) {
68+
task
69+
.onResponse(completionResponse -> callback.onSuccess(completionResponse))
70+
.onError(throwable -> callback.onFailure(throwable))
71+
.execute();
72+
}
73+
74+
private <T> void execute(SyncOrAsyncOrStreaming<T> task, StreamingCallback<T> callback) {
75+
task
76+
.onPartialResponse(response -> callback.onPart(response))
77+
.onComplete(() -> callback.onComplete())
78+
.onError(throwable -> callback.onFailure(throwable))
79+
.execute();
80+
}
81+
82+
static class Builder {
83+
84+
OpenAiClient client;
85+
Builder withConfig(OpenAIConfig config) {
86+
OpenAiClient.Builder builder = OpenAiClient.builder()
87+
.url(config.getUrl())
88+
.apiKey(config.getApiKey());
89+
// set timeout
90+
Duration timeout = config.getTimeout();
91+
builder.callTimeout(timeout)
92+
.connectTimeout(timeout)
93+
.readTimeout(timeout)
94+
.writeTimeout(timeout);
95+
// set proxy
96+
if(config.hasProxy()) {
97+
Proxy.Type proxyType = config.getProxy();
98+
String proxyIp = config.getProxyIp();
99+
int proxyPort = config.getProxyPort();
100+
builder.proxy(proxyType, proxyIp, proxyPort);
101+
}
102+
client = builder
103+
.logRequests()
104+
.logResponses()
105+
.logStreamingResponses()
106+
.build();
107+
return this;
108+
}
109+
110+
public OpenAIClient build() {
111+
return new OpenAIClient(this);
112+
}
113+
}
114+
115+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package org.llm4j.openai;
2+
3+
import org.apache.commons.configuration2.Configuration;
4+
5+
import java.net.Proxy;
6+
import java.time.Duration;
7+
8+
import static dev.ai4j.openai4j.Model.GPT_3_5_TURBO;
9+
10+
public class OpenAIConfig {
11+
12+
public static final String PROXY_TYPE = "openai.proxy.type";
13+
public static final String PROXY_IP = "openai.proxy.ip";
14+
public static final String PROXY_PORT = "openai.proxy.port";
15+
16+
public static final String MODEL_ID = "openai.modelId";
17+
18+
public static final String MODEL_ID_DEFAULT = GPT_3_5_TURBO.stringValue();
19+
20+
public static final String API_URL = "openai.url";
21+
public static final String API_URL_DEFAULT = "https://api.openai.com/";
22+
public static final String API_KEY = "openai.apiKey";
23+
24+
public static final String TEMPERATURE = "temperature";
25+
public static final double TEMPERATURE_DEFAULT = 0.9;
26+
27+
public static final String TOP_P = "topP";
28+
public static final double TOP_P_DEFAULT = 0.9;
29+
30+
public static final String MAX_TOKENS = "maxTokens";
31+
public static final int MAX_TOKENS_DEFAULT = 128;
32+
33+
public static final String PRESENCE_PENALTY = "presencePenalty";
34+
public static final double PRESENCE_PENALTY_DEFAULT = 0.9;
35+
36+
public static final String FREQUENCY_PENALTY = "frequencyPenalty";
37+
public static final double FREQUENCY_PENALTY_DEFAULT = 0.9;
38+
public static final String TIMEOUT = "timeout";
39+
40+
public static final long DEFAULT_TIMEOUT_MILLIS = 15 * 1000L;
41+
42+
private final Configuration config;
43+
44+
public OpenAIConfig(Configuration config) {
45+
this.config = config;
46+
}
47+
48+
public String getUrl() {
49+
return config.getString(OpenAIConfig.API_URL, OpenAIConfig.API_URL_DEFAULT);
50+
}
51+
52+
public String getApiKey() {
53+
return config.getString(OpenAIConfig.API_KEY);
54+
}
55+
56+
public String getModelId() {
57+
return config.getString(OpenAIConfig.MODEL_ID, OpenAIConfig.MODEL_ID_DEFAULT);
58+
}
59+
60+
public Double getTemperature() {
61+
return config.getDouble(OpenAIConfig.TEMPERATURE, TEMPERATURE_DEFAULT);
62+
}
63+
64+
public Double getTopP() {
65+
return config.getDouble(OpenAIConfig.TOP_P, TOP_P_DEFAULT);
66+
}
67+
public Integer getMaxTokens() {
68+
return config.getInteger(OpenAIConfig.MAX_TOKENS, MAX_TOKENS_DEFAULT);
69+
}
70+
public Double getPresencePenalty() {
71+
return config.getDouble(OpenAIConfig.PRESENCE_PENALTY, PRESENCE_PENALTY_DEFAULT);
72+
}
73+
public Double getFrequencyPenalty() {
74+
return config.getDouble(OpenAIConfig.FREQUENCY_PENALTY, FREQUENCY_PENALTY_DEFAULT);
75+
}
76+
77+
public Duration getTimeout() {
78+
return Duration.ofMillis(config.getLong(OpenAIConfig.TIMEOUT, OpenAIConfig.DEFAULT_TIMEOUT_MILLIS));
79+
}
80+
81+
public boolean hasProxy() {
82+
return !config.getString(OpenAIConfig.PROXY_TYPE, "").isEmpty();
83+
}
84+
85+
public Proxy.Type getProxy() {
86+
String proxyTypeName = config.getString(OpenAIConfig.PROXY_TYPE);
87+
Proxy.Type proxyType = Proxy.Type.valueOf(proxyTypeName);
88+
return proxyType;
89+
}
90+
91+
public String getProxyIp() {
92+
return config.getString(OpenAIConfig.PROXY_IP);
93+
}
94+
95+
public Integer getProxyPort() {
96+
return config.getInt(OpenAIConfig.PROXY_PORT);
97+
}
98+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package org.llm4j.openai;
2+
3+
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
4+
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
5+
import dev.ai4j.openai4j.completion.CompletionRequest;
6+
import dev.ai4j.openai4j.completion.CompletionResponse;
7+
import dev.ai4j.openai4j.embedding.EmbeddingRequest;
8+
import dev.ai4j.openai4j.embedding.EmbeddingResponse;
9+
import dev.ai4j.openai4j.moderation.ModerationRequest;
10+
import dev.ai4j.openai4j.moderation.ModerationResponse;
11+
import dev.ai4j.openai4j.moderation.ModerationResult;
12+
import org.apache.commons.configuration2.Configuration;
13+
import org.llm4j.api.ChatHistory;
14+
import org.llm4j.api.LanguageModel;
15+
import org.llm4j.api.LanguageModelFactory;
16+
import org.llm4j.openai.request.ChatRequestFactory;
17+
import org.llm4j.openai.request.EmbeddingRequestFactory;
18+
import org.llm4j.openai.request.ModerationRequestFactory;
19+
import org.llm4j.openai.request.TextRequestFactory;
20+
21+
import java.util.Collections;
22+
import java.util.List;
23+
import java.util.Map;
24+
25+
public class OpenAILanguageModel implements LanguageModel {
26+
27+
private final OpenAIClient client;
28+
private final OpenAIConfig config;
29+
OpenAILanguageModel(Builder builder) {
30+
this.client = builder.client;
31+
this.config = builder.config;
32+
}
33+
@Override
34+
public String process(String text) {
35+
CompletionRequest request = new TextRequestFactory()
36+
.withText(text)
37+
.withConfig(config)
38+
.build();
39+
40+
CompletionResponse response = client.generate(request);
41+
42+
return response.text();
43+
}
44+
45+
@Override
46+
public String process(ChatHistory history) {
47+
ChatCompletionRequest request = new ChatRequestFactory()
48+
.withChat(history)
49+
.withConfig(config)
50+
.build();
51+
52+
ChatCompletionResponse response = client.generate(request);
53+
return response.content();
54+
}
55+
56+
@Override
57+
public List<Float> embed(String text) {
58+
59+
EmbeddingRequest request = new EmbeddingRequestFactory()
60+
.withText(text)
61+
.withConfig(config)
62+
.build();
63+
64+
EmbeddingResponse response = client.embed(request);
65+
66+
return response.embedding();
67+
}
68+
69+
public List<ModerationResult> moderate(String input) {
70+
return moderate(Collections.singletonList(input));
71+
}
72+
public List<ModerationResult> moderate(List<String> inputs) {
73+
74+
ModerationRequest request = new ModerationRequestFactory()
75+
.withInputs(inputs)
76+
.withConfig(config)
77+
.build();
78+
79+
ModerationResponse response = client.moderate(request);
80+
81+
return response.results();
82+
}
83+
84+
public static final class Builder implements LanguageModelFactory {
85+
private OpenAIClient client;
86+
private OpenAIConfig config;
87+
88+
public LanguageModel getLanguageModel(Configuration config) {
89+
this.config = new OpenAIConfig(config);
90+
this.client = new OpenAIClient.Builder()
91+
.withConfig(this.config)
92+
.build();
93+
return new OpenAILanguageModel(this);
94+
}
95+
}
96+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.llm4j.openai;
2+
3+
import org.llm4j.api.LanguageModelFactory;
4+
import org.llm4j.spi.LLM4JServiceProvider;
5+
6+
public class OpenAIServiceProvider implements LLM4JServiceProvider {
7+
@Override
8+
public LanguageModelFactory getLLMFactory() {
9+
return new OpenAILanguageModel.Builder();
10+
}
11+
}

0 commit comments

Comments
 (0)