Skip to content

Commit 9c25ec4

Browse files
authored
feat: expose n_ubatch and dynamically adjust ntokens for bench (#104)
* feat: expose n_ubatch in the context params * feat: limit max tokens by n_ubatch in bench
1 parent fb3896e commit 9c25ec4

File tree

5 files changed

+12
-1
lines changed

5 files changed

+12
-1
lines changed

android/src/main/java/com/rnllama/LlamaContext.java

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
5050
params.hasKey("n_ctx") ? params.getInt("n_ctx") : 512,
5151
// int n_batch,
5252
params.hasKey("n_batch") ? params.getInt("n_batch") : 512,
53+
// int n_ubatch,
54+
params.hasKey("n_ubatch") ? params.getInt("n_ubatch") : 512,
5355
// int n_threads,
5456
params.hasKey("n_threads") ? params.getInt("n_threads") : 0,
5557
// int n_gpu_layers, // TODO: Support this
@@ -412,6 +414,7 @@ protected static native long initContext(
412414
int embd_normalize,
413415
int n_ctx,
414416
int n_batch,
417+
int n_ubatch,
415418
int n_threads,
416419
int n_gpu_layers, // TODO: Support this
417420
boolean flash_attn,

android/src/main/jni.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ Java_com_rnllama_LlamaContext_initContext(
226226
jint embd_normalize,
227227
jint n_ctx,
228228
jint n_batch,
229+
jint n_ubatch,
229230
jint n_threads,
230231
jint n_gpu_layers, // TODO: Support this
231232
jboolean flash_attn,
@@ -256,6 +257,7 @@ Java_com_rnllama_LlamaContext_initContext(
256257

257258
defaultParams.n_ctx = n_ctx;
258259
defaultParams.n_batch = n_batch;
260+
defaultParams.n_ubatch = n_ubatch;
259261

260262
if (pooling_type != -1) {
261263
defaultParams.pooling_type = static_cast<enum llama_pooling_type>(pooling_type);

cpp/rn-llama.hpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,11 @@ struct llama_rn_context
683683
double tg_std = 0;
684684

685685
// TODO: move batch into llama_rn_context (related https://github.com/mybigday/llama.rn/issues/30)
686-
llama_batch batch = llama_batch_init(512, 0, 1);
686+
llama_batch batch = llama_batch_init(
687+
std::min(pp, params.n_ubatch), // max n_tokens is limited by n_ubatch
688+
0, // No embeddings
689+
1 // Single sequence
690+
);
687691

688692
for (int i = 0; i < nr; i++)
689693
{

ios/RNLlamaContext.mm

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig
9494
#endif
9595
}
9696
if (params[@"n_batch"]) defaultParams.n_batch = [params[@"n_batch"] intValue];
97+
if (params[@"n_ubatch"]) defaultParams.n_ubatch = [params[@"n_ubatch"] intValue];
9798
if (params[@"use_mmap"]) defaultParams.use_mmap = [params[@"use_mmap"] boolValue];
9899

99100
if (params[@"pooling_type"] && [params[@"pooling_type"] isKindOfClass:[NSNumber class]]) {

src/NativeRNLlama.ts

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ export type NativeContextParams = {
1212

1313
n_ctx?: number
1414
n_batch?: number
15+
n_ubatch?: number
1516

1617
n_threads?: number
1718
n_gpu_layers?: number

0 commit comments

Comments
 (0)