@@ -114,6 +114,7 @@ class ExecuTorchLlmCallbackJni
114
114
class ExecuTorchLlmJni : public facebook ::jni::HybridClass<ExecuTorchLlmJni> {
115
115
private:
116
116
friend HybridBase;
117
+ float temperature_;
117
118
int model_type_category_;
118
119
std::unique_ptr<llm::IRunner> runner_;
119
120
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -159,6 +160,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
159
160
#endif
160
161
161
162
model_type_category_ = model_type_category;
163
+ temperature_ = temperature;
162
164
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
163
165
multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
164
166
model_path->toStdString ().c_str (),
@@ -181,8 +183,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
181
183
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
182
184
runner_ = std::make_unique<MTKLlamaRunner>(
183
185
model_path->toStdString ().c_str (),
184
- tokenizer_path->toStdString ().c_str (),
185
- temperature);
186
+ tokenizer_path->toStdString ().c_str ());
186
187
// Interpret the model type as LLM
187
188
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
188
189
#endif
@@ -222,6 +223,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
222
223
executorch::extension::llm::GenerationConfig config{
223
224
.echo = static_cast <bool >(echo),
224
225
.seq_len = seq_len,
226
+ .temperature = temperature_,
225
227
};
226
228
runner_->generate (
227
229
prompt->toStdString (),
0 commit comments