Skip to content

Commit 1088300

Browse files
authored
fix: not init sampling before get embedding (#69)
* fix: not init sampling before get embedding * fix(android): embeddings array
1 parent 7bfda3b commit 1088300

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

android/src/main/java/com/rnllama/LlamaContext.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public WritableMap completion(ReadableMap params) {
139139
}
140140
}
141141

142-
return doCompletion(
142+
WritableMap result = doCompletion(
143143
this.context,
144144
// String prompt,
145145
params.getString("prompt"),
@@ -193,6 +193,10 @@ public WritableMap completion(ReadableMap params) {
193193
params.hasKey("emit_partial_completion") ? params.getBoolean("emit_partial_completion") : false
194194
)
195195
);
196+
if (result.hasKey("error")) {
197+
throw new IllegalStateException(result.getString("error"));
198+
}
199+
return result;
196200
}
197201

198202
public void stopCompletion() {
@@ -217,12 +221,14 @@ public String detokenize(ReadableArray tokens) {
217221
return detokenize(this.context, toks);
218222
}
219223

220-
public WritableMap embedding(String text) {
224+
public WritableMap getEmbedding(String text) {
221225
if (isEmbeddingEnabled(this.context) == false) {
222226
throw new IllegalStateException("Embedding is not enabled");
223227
}
224-
WritableMap result = Arguments.createMap();
225-
result.putArray("embedding", embedding(this.context, text));
228+
WritableMap result = embedding(this.context, text);
229+
if (result.hasKey("error")) {
230+
throw new IllegalStateException(result.getString("error"));
231+
}
226232
return result;
227233
}
228234

@@ -354,7 +360,7 @@ protected static native WritableMap doCompletion(
354360
protected static native WritableArray tokenize(long contextPtr, String text);
355361
protected static native String detokenize(long contextPtr, int[] tokens);
356362
protected static native boolean isEmbeddingEnabled(long contextPtr);
357-
protected static native WritableArray embedding(long contextPtr, String text);
363+
protected static native WritableMap embedding(long contextPtr, String text);
358364
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
359365
protected static native void freeContext(long contextPtr);
360366
}

android/src/main/java/com/rnllama/RNLlama.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ protected WritableMap doInBackground(Void... voids) {
297297
if (context == null) {
298298
throw new Exception("Context not found");
299299
}
300-
return context.embedding(text);
300+
return context.getEmbedding(text);
301301
} catch (Exception e) {
302302
exception = e;
303303
}

android/src/main/jni.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -581,17 +581,24 @@ Java_com_rnllama_LlamaContext_embedding(
581581
llama->params.prompt = text_chars;
582582

583583
llama->params.n_predict = 0;
584+
585+
auto result = createWriteableMap(env);
586+
if (!llama->initSampling()) {
587+
putString(env, result, "error", "Failed to initialize sampling");
588+
return reinterpret_cast<jobject>(result);
589+
}
590+
584591
llama->beginCompletion();
585592
llama->loadPrompt();
586593
llama->doCompletion();
587594

588595
std::vector<float> embedding = llama->getEmbedding();
589596

590-
jobject result = createWritableArray(env);
591-
597+
auto embeddings = createWritableArray(env);
592598
for (const auto &val : embedding) {
593-
pushDouble(env, result, (double) val);
599+
pushDouble(env, embeddings, (double) val);
594600
}
601+
putArray(env, result, "embedding", embeddings);
595602

596603
env->ReleaseStringUTFChars(text, text_chars);
597604
return result;

ios/RNLlamaContext.mm

+5-1
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,12 @@ - (NSArray *)embedding:(NSString *)text {
365365
llama->params.prompt = [text UTF8String];
366366

367367
llama->params.n_predict = 0;
368-
llama->loadPrompt();
368+
369+
if (!llama->initSampling()) {
370+
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to initialize sampling" userInfo:nil];
371+
}
369372
llama->beginCompletion();
373+
llama->loadPrompt();
370374
llama->doCompletion();
371375

372376
std::vector<float> result = llama->getEmbedding();

0 commit comments

Comments
 (0)