Skip to content

add llm flush sentinel for llm_node #2023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions .github/next-release/changeset-007be3b9.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

add llm flush sentinel for llm_node (#2023)
99 changes: 99 additions & 0 deletions examples/voice_agents/two_llm_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import logging
from collections.abc import AsyncIterable

from dotenv import load_dotenv

from livekit.agents import (
Agent,
AgentSession,
JobContext,
MetricsCollectedEvent,
ModelSettings,
WorkerOptions,
cli,
llm,
metrics,
)
from livekit.plugins import cartesia, deepgram, groq, silero

logger = logging.getLogger("two-llm-example")
logger.setLevel(logging.INFO)

load_dotenv()

## This example shows how to use a fast LLM and a main LLM to generate a response.
## The fast LLM is used to generate a short instant response to the user's message.
## The main LLM is used to generate a more detailed response to the user's message.


class TwoLLMAgent(Agent):
def __init__(self) -> None:
super().__init__(
instructions="You are a helpful assistant.",
llm=groq.LLM(model="llama-3.3-70b-versatile"),
)
self.fast_llm: llm.LLM = groq.LLM(model="llama-3.1-8b-instant")
self.fast_llm_prompt = llm.ChatMessage(
role="system",
content=[
"Generate a short instant response to the user's message with 5 to 10 words.",
"Do not answer the questions directly. For example, let me think about that, "
"wait a moment, that's a good question, etc.",
],
)

async def llm_node(
self,
chat_ctx: llm.ChatContext,
tools: list[llm.FunctionTool],
model_settings: ModelSettings,
) -> AsyncIterable[llm.ChatChunk | llm.FlushSentinel]:
# truncate the chat ctx with a fast response prompt
fast_chat_ctx = chat_ctx.copy(
exclude_function_call=True, exclude_instructions=True
).truncate(max_items=3)
fast_chat_ctx.items.insert(0, self.fast_llm_prompt)

quick_response = ""
async with self.fast_llm.chat(chat_ctx=fast_chat_ctx) as stream:
async for chunk in stream:
yield chunk
if chunk.delta and chunk.delta.content:
quick_response += chunk.delta.content

# flush the quick response to tts
yield llm.FlushSentinel()
logger.info(f"quick response: {quick_response}")

# (Optional) add the quick response to the chat ctx for the main llm
assert isinstance(self.llm, llm.LLM)
chat_ctx.add_message(role="assistant", content=quick_response)

# generate the response with the main llm
async for chunk in Agent.default.llm_node(
agent=self,
chat_ctx=chat_ctx,
tools=tools,
model_settings=model_settings,
):
yield chunk


async def entrypoint(ctx: JobContext):
await ctx.connect()

session = AgentSession(
vad=silero.VAD.load(),
stt=deepgram.STT(),
tts=cartesia.TTS(),
)

@session.on("metrics_collected")
def _on_metrics_collected(ev: MetricsCollectedEvent):
metrics.log_metrics(ev.metrics)

await session.start(agent=TwoLLMAgent(), room=ctx.room)


if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ChatChunk,
ChoiceDelta,
CompletionUsage,
FlushSentinel,
FunctionToolCall,
LLMError,
LLMStream,
Expand Down Expand Up @@ -88,4 +89,5 @@
"GenerationCreatedEvent",
"MessageGeneration",
"LLMError",
"FlushSentinel",
]
3 changes: 3 additions & 0 deletions livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class ChatChunk(BaseModel):
usage: CompletionUsage | None = None


class FlushSentinel: ...


class LLMError(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
type: Literal["llm_error"] = "llm_error"
Expand Down
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/voice/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def llm_node(
tools: list[FunctionTool],
model_settings: ModelSettings,
) -> (
AsyncIterable[llm.ChatChunk | str]
| Coroutine[Any, Any, AsyncIterable[llm.ChatChunk | str]]
AsyncIterable[llm.ChatChunk | str | llm.FlushSentinel]
| Coroutine[Any, Any, AsyncIterable[llm.ChatChunk | str | llm.FlushSentinel]]
| Coroutine[Any, Any, str]
| Coroutine[Any, Any, llm.ChatChunk]
| Coroutine[Any, Any, None]
Expand Down
10 changes: 9 additions & 1 deletion livekit-agents/livekit/agents/voice/agent_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,15 @@ async def _pipeline_reply_task(
await utils.aio.cancel_and_wait(*tasks)
return

tr_node = self._agent.transcription_node(llm_output, model_settings)
async def _read_text(
llm_output: AsyncIterable[str | llm.FlushSentinel],
) -> AsyncIterable[str]:
async for chunk in llm_output:
if isinstance(chunk, llm.FlushSentinel):
continue
yield chunk

tr_node = self._agent.transcription_node(_read_text(llm_output), model_settings)
if asyncio.iscoroutine(tr_node):
tr_node = await tr_node

Expand Down
53 changes: 39 additions & 14 deletions livekit-agents/livekit/agents/voice/generation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterable
from collections.abc import AsyncGenerator, AsyncIterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

Expand Down Expand Up @@ -41,7 +41,7 @@ async def aclose(self): ...

@dataclass
class _LLMGenerationData:
text_ch: aio.Chan[str]
text_ch: aio.Chan[str | llm.FlushSentinel]
function_ch: aio.Chan[llm.FunctionCall]
generated_text: str = ""
generated_functions: list[llm.FunctionCall] = field(default_factory=list)
Expand Down Expand Up @@ -88,6 +88,9 @@ async def _inference_task():
data.generated_text += chunk
text_ch.send_nowait(chunk)

elif isinstance(chunk, llm.FlushSentinel):
text_ch.send_nowait(chunk)

elif isinstance(chunk, ChatChunk):
if not chunk.delta:
continue
Expand Down Expand Up @@ -133,23 +136,45 @@ class _TTSGenerationData:


def perform_tts_inference(
*, node: io.TTSNode, input: AsyncIterable[str], model_settings: ModelSettings
*,
node: io.TTSNode,
input: AsyncIterable[str | llm.FlushSentinel],
model_settings: ModelSettings,
) -> tuple[asyncio.Task, _TTSGenerationData]:
audio_ch = aio.Chan[rtc.AudioFrame]()

@utils.log_exceptions(logger=logger)
async def _inference_task():
tts_node = node(input, model_settings)
if asyncio.iscoroutine(tts_node):
tts_node = await tts_node

if isinstance(tts_node, AsyncIterable):
async for audio_frame in tts_node:
audio_ch.send_nowait(audio_frame)

return True

return False
# convert any input to a generator
async def _input() -> AsyncGenerator[str | llm.FlushSentinel, None]:
async for chunk in input:
yield chunk

input_gen = _input()
audio_generated = False
done = False

# split the input into segments
async def _input_segment() -> AsyncIterable[str]:
async for chunk in input_gen:
if isinstance(chunk, llm.FlushSentinel):
return
yield chunk

nonlocal done
done = True

while not done:
# create a new tts node for each segment
tts_node = node(_input_segment(), model_settings)
if asyncio.iscoroutine(tts_node):
tts_node = await tts_node

if isinstance(tts_node, AsyncIterable):
async for audio_frame in tts_node:
audio_ch.send_nowait(audio_frame)
audio_generated = True
return audio_generated

tts_task = asyncio.create_task(_inference_task())
tts_task.add_done_callback(lambda _: audio_ch.close())
Expand Down
4 changes: 2 additions & 2 deletions livekit-agents/livekit/agents/voice/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
LLMNode = Callable[
[llm.ChatContext, list[llm.FunctionTool], ModelSettings],
Union[
Optional[Union[AsyncIterable[llm.ChatChunk], AsyncIterable[str], str]],
Awaitable[Optional[Union[AsyncIterable[llm.ChatChunk], AsyncIterable[str], str]]],
Optional[Union[AsyncIterable[llm.ChatChunk | str | llm.FlushSentinel], str]],
Awaitable[Optional[Union[AsyncIterable[llm.ChatChunk | str | llm.FlushSentinel], str]]],
],
]
TTSNode = Callable[
Expand Down
Loading