diff --git a/src/llm/language_model/legacy/servable.cpp b/src/llm/language_model/legacy/servable.cpp index 3ab86e4011..9dcea0132a 100644 --- a/src/llm/language_model/legacy/servable.cpp +++ b/src/llm/language_model/legacy/servable.cpp @@ -106,12 +106,12 @@ absl::Status LegacyServable::parseRequest(std::shared_ptrapiHandler->getRequest().skipSpecialTokens) { streamerConfig.insert(ov::genai::skip_special_tokens(false)); } - auto ovmsCallback = [& ctx = *legacyExecutionContext](rapidjson::Document delta) -> ov::genai::StreamingStatus { + auto ovmsCallback = [& ctx = *legacyExecutionContext](rapidjson::Document delta, bool isLast) -> ov::genai::StreamingStatus { if (ctx.clientDisconnected.load()) { ctx.deltaChannel.signalComplete(); return ov::genai::StreamingStatus::CANCEL; } - ctx.deltaChannel.push(std::move(delta)); + ctx.deltaChannel.push(std::move(delta), isLast); return ov::genai::StreamingStatus::RUNNING; }; legacyExecutionContext->textStreamer = std::make_shared( diff --git a/src/llm/ovms_text_streamer.cpp b/src/llm/ovms_text_streamer.cpp index a7686133de..4d882b5e4a 100644 --- a/src/llm/ovms_text_streamer.cpp +++ b/src/llm/ovms_text_streamer.cpp @@ -196,14 +196,15 @@ ov::genai::StreamingStatus OVMSTextStreamer::flush_chunk( delta = std::move(doc); } + const bool isLast = (finish_reason != ov::genai::GenerationFinishReason::NONE); if (delta.has_value()) { - return m_callback(std::move(*delta)); + return m_callback(std::move(*delta), isLast); } - if (finish_reason != ov::genai::GenerationFinishReason::NONE) { - // Parser produced no delta for the final flush (e.g. generation ended - // on a special token the parser absorbed). Still fire the callback with - // an empty Document so preparePartialResponse can emit the finish_reason. - return m_callback(rapidjson::Document{}); + if (isLast) { + // Parser produced no delta for the final flush (e.g. generation ended on a + // special token the parser absorbed). Still fire the callback with an empty + // Document so the caller can emit the finish_reason chunk. + return m_callback(rapidjson::Document{}, true); } return ov::genai::StreamingStatus::RUNNING; } diff --git a/src/llm/ovms_text_streamer.hpp b/src/llm/ovms_text_streamer.hpp index 808054ae60..d320188076 100644 --- a/src/llm/ovms_text_streamer.hpp +++ b/src/llm/ovms_text_streamer.hpp @@ -49,11 +49,14 @@ namespace ovms { // fires the callback unconditionally, preserving existing behavior. class OVMSTextStreamer : public ov::genai::TextStreamer { public: - // Callback receives a Document and returns the streaming status. + // Callback receives a Document and the isLast flag, and returns the streaming status. // Document shape is always {"delta":{...}} matching the OpenAI delta format. // For the finish-only case (nullopt from parseChunk + STOP finishReason), // an empty Document{} is passed so the caller can emit the finish_reason chunk. - using Callback = std::function; + // isLast is true when finish_reason != NONE — callers that push into a DeltaChannel + // should forward this flag to DeltaChannel::push() so the final document and the + // completion signal are observed atomically (no separate signalComplete() needed). + using Callback = std::function; // outputParser may be nullptr (e.g. for the unary VLM path). // TODO(phase3): rework ownership — OVMSTextStreamer should not need to keep diff --git a/src/llm/servable.cpp b/src/llm/servable.cpp index 0d934cce0a..17d74e688a 100644 --- a/src/llm/servable.cpp +++ b/src/llm/servable.cpp @@ -141,8 +141,8 @@ absl::Status GenAiServable::parseRequest(std::shared_ptrapiHandler->isStream()) { - auto ovmsCallback = [& ctx = *executionContext](rapidjson::Document delta) -> ov::genai::StreamingStatus { - ctx.deltaChannel.push(std::move(delta)); + auto ovmsCallback = [& ctx = *executionContext](rapidjson::Document delta, bool isLast) -> ov::genai::StreamingStatus { + ctx.deltaChannel.push(std::move(delta), isLast); return ov::genai::StreamingStatus::RUNNING; }; ov::AnyMap streamerConfig; diff --git a/src/llm/servable.hpp b/src/llm/servable.hpp index 6d1669735a..c84801c678 100644 --- a/src/llm/servable.hpp +++ b/src/llm/servable.hpp @@ -81,16 +81,21 @@ enum class ChatTemplateMode { // thread, so the mutex is acquired but uncontested. struct DeltaChannel { // Push a delta from any thread (streamer callback). - void push(rapidjson::Document delta) { + // When isLast is true, also marks the channel complete atomically so consumers + // always see the final document and the completion flag in the same observation. + void push(rapidjson::Document delta, bool isLast = false) { { std::lock_guard lock(m_mutex); m_deltas.push_back(std::move(delta)); + if (isLast) + m_complete = true; } m_cv.notify_one(); } // Signal that no more deltas will be pushed (generation complete or cancelled). - // May be called from any thread. + // May be called from any thread. Also acts as a safety-net for paths where + // push(delta, isLast=true) may not fire (e.g. client disconnection mid-stream). void signalComplete() { { std::lock_guard lock(m_mutex); diff --git a/src/llm/visual_language_model/legacy/servable.cpp b/src/llm/visual_language_model/legacy/servable.cpp index 399682d6c5..f185332e44 100644 --- a/src/llm/visual_language_model/legacy/servable.cpp +++ b/src/llm/visual_language_model/legacy/servable.cpp @@ -122,12 +122,12 @@ absl::Status VisualLanguageModelLegacyServable::parseRequest(std::shared_ptrapiHandler->getRequest().skipSpecialTokens) { streamerConfig.insert(ov::genai::skip_special_tokens(false)); } - auto ovmsCallback = [& ctx = *legacyExecutionContext](rapidjson::Document delta) -> ov::genai::StreamingStatus { + auto ovmsCallback = [& ctx = *legacyExecutionContext](rapidjson::Document delta, bool isLast) -> ov::genai::StreamingStatus { if (ctx.clientDisconnected.load()) { ctx.deltaChannel.signalComplete(); return ov::genai::StreamingStatus::CANCEL; } - ctx.deltaChannel.push(std::move(delta)); + ctx.deltaChannel.push(std::move(delta), isLast); return ov::genai::StreamingStatus::RUNNING; }; legacyExecutionContext->textStreamer = std::make_shared( @@ -155,7 +155,7 @@ absl::Status VisualLanguageModelLegacyServable::parseRequest(std::shared_ptrapiHandler->getRequest().skipSpecialTokens) { streamerConfig.insert(ov::genai::skip_special_tokens(false)); } - auto unaryCallback = [& ctx = *legacyExecutionContext](rapidjson::Document delta) -> ov::genai::StreamingStatus { + auto unaryCallback = [& ctx = *legacyExecutionContext](rapidjson::Document delta, bool /*isLast*/) -> ov::genai::StreamingStatus { if (delta.HasMember("delta") && delta["delta"].IsObject() && delta["delta"].HasMember("content") && delta["delta"]["content"].IsString()) { ctx.accumulatedUnaryText += delta["delta"]["content"].GetString();