Skip to content

Commit 18f75b1

Browse files
committed
Android JNI llama cache temperature in class
1 parent ef7d4ca commit 18f75b1

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

Diff for: extension/android/jni/jni_layer_llama.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class ExecuTorchLlmCallbackJni
114114
class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
115115
private:
116116
friend HybridBase;
117+
float temperature_;
117118
int model_type_category_;
118119
std::unique_ptr<llm::IRunner> runner_;
119120
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -159,6 +160,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
159160
#endif
160161

161162
model_type_category_ = model_type_category;
163+
temperature_ = temperature;
162164
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
163165
multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
164166
model_path->toStdString().c_str(),
@@ -181,8 +183,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
181183
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
182184
runner_ = std::make_unique<MTKLlamaRunner>(
183185
model_path->toStdString().c_str(),
184-
tokenizer_path->toStdString().c_str(),
185-
temperature);
186+
tokenizer_path->toStdString().c_str());
186187
// Interpret the model type as LLM
187188
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
188189
#endif
@@ -222,6 +223,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
222223
executorch::extension::llm::GenerationConfig config{
223224
.echo = static_cast<bool>(echo),
224225
.seq_len = seq_len,
226+
.temperature = temperature_,
225227
};
226228
runner_->generate(
227229
prompt->toStdString(),

0 commit comments

Comments
 (0)