Skip to content

Commit 32b3ae0

Browse files
committed
init
1 parent 0e5c2b2 commit 32b3ae0

File tree

9 files changed

+56
-106
lines changed

9 files changed

+56
-106
lines changed

src/llm/apis/openai_completions.cpp

+22-6
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ absl::Status OpenAIChatCompletionsHandler::parseChatCompletionsPart(std::optiona
275275
return absl::OkStatus();
276276
}
277277

278-
absl::Status OpenAIChatCompletionsHandler::parseCommonPart(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline, std::optional<uint32_t> maxModelLength) {
278+
absl::Status OpenAIChatCompletionsHandler::parseCommonPart(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline, bool isPromptLookupPipeline, std::optional<uint32_t> maxModelLength) {
279279
OVMS_PROFILE_FUNCTION();
280280
// stream: bool; optional
281281
if (!doc.IsObject())
@@ -495,22 +495,30 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(std::optional<uint32_
495495
request.numReturnSequences = it->value.GetUint();
496496
}
497497

498-
// Speculative decoding specific parameters
498+
// Assisted decoding specific parameters
499499

500500
auto numAssistantTokensIt = doc.FindMember("num_assistant_tokens");
501501
auto assistantConfidenceThresholdIt = doc.FindMember("assistant_confidence_threshold");
502+
auto maxNgramSizeIt = doc.FindMember("max_ngram_size");
502503

503504
bool numAssistantTokensItHasValue = (numAssistantTokensIt != doc.MemberEnd() && !numAssistantTokensIt->value.IsNull());
504505
bool assistantConfidenceThresholdItHasValue = (assistantConfidenceThresholdIt != doc.MemberEnd() && !assistantConfidenceThresholdIt->value.IsNull());
506+
bool maxNgramSizeItHasValue = (maxNgramSizeIt != doc.MemberEnd() && !maxNgramSizeIt->value.IsNull());
505507

506508
if (isSpeculativePipeline) {
507509
if (!numAssistantTokensItHasValue && !assistantConfidenceThresholdItHasValue)
508510
return absl::InvalidArgumentError("Speculative decoding requires either num_assistant_tokens or assistant_confidence_threshold to be set.");
509511

510512
if (numAssistantTokensItHasValue && assistantConfidenceThresholdItHasValue)
511513
return absl::InvalidArgumentError("num_assistant_tokens and assistant_confidence_threshold are mutually exclusive and cannot both be set.");
512-
} else if (numAssistantTokensItHasValue || assistantConfidenceThresholdItHasValue) {
513-
return absl::InvalidArgumentError("num_assistant_tokens and assistant_confidence_threshold are only supported when speculative decoding is enabled.");
514+
} else if (assistantConfidenceThresholdItHasValue) {
515+
return absl::InvalidArgumentError("assistant_confidence_threshold is only supported when speculative decoding is enabled.");
516+
}
517+
518+
if (isPromptLookupPipeline) {
519+
if (!numAssistantTokensItHasValue || !maxNgramSizeItHasValue) {
520+
return absl::InvalidArgumentError("Prompt lookup requires num_assistant_tokens and max_ngram_size to be set.");
521+
}
514522
}
515523
// num_assistant_tokens: uint;
516524
if (numAssistantTokensItHasValue) {
@@ -529,6 +537,14 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(std::optional<uint32_
529537
return absl::InvalidArgumentError("assistant_confidence_threshold must be greater than 0");
530538
}
531539
}
540+
541+
// max_ngram_size: uint; optional - defaults to 0
542+
if (maxNgramSizeIt != doc.MemberEnd() && !maxNgramSizeIt->value.IsNull()) {
543+
if (!maxNgramSizeIt->value.IsUint() || numAssistantTokensIt->value.GetUint() == 0) {
544+
return absl::InvalidArgumentError("max_ngram_size must be an unsigned integer greater than 0");
545+
}
546+
request.maxNgramSize = maxNgramSizeIt->value.GetUint();
547+
}
532548
request.maxModelLength = maxModelLength;
533549

534550
// use_beam_search: bool; optional - defaults to false
@@ -573,8 +589,8 @@ ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig
573589
return request.createGenerationConfig();
574590
}
575591

576-
absl::Status OpenAIChatCompletionsHandler::parseRequest(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline, std::optional<uint32_t> maxModelLength) {
577-
absl::Status status = parseCommonPart(maxTokensLimit, bestOfLimit, isSpeculativePipeline, maxModelLength);
592+
absl::Status OpenAIChatCompletionsHandler::parseRequest(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline, bool isPromptLookupPipeline, std::optional<uint32_t> maxModelLength) {
593+
absl::Status status = parseCommonPart(maxTokensLimit, bestOfLimit, isSpeculativePipeline, isPromptLookupPipeline, maxModelLength);
578594

579595
if (status != absl::OkStatus())
580596
return status;

src/llm/apis/openai_completions.hpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ struct OpenAIChatCompletionsRequest {
9595
std::optional<int> bestOf{std::nullopt};
9696
std::optional<float> lengthPenalty{std::nullopt};
9797

98-
// Speculative decoding specific (only with speculative decoding pipeline, see <docs> for reference)
98+
// Assisted decoding specific (only with speculative decoding or prompt lookup pipeline)
9999
std::optional<int> numAssistantTokens{std::nullopt};
100100
std::optional<float> assistantConfidenceThreshold{std::nullopt};
101+
std::optional<int> maxNgramSize{std::nullopt};
101102

102103
std::optional<uint32_t> maxModelLength;
103104

@@ -157,11 +158,13 @@ struct OpenAIChatCompletionsRequest {
157158

158159
if (logprobschat || logprobs)
159160
config.logprobs = 1;
160-
// Speculative decoding specific
161+
// Assisted decoding specific
161162
if (numAssistantTokens.has_value())
162163
config.num_assistant_tokens = numAssistantTokens.value();
163164
if (assistantConfidenceThreshold.has_value())
164165
config.assistant_confidence_threshold = assistantConfidenceThreshold.value();
166+
if (maxNgramSize.has_value())
167+
config.max_ngram_size = maxNgramSize.value();
165168

166169
return config;
167170
}
@@ -180,7 +183,7 @@ class OpenAIChatCompletionsHandler {
180183

181184
absl::Status parseCompletionsPart();
182185
absl::Status parseChatCompletionsPart(std::optional<uint32_t> maxTokensLimit);
183-
absl::Status parseCommonPart(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline, std::optional<uint32_t> maxModelLength);
186+
absl::Status parseCommonPart(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline, bool isPromptLookupPipeline, std::optional<uint32_t> maxModelLength);
184187

185188
public:
186189
OpenAIChatCompletionsHandler(Document& doc, Endpoint endpoint, std::chrono::time_point<std::chrono::system_clock> creationTime,
@@ -208,7 +211,7 @@ class OpenAIChatCompletionsHandler {
208211

209212
ov::genai::GenerationConfig createGenerationConfig() const;
210213

211-
absl::Status parseRequest(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline, std::optional<uint32_t> maxModelLength);
214+
absl::Status parseRequest(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline, bool isPromptLookupPipeline, std::optional<uint32_t> maxModelLength);
212215
absl::Status parseMessages();
213216

214217
std::string serializeUnaryResponse(const std::vector<ov::genai::GenerationOutput>& generationOutputs);

src/llm/language_model/continuous_batching/servable_initializer.cpp

+14-76
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,6 @@
4040

4141
namespace ovms {
4242

43-
ov::genai::SchedulerConfig ContinuousBatchingServableInitializer::prepareDraftPipelineSchedulerConfigExperimental(const mediapipe::LLMCalculatorOptions_PipelineConfig& draftPipelineConfig) {
44-
ov::genai::SchedulerConfig config;
45-
config.max_num_batched_tokens = draftPipelineConfig.max_num_batched_tokens();
46-
config.cache_size = draftPipelineConfig.cache_size();
47-
config.dynamic_split_fuse = draftPipelineConfig.dynamic_split_fuse();
48-
config.max_num_seqs = draftPipelineConfig.max_num_seqs();
49-
config.enable_prefix_caching = draftPipelineConfig.enable_prefix_caching();
50-
return config;
51-
}
52-
5343
ov::genai::SchedulerConfig ContinuousBatchingServableInitializer::prepareDraftPipelineSchedulerConfig(const mediapipe::LLMCalculatorOptions& nodeOptions) {
5444
ov::genai::SchedulerConfig config;
5545
config.max_num_batched_tokens = nodeOptions.has_draft_max_num_batched_tokens() ? nodeOptions.draft_max_num_batched_tokens() : nodeOptions.max_num_batched_tokens();
@@ -60,72 +50,6 @@ ov::genai::SchedulerConfig ContinuousBatchingServableInitializer::prepareDraftPi
6050
return config;
6151
}
6252

63-
Status ContinuousBatchingServableInitializer::initializeExperimental(std::shared_ptr<GenAiServable>& servable, const mediapipe::LLMCalculatorOptions& nodeOptions, std::string graphPath) {
64-
auto continousBatchingPipelineConfig = nodeOptions.continuous_batching_pipeline_config();
65-
auto mainPipelineConfig = continousBatchingPipelineConfig.main_pipeline_config();
66-
std::string parsedModelsPath;
67-
auto status = parseModelsPath(parsedModelsPath, mainPipelineConfig.models_path(), graphPath);
68-
if (!status.ok()) {
69-
return status;
70-
}
71-
auto properties = std::static_pointer_cast<ContinuousBatchingServableProperties>(servable->getProperties());
72-
properties->modelsPath = parsedModelsPath;
73-
74-
properties->schedulerConfig.max_num_batched_tokens = mainPipelineConfig.max_num_batched_tokens();
75-
properties->schedulerConfig.cache_size = mainPipelineConfig.cache_size();
76-
properties->schedulerConfig.dynamic_split_fuse = mainPipelineConfig.dynamic_split_fuse();
77-
properties->schedulerConfig.max_num_seqs = mainPipelineConfig.max_num_seqs();
78-
properties->schedulerConfig.enable_prefix_caching = mainPipelineConfig.enable_prefix_caching();
79-
80-
properties->device = mainPipelineConfig.device();
81-
82-
// Speculative decoding enabled
83-
properties->isSpeculativePipeline = false;
84-
if (continousBatchingPipelineConfig.has_draft_pipeline_config()) {
85-
auto draftPipelineConfig = continousBatchingPipelineConfig.draft_pipeline_config();
86-
auto fsDraftModelsPath = std::filesystem::path(draftPipelineConfig.models_path());
87-
std::string draftPipelinePath;
88-
if (fsDraftModelsPath.is_relative()) {
89-
draftPipelinePath = (std::filesystem::path(graphPath) / fsDraftModelsPath).string();
90-
} else {
91-
draftPipelinePath = fsDraftModelsPath.string();
92-
}
93-
auto draftSchedulerConfig = prepareDraftPipelineSchedulerConfigExperimental(draftPipelineConfig);
94-
auto draftPipeline = ov::genai::draft_model(draftPipelinePath, draftPipelineConfig.device(), ov::genai::scheduler_config(draftSchedulerConfig));
95-
properties->pluginConfig.insert(draftPipeline);
96-
properties->isSpeculativePipeline = true;
97-
}
98-
99-
status = JsonParser::parsePluginConfig(mainPipelineConfig.plugin_config(), properties->pluginConfig);
100-
if (!status.ok()) {
101-
SPDLOG_ERROR("Error during llm node plugin_config option parsing to JSON: {}", mainPipelineConfig.plugin_config());
102-
return status;
103-
}
104-
105-
properties->tokenizerPluginConfig = {{"PERFORMANCE_HINT", "THROUGHPUT"}};
106-
try {
107-
properties->pipeline = std::make_shared<ov::genai::ContinuousBatchingPipeline>(parsedModelsPath,
108-
properties->schedulerConfig, properties->device,
109-
properties->pluginConfig, properties->tokenizerPluginConfig);
110-
properties->tokenizer = properties->pipeline->get_tokenizer();
111-
} catch (const std::exception& e) {
112-
SPDLOG_ERROR("Error during llm node initialization for models_path: {} exception: {}", parsedModelsPath, e.what());
113-
return StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED;
114-
} catch (...) {
115-
SPDLOG_ERROR("Error during llm node initialization for models_path: {}", parsedModelsPath);
116-
return StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED;
117-
}
118-
119-
loadTextProcessor(properties, parsedModelsPath);
120-
if (nodeOptions.has_max_tokens_limit()) {
121-
properties->maxTokensLimit = nodeOptions.max_tokens_limit();
122-
}
123-
properties->bestOfLimit = mainPipelineConfig.best_of_limit();
124-
125-
properties->llmExecutorWrapper = std::make_shared<LLMExecutorWrapper>(properties->pipeline);
126-
return StatusCode::OK;
127-
}
128-
12953
Status ContinuousBatchingServableInitializer::initialize(std::shared_ptr<GenAiServable>& servable, const mediapipe::LLMCalculatorOptions& nodeOptions, std::string graphPath) {
13054
std::string parsedModelsPath;
13155
auto status = parseModelsPath(parsedModelsPath, nodeOptions.models_path(), graphPath);
@@ -174,6 +98,20 @@ Status ContinuousBatchingServableInitializer::initialize(std::shared_ptr<GenAiSe
17498
return status;
17599
}
176100

101+
std::cout << "Checking if prompt lookup is enabled" << std::endl;
102+
// Check if prompt lookup is enabled
103+
auto promptLookupPropertyIt = properties->pluginConfig.find("prompt_lookup");
104+
if (promptLookupPropertyIt != properties->pluginConfig.end()) {
105+
auto promptLookupProperty = promptLookupPropertyIt->second.as<bool>();
106+
if (promptLookupProperty == true) {
107+
properties->isPromptLookupPipeline = true;
108+
} else {
109+
properties->isPromptLookupPipeline = false;
110+
}
111+
}
112+
113+
std::cout << "properties->isPromptLookupPipeline: " << properties->isPromptLookupPipeline << std::endl;
114+
177115
properties->tokenizerPluginConfig = {{"PERFORMANCE_HINT", "THROUGHPUT"}};
178116
try {
179117
properties->pipeline = std::make_shared<ov::genai::ContinuousBatchingPipeline>(parsedModelsPath,

src/llm/language_model/continuous_batching/servable_initializer.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@ namespace ovms {
3333
class Status;
3434

3535
class ContinuousBatchingServableInitializer : public GenAiServableInitializer {
36-
static ov::genai::SchedulerConfig prepareDraftPipelineSchedulerConfigExperimental(const mediapipe::LLMCalculatorOptions_PipelineConfig& draftModelConfig);
3736
static ov::genai::SchedulerConfig prepareDraftPipelineSchedulerConfig(const mediapipe::LLMCalculatorOptions& nodeOptions);
3837

3938
public:
40-
Status initializeExperimental(std::shared_ptr<GenAiServable>& servable, const mediapipe::LLMCalculatorOptions& nodeOptions, std::string graphPath);
4139
Status initialize(std::shared_ptr<GenAiServable>& servable, const mediapipe::LLMCalculatorOptions& nodeOptions, std::string graphPath) override;
4240
};
4341
} // namespace ovms

src/llm/language_model/legacy/servable.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ absl::Status LegacyServable::parseRequest(std::shared_ptr<GenAiServableExecution
7272
std::chrono::system_clock::now(),
7373
getProperties()->tokenizer);
7474

75-
auto status = executionContext->apiHandler->parseRequest(getProperties()->maxTokensLimit, getProperties()->bestOfLimit, getProperties()->isSpeculativePipeline, getProperties()->maxModelLength);
75+
auto status = executionContext->apiHandler->parseRequest(getProperties()->maxTokensLimit, getProperties()->bestOfLimit, getProperties()->isSpeculativePipeline,
76+
getProperties()->isPromptLookupPipeline, getProperties()->maxModelLength);
7677
if (!status.ok()) {
7778
SPDLOG_LOGGER_ERROR(llm_calculator_logger, "Failed to parse request: {}", status.message());
7879
return status;

src/llm/servable.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ absl::Status GenAiServable::parseRequest(std::shared_ptr<GenAiServableExecutionC
5555
std::chrono::system_clock::now(),
5656
getProperties()->tokenizer);
5757

58-
auto status = executionContext->apiHandler->parseRequest(getProperties()->maxTokensLimit, getProperties()->bestOfLimit, getProperties()->isSpeculativePipeline, getProperties()->maxModelLength);
58+
auto status = executionContext->apiHandler->parseRequest(getProperties()->maxTokensLimit, getProperties()->bestOfLimit, getProperties()->isSpeculativePipeline,
59+
getProperties()->isPromptLookupPipeline, getProperties()->maxModelLength);
5960
if (!status.ok()) {
6061
SPDLOG_LOGGER_ERROR(llm_calculator_logger, "Failed to parse request: {}", status.message());
6162
return status;

src/llm/servable.hpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ struct GenAiServableProperties {
8282
// Sampling limits
8383
std::optional<uint32_t> maxTokensLimit;
8484
uint32_t bestOfLimit;
85-
bool isSpeculativePipeline; // sampling is generally common, but maybe we could avoid having this field at all
85+
// TODO (mzegla): perhaps we can remove below bools and rely on GenAI logic entirely
86+
bool isSpeculativePipeline; // sampling is generally common, but maybe we could avoid having this field at all
87+
bool isPromptLookupPipeline; // prompt lookup is generally common, but maybe we could avoid having this field at all
8688
// Text processing utilities
8789
ov::genai::Tokenizer tokenizer;
8890
TextProcessor textProcessor;

src/llm/servable_initializer.cpp

+4-14
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ Status initializeGenAiServable(std::shared_ptr<GenAiServable>& servable, const :
214214
mediapipe::LLMCalculatorOptions nodeOptions;
215215
graphNodeConfig.node_options(0).UnpackTo(&nodeOptions);
216216
Status status;
217-
if (nodeOptions.has_models_path()) { // Stable initialization
217+
if (nodeOptions.has_models_path()) {
218218
// need to initialize pipelineType with some value to avoid compiler warning, determinePipelineType will set it properly
219219
PipelineType pipelineType{PipelineType::LM_CB};
220220
status = determinePipelineType(pipelineType, nodeOptions, graphPath);
@@ -262,22 +262,12 @@ Status initializeGenAiServable(std::shared_ptr<GenAiServable>& servable, const :
262262
return StatusCode::INTERNAL_ERROR;
263263
}
264264
} else {
265-
if (nodeOptions.has_continuous_batching_pipeline_config()) { // Experimental initialization
266-
ContinuousBatchingServableInitializer cbServableInitializer;
267-
servable = std::make_shared<ContinuousBatchingServable>();
268-
status = cbServableInitializer.initializeExperimental(servable, nodeOptions, graphPath);
269-
} else {
270-
SPDLOG_LOGGER_ERROR(modelmanager_logger, "LLM node options do not contain any recognized pipeline configuration.");
271-
return StatusCode::INTERNAL_ERROR;
272-
}
273-
274-
if (status != StatusCode::OK) {
275-
SPDLOG_LOGGER_ERROR(modelmanager_logger, "Error during LLM node resources initialization: {}", status.string());
276-
return status;
277-
}
265+
SPDLOG_LOGGER_ERROR(modelmanager_logger, "LLM node requires models_path to be set.");
266+
return StatusCode::INTERNAL_ERROR;
278267
}
279268
return StatusCode::OK;
280269
}
270+
281271
std::optional<uint32_t> parseMaxModelLength(std::string& modelsPath) {
282272
std::string configPath = FileSystem::appendSlash(modelsPath) + "config.json";
283273
std::optional<uint32_t> maxModelLength;

src/llm/visual_language_model/legacy/servable.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ absl::Status VisualLanguageModelLegacyServable::parseRequest(std::shared_ptr<Gen
7474
std::chrono::system_clock::now(),
7575
getProperties()->tokenizer);
7676

77-
auto status = executionContext->apiHandler->parseRequest(getProperties()->maxTokensLimit, getProperties()->bestOfLimit, getProperties()->isSpeculativePipeline, getProperties()->maxModelLength);
77+
auto status = executionContext->apiHandler->parseRequest(getProperties()->maxTokensLimit, getProperties()->bestOfLimit, getProperties()->isSpeculativePipeline,
78+
getProperties()->isPromptLookupPipeline, getProperties()->maxModelLength);
7879
if (!status.ok()) {
7980
SPDLOG_LOGGER_ERROR(llm_calculator_logger, "Failed to parse request: {}", status.message());
8081
return status;

0 commit comments

Comments
 (0)