21
21
import com .google .firebase .concurrent .FirebaseExecutors ;
22
22
import com .google .firebase .vertexai .FirebaseVertexAI ;
23
23
import com .google .firebase .vertexai .GenerativeModel ;
24
+ import com .google .firebase .vertexai .LiveGenerativeModel ;
24
25
import com .google .firebase .vertexai .java .ChatFutures ;
25
26
import com .google .firebase .vertexai .java .GenerativeModelFutures ;
27
+ import com .google .firebase .vertexai .java .LiveModelFutures ;
28
+ import com .google .firebase .vertexai .java .LiveSessionFutures ;
26
29
import com .google .firebase .vertexai .type .BlockReason ;
27
30
import com .google .firebase .vertexai .type .Candidate ;
28
31
import com .google .firebase .vertexai .type .Citation ;
33
36
import com .google .firebase .vertexai .type .FileDataPart ;
34
37
import com .google .firebase .vertexai .type .FinishReason ;
35
38
import com .google .firebase .vertexai .type .FunctionCallPart ;
39
+ import com .google .firebase .vertexai .type .FunctionResponsePart ;
36
40
import com .google .firebase .vertexai .type .GenerateContentResponse ;
41
+ import com .google .firebase .vertexai .type .GenerationConfig ;
37
42
import com .google .firebase .vertexai .type .HarmCategory ;
38
43
import com .google .firebase .vertexai .type .HarmProbability ;
39
44
import com .google .firebase .vertexai .type .HarmSeverity ;
40
45
import com .google .firebase .vertexai .type .ImagePart ;
41
46
import com .google .firebase .vertexai .type .InlineDataPart ;
47
+ import com .google .firebase .vertexai .type .LiveContentResponse ;
48
+ import com .google .firebase .vertexai .type .LiveGenerationConfig ;
49
+ import com .google .firebase .vertexai .type .MediaData ;
42
50
import com .google .firebase .vertexai .type .ModalityTokenCount ;
43
51
import com .google .firebase .vertexai .type .Part ;
44
52
import com .google .firebase .vertexai .type .PromptFeedback ;
53
+ import com .google .firebase .vertexai .type .ResponseModality ;
45
54
import com .google .firebase .vertexai .type .SafetyRating ;
55
+ import com .google .firebase .vertexai .type .SpeechConfig ;
46
56
import com .google .firebase .vertexai .type .TextPart ;
47
57
import com .google .firebase .vertexai .type .UsageMetadata ;
58
+ import com .google .firebase .vertexai .type .Voices ;
48
59
import java .util .Calendar ;
49
60
import java .util .List ;
50
61
import java .util .Map ;
51
62
import java .util .concurrent .Executor ;
52
63
import kotlinx .serialization .json .JsonElement ;
53
64
import kotlinx .serialization .json .JsonNull ;
65
+ import kotlinx .serialization .json .JsonObject ;
54
66
import org .junit .Assert ;
55
67
import org .reactivestreams .Publisher ;
56
68
import org .reactivestreams .Subscriber ;
@@ -63,9 +75,31 @@ public class JavaCompileTests {
63
75
64
76
public void initializeJava () throws Exception {
65
77
FirebaseVertexAI vertex = FirebaseVertexAI .getInstance ();
66
- GenerativeModel model = vertex .generativeModel ("fake-model-name" );
78
+ GenerativeModel model = vertex .generativeModel ("fake-model-name" , getConfig ());
79
+ LiveGenerativeModel live = vertex .liveModel ("fake-model-name" , getLiveConfig ());
67
80
GenerativeModelFutures futures = GenerativeModelFutures .from (model );
81
+ LiveModelFutures liveFutures = LiveModelFutures .from (live );
68
82
testFutures (futures );
83
+ testLiveFutures (liveFutures );
84
+ }
85
+
86
+ private GenerationConfig getConfig () {
87
+ return new GenerationConfig .Builder ().build ();
88
+ // TODO b/406558430 GenerationConfig.Builder.setParts returns void
89
+ }
90
+
91
+ private LiveGenerationConfig getLiveConfig () {
92
+ return new LiveGenerationConfig .Builder ()
93
+ .setTopK (10 )
94
+ .setTopP (11.0F )
95
+ .setTemperature (32.0F )
96
+ .setCandidateCount (1 )
97
+ .setMaxOutputTokens (0xCAFEBABE )
98
+ .setFrequencyPenalty (1.0F )
99
+ .setPresencePenalty (2.0F )
100
+ .setResponseModality (ResponseModality .AUDIO )
101
+ .setSpeechConfig (new SpeechConfig (Voices .AOEDE ))
102
+ .build ();
69
103
}
70
104
71
105
private void testFutures (GenerativeModelFutures futures ) throws Exception {
@@ -236,4 +270,62 @@ public void validateUsageMetadata(UsageMetadata metadata) {
236
270
}
237
271
}
238
272
}
273
+
274
+ private void testLiveFutures (LiveModelFutures futures ) throws Exception {
275
+ LiveSessionFutures session = futures .connect ().get ();
276
+ session
277
+ .receive ()
278
+ .subscribe (
279
+ new Subscriber <LiveContentResponse >() {
280
+ @ Override
281
+ public void onSubscribe (Subscription s ) {
282
+ s .request (Long .MAX_VALUE );
283
+ }
284
+
285
+ @ Override
286
+ public void onNext (LiveContentResponse response ) {
287
+ validateLiveContentResponse (response );
288
+ }
289
+
290
+ @ Override
291
+ public void onError (Throwable t ) {
292
+ // Ignore
293
+ }
294
+
295
+ @ Override
296
+ public void onComplete () {
297
+ // Also ignore
298
+ }
299
+ });
300
+
301
+ session .send ("Fake message" );
302
+ session .send (new Content .Builder ().addText ("Fake message" ).build ());
303
+
304
+ byte [] bytes = new byte [] {(byte ) 0xCA , (byte ) 0xFE , (byte ) 0xBA , (byte ) 0xBE };
305
+ session .sendMediaStream (List .of (new MediaData (bytes , "image/jxl" )));
306
+
307
+ FunctionResponsePart functionResponse =
308
+ new FunctionResponsePart ("myFunction" , new JsonObject (Map .of ()));
309
+ session .sendFunctionResponse (List .of (functionResponse , functionResponse ));
310
+
311
+ session .startAudioConversation (part -> functionResponse );
312
+ session .startAudioConversation ();
313
+ session .stopAudioConversation ();
314
+ session .stopReceiving ();
315
+ session .close ();
316
+ }
317
+
318
+ private void validateLiveContentResponse (LiveContentResponse response ) {
319
+ //int status = response.getStatus();
320
+ //Assert.assertEquals(status, LiveContentResponse.Status.Companion.getNORMAL());
321
+ //Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getINTERRUPTED());
322
+ //Assert.assertNotEquals(status, LiveContentResponse.Status.Companion.getTURN_COMPLETE());
323
+ // TODO b/412743328 LiveContentResponse.Status inaccessible for Java users
324
+ Content data = response .getData ();
325
+ if (data != null ) {
326
+ validateContent (data );
327
+ }
328
+ String text = response .getText ();
329
+ validateFunctionCalls (response .getFunctionCalls ());
330
+ }
239
331
}
0 commit comments