Skip to content

Commit c29e8fe

Browse files
authored
Merge ba9c38f into 534cc53
2 parents 534cc53 + ba9c38f commit c29e8fe

File tree

1 file changed

+93
-1
lines changed

1 file changed

+93
-1
lines changed

firebase-vertexai/src/testUtil/java/com/google/firebase/vertexai/JavaCompileTests.java

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import com.google.firebase.concurrent.FirebaseExecutors;
2222
import com.google.firebase.vertexai.FirebaseVertexAI;
2323
import com.google.firebase.vertexai.GenerativeModel;
24+
import com.google.firebase.vertexai.LiveGenerativeModel;
2425
import com.google.firebase.vertexai.java.ChatFutures;
2526
import com.google.firebase.vertexai.java.GenerativeModelFutures;
27+
import com.google.firebase.vertexai.java.LiveModelFutures;
28+
import com.google.firebase.vertexai.java.LiveSessionFutures;
2629
import com.google.firebase.vertexai.type.BlockReason;
2730
import com.google.firebase.vertexai.type.Candidate;
2831
import com.google.firebase.vertexai.type.Citation;
@@ -33,24 +36,33 @@
3336
import com.google.firebase.vertexai.type.FileDataPart;
3437
import com.google.firebase.vertexai.type.FinishReason;
3538
import com.google.firebase.vertexai.type.FunctionCallPart;
39+
import com.google.firebase.vertexai.type.FunctionResponsePart;
3640
import com.google.firebase.vertexai.type.GenerateContentResponse;
41+
import com.google.firebase.vertexai.type.GenerationConfig;
3742
import com.google.firebase.vertexai.type.HarmCategory;
3843
import com.google.firebase.vertexai.type.HarmProbability;
3944
import com.google.firebase.vertexai.type.HarmSeverity;
4045
import com.google.firebase.vertexai.type.ImagePart;
4146
import com.google.firebase.vertexai.type.InlineDataPart;
47+
import com.google.firebase.vertexai.type.LiveContentResponse;
48+
import com.google.firebase.vertexai.type.LiveGenerationConfig;
49+
import com.google.firebase.vertexai.type.MediaData;
4250
import com.google.firebase.vertexai.type.ModalityTokenCount;
4351
import com.google.firebase.vertexai.type.Part;
4452
import com.google.firebase.vertexai.type.PromptFeedback;
53+
import com.google.firebase.vertexai.type.ResponseModality;
4554
import com.google.firebase.vertexai.type.SafetyRating;
55+
import com.google.firebase.vertexai.type.SpeechConfig;
4656
import com.google.firebase.vertexai.type.TextPart;
4757
import com.google.firebase.vertexai.type.UsageMetadata;
58+
import com.google.firebase.vertexai.type.Voices;
4859
import java.util.Calendar;
4960
import java.util.List;
5061
import java.util.Map;
5162
import java.util.concurrent.Executor;
5263
import kotlinx.serialization.json.JsonElement;
5364
import kotlinx.serialization.json.JsonNull;
65+
import kotlinx.serialization.json.JsonObject;
5466
import org.junit.Assert;
5567
import org.reactivestreams.Publisher;
5668
import org.reactivestreams.Subscriber;
@@ -63,9 +75,31 @@ public class JavaCompileTests {
6375

6476
public void initializeJava() throws Exception {
6577
FirebaseVertexAI vertex = FirebaseVertexAI.getInstance();
66-
GenerativeModel model = vertex.generativeModel("fake-model-name");
78+
GenerativeModel model = vertex.generativeModel("fake-model-name", getConfig());
79+
LiveGenerativeModel live = vertex.liveModel("fake-model-name", getLiveConfig());
6780
GenerativeModelFutures futures = GenerativeModelFutures.from(model);
81+
LiveModelFutures liveFutures = LiveModelFutures.from(live);
6882
testFutures(futures);
83+
testLiveFutures(liveFutures);
84+
}
85+
86+
private GenerationConfig getConfig() {
87+
return new GenerationConfig.Builder().build();
88+
// TODO b/406558430 GenerationConfig.Builder.setParts returns void
89+
}
90+
91+
private LiveGenerationConfig getLiveConfig() {
92+
return new LiveGenerationConfig.Builder()
93+
.setTopK(10)
94+
.setTopP(11.0F)
95+
.setTemperature(32.0F)
96+
.setCandidateCount(1)
97+
.setMaxOutputTokens(0xCAFEBABE)
98+
.setFrequencyPenalty(1.0F)
99+
.setPresencePenalty(2.0F)
100+
.setResponseModality(ResponseModality.AUDIO)
101+
.setSpeechConfig(new SpeechConfig(Voices.AOEDE))
102+
.build();
69103
}
70104

71105
private void testFutures(GenerativeModelFutures futures) throws Exception {
@@ -236,4 +270,62 @@ public void validateUsageMetadata(UsageMetadata metadata) {
236270
}
237271
}
238272
}
273+
274+
private void testLiveFutures(LiveModelFutures futures) throws Exception {
275+
LiveSessionFutures session = futures.connect().get();
276+
session
277+
.receive()
278+
.subscribe(
279+
new Subscriber<LiveContentResponse>() {
280+
@Override
281+
public void onSubscribe(Subscription s) {
282+
s.request(Long.MAX_VALUE);
283+
}
284+
285+
@Override
286+
public void onNext(LiveContentResponse response) {
287+
validateLiveContentResponse(response);
288+
}
289+
290+
@Override
291+
public void onError(Throwable t) {
292+
// Ignore
293+
}
294+
295+
@Override
296+
public void onComplete() {
297+
// Also ignore
298+
}
299+
});
300+
301+
session.send("Fake message");
302+
session.send(new Content.Builder().addText("Fake message").build());
303+
304+
byte[] bytes = new byte[] {(byte) 0xCA, (byte) 0xFE, (byte) 0xBA, (byte) 0xBE};
305+
session.sendMediaStream(List.of(new MediaData(bytes, "image/jxl")));
306+
307+
FunctionResponsePart functionResponse =
308+
new FunctionResponsePart("myFunction", new JsonObject(Map.of()));
309+
session.sendFunctionResponse(List.of(functionResponse, functionResponse));
310+
311+
session.startAudioConversation(part -> functionResponse);
312+
session.startAudioConversation();
313+
session.stopAudioConversation();
314+
session.stopReceiving();
315+
session.close();
316+
}
317+
318+
private void validateLiveContentResponse(LiveContentResponse response) {
319+
//int status = response.getStatus();
320+
//Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL());
321+
//Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED());
322+
//Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE());
323+
// TODO b/412743328 LiveContentResponse.Status inaccessible for Java users
324+
Content data = response.getData();
325+
if (data != null) {
326+
validateContent(data);
327+
}
328+
String text = response.getText();
329+
validateFunctionCalls(response.getFunctionCalls());
330+
}
239331
}

0 commit comments

Comments
 (0)