Skip to content

Commit f6f01ea

Browse files
authored
Merge pull request #1588 from pipecat-ai/aleix/llm-aggregator-params
LLM aggregator params
2 parents 8299c96 + f385cc0 commit f6f01ea

File tree

13 files changed

+227
-128
lines changed

13 files changed

+227
-128
lines changed

CHANGELOG.md

+11-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1212
- `DeepgramTTSService` accepts `base_url` argument again, allowing you to
1313
connect to an on-prem service.
1414

15+
- Added `LLMUserAggregatorParams` and `LLMAssistantAggregatorParams` which allow
16+
you to control aggregator settings. You can now pass these arguments when
17+
creating aggregator pairs with `create_context_aggregator()`.
18+
1519
- It is now possible to disable `SoundfileMixer` when created. You can then use
1620
`MixerEnableFrame` to dynamically enable it when necessary.
1721

@@ -38,14 +42,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3842
- `DeepgramSTTService` parameter `url` is now deprecated, use `base_url`
3943
instead.
4044

45+
### Removed
46+
47+
- Parameters `user_kwargs` and `assistant_kwargs` when creating a context
48+
aggregator pair using `create_context_aggregator()` have been removed. Use
49+
`user_params` and `assistant_params` instead.
50+
4151
### Fixed
4252

4353
- Fixed a `TavusVideoService` issue that was causing audio choppiness.
4454

4555
- Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the
4656
client did not create a video transceiver.
4757

48-
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
58+
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
4959
unexpected behavior during inference.
5060

5161
## [0.0.63] - 2025-04-11

examples/foundational/22d-natural-conversation-gemini-audio.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
from pipecat.pipeline.pipeline import Pipeline
3434
from pipecat.pipeline.runner import PipelineRunner
3535
from pipecat.pipeline.task import PipelineParams, PipelineTask
36-
from pipecat.processors.aggregators.llm_response import LLMAssistantResponseAggregator
36+
from pipecat.processors.aggregators.llm_response import (
37+
LLMAssistantAggregatorParams,
38+
LLMAssistantResponseAggregator,
39+
)
3740
from pipecat.processors.aggregators.openai_llm_context import (
3841
OpenAILLMContext,
3942
OpenAILLMContextFrame,
@@ -478,7 +481,7 @@ class LLMAggregatorBuffer(LLMAssistantResponseAggregator):
478481
"""Buffers the output of the transcription LLM. Used by the bot output gate."""
479482

480483
def __init__(self, **kwargs):
481-
super().__init__(expect_stripped_words=False)
484+
super().__init__(params=LLMAssistantAggregatorParams(expect_stripped_words=False))
482485
self._transcription = ""
483486

484487
async def process_frame(self, frame: Frame, direction: FrameDirection):

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ fal = [ "fal-client~=0.5.9" ]
5454
fireworks = []
5555
fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
5656
gladia = [ "websockets~=13.1" ]
57-
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ]
57+
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4", "websockets~=13.1" ]
5858
grok = []
5959
groq = [ "groq~=0.20.0" ]
6060
gstreamer = [ "pygobject~=3.50.0" ]

src/pipecat/processors/aggregators/llm_response.py

+65-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import asyncio
88
from abc import abstractmethod
9+
from dataclasses import dataclass
910
from typing import Dict, List, Literal, Set
1011

1112
from loguru import logger
@@ -46,6 +47,16 @@
4647
from pipecat.utils.time import time_now_iso8601
4748

4849

50+
@dataclass
51+
class LLMUserAggregatorParams:
52+
aggregation_timeout: float = 1.0
53+
54+
55+
@dataclass
56+
class LLMAssistantAggregatorParams:
57+
expect_stripped_words: bool = True
58+
59+
4960
class LLMFullResponseAggregator(FrameProcessor):
5061
"""This is an LLM aggregator that aggregates a full LLM completion. It
5162
aggregates LLM text frames (tokens) received between
@@ -230,11 +241,23 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
230241
def __init__(
231242
self,
232243
context: OpenAILLMContext,
233-
aggregation_timeout: float = 1.0,
244+
*,
245+
params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
234246
**kwargs,
235247
):
236248
super().__init__(context=context, role="user", **kwargs)
237-
self._aggregation_timeout = aggregation_timeout
249+
self._params = params
250+
if "aggregation_timeout" in kwargs:
251+
import warnings
252+
253+
with warnings.catch_warnings():
254+
warnings.simplefilter("always")
255+
warnings.warn(
256+
"Parameter 'aggregation_timeout' is deprecated, use 'params' instead.",
257+
DeprecationWarning,
258+
)
259+
260+
self._params.aggregation_timeout = kwargs["aggregation_timeout"]
238261

239262
self._seen_interim_results = False
240263
self._user_speaking = False
@@ -357,7 +380,9 @@ async def _cancel_aggregation_task(self):
357380
async def _aggregation_task_handler(self):
358381
while True:
359382
try:
360-
await asyncio.wait_for(self._aggregation_event.wait(), self._aggregation_timeout)
383+
await asyncio.wait_for(
384+
self._aggregation_event.wait(), self._params.aggregation_timeout
385+
)
361386
await self._maybe_push_bot_interruption()
362387
except asyncio.TimeoutError:
363388
if not self._user_speaking:
@@ -394,9 +419,27 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
394419
395420
"""
396421

397-
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True, **kwargs):
422+
def __init__(
423+
self,
424+
context: OpenAILLMContext,
425+
*,
426+
params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
427+
**kwargs,
428+
):
398429
super().__init__(context=context, role="assistant", **kwargs)
399-
self._expect_stripped_words = expect_stripped_words
430+
self._params = params
431+
432+
if "expect_stripped_words" in kwargs:
433+
import warnings
434+
435+
with warnings.catch_warnings():
436+
warnings.simplefilter("always")
437+
warnings.warn(
438+
"Parameter 'expect_stripped_words' is deprecated, use 'params' instead.",
439+
DeprecationWarning,
440+
)
441+
442+
self._params.expect_stripped_words = kwargs["expect_stripped_words"]
400443

401444
self._started = 0
402445
self._function_calls_in_progress: Dict[str, FunctionCallInProgressFrame] = {}
@@ -558,7 +601,7 @@ async def _handle_text(self, frame: TextFrame):
558601
if not self._started:
559602
return
560603

561-
if self._expect_stripped_words:
604+
if self._params.expect_stripped_words:
562605
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
563606
else:
564607
self._aggregation += frame.text
@@ -572,8 +615,14 @@ def _context_updated_task_finished(self, task: asyncio.Task):
572615

573616

574617
class LLMUserResponseAggregator(LLMUserContextAggregator):
575-
def __init__(self, messages: List[dict] = [], **kwargs):
576-
super().__init__(context=OpenAILLMContext(messages), **kwargs)
618+
def __init__(
619+
self,
620+
messages: List[dict] = [],
621+
*,
622+
params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
623+
**kwargs,
624+
):
625+
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
577626

578627
async def push_aggregation(self):
579628
if len(self._aggregation) > 0:
@@ -588,8 +637,14 @@ async def push_aggregation(self):
588637

589638

590639
class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
591-
def __init__(self, messages: List[dict] = [], **kwargs):
592-
super().__init__(context=OpenAILLMContext(messages), **kwargs)
640+
def __init__(
641+
self,
642+
messages: List[dict] = [],
643+
*,
644+
params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
645+
**kwargs,
646+
):
647+
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
593648

594649
async def push_aggregation(self):
595650
if len(self._aggregation) > 0:

src/pipecat/services/anthropic/llm.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import json
1212
import re
1313
from dataclasses import dataclass
14-
from typing import Any, Dict, List, Mapping, Optional, Union
14+
from typing import Any, Dict, List, Optional, Union
1515

1616
import httpx
1717
from loguru import logger
@@ -35,7 +35,9 @@
3535
)
3636
from pipecat.metrics.metrics import LLMTokenUsage
3737
from pipecat.processors.aggregators.llm_response import (
38+
LLMAssistantAggregatorParams,
3839
LLMAssistantContextAggregator,
40+
LLMUserAggregatorParams,
3941
LLMUserContextAggregator,
4042
)
4143
from pipecat.processors.aggregators.openai_llm_context import (
@@ -49,10 +51,7 @@
4951
from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven
5052
except ModuleNotFoundError as e:
5153
logger.error(f"Exception: {e}")
52-
logger.error(
53-
"In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`. "
54-
+ "Also, set `ANTHROPIC_API_KEY` environment variable."
55-
)
54+
logger.error("In order to use Anthropic, you need to `pip install pipecat-ai[anthropic]`.")
5655
raise Exception(f"Missing module: {e}")
5756

5857

@@ -120,21 +119,19 @@ def create_context_aggregator(
120119
self,
121120
context: OpenAILLMContext,
122121
*,
123-
user_kwargs: Mapping[str, Any] = {},
124-
assistant_kwargs: Mapping[str, Any] = {},
122+
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
123+
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
125124
) -> AnthropicContextAggregatorPair:
126125
"""Create an instance of AnthropicContextAggregatorPair from an
127126
OpenAILLMContext. Constructor keyword arguments for both the user and
128127
assistant aggregators can be provided.
129128
130129
Args:
131130
context (OpenAILLMContext): The LLM context.
132-
user_kwargs (Mapping[str, Any], optional): Additional keyword
133-
arguments for the user context aggregator constructor. Defaults
134-
to an empty mapping.
135-
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
136-
arguments for the assistant context aggregator
137-
constructor. Defaults to an empty mapping.
131+
user_params (LLMUserAggregatorParams, optional): User aggregator
132+
parameters.
133+
assistant_params (LLMAssistantAggregatorParams, optional): User
134+
aggregator parameters.
138135
139136
Returns:
140137
AnthropicContextAggregatorPair: A pair of context aggregators, one
@@ -146,8 +143,8 @@ def create_context_aggregator(
146143

147144
if isinstance(context, OpenAILLMContext):
148145
context = AnthropicLLMContext.from_openai_context(context)
149-
user = AnthropicUserContextAggregator(context, **user_kwargs)
150-
assistant = AnthropicAssistantContextAggregator(context, **assistant_kwargs)
146+
user = AnthropicUserContextAggregator(context, params=user_params)
147+
assistant = AnthropicAssistantContextAggregator(context, params=assistant_params)
151148
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
152149

153150
async def _process_context(self, context: OpenAILLMContext):

src/pipecat/services/gemini_multimodal_live/gemini.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import time
1111
from dataclasses import dataclass
1212
from enum import Enum
13-
from typing import Any, Dict, List, Mapping, Optional, Union
13+
from typing import Any, Dict, List, Optional, Union
1414

15-
import websockets
1615
from loguru import logger
1716
from pydantic import BaseModel, Field
1817

@@ -45,6 +44,10 @@
4544
UserStoppedSpeakingFrame,
4645
)
4746
from pipecat.metrics.metrics import LLMTokenUsage
47+
from pipecat.processors.aggregators.llm_response import (
48+
LLMAssistantAggregatorParams,
49+
LLMUserAggregatorParams,
50+
)
4851
from pipecat.processors.aggregators.openai_llm_context import (
4952
OpenAILLMContext,
5053
OpenAILLMContextFrame,
@@ -61,6 +64,13 @@
6164
from . import events
6265
from .audio_transcriber import AudioTranscriber
6366

67+
try:
68+
import websockets
69+
except ModuleNotFoundError as e:
70+
logger.error(f"Exception: {e}")
71+
logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
72+
raise Exception(f"Missing module: {e}")
73+
6474

6575
def language_to_gemini_language(language: Language) -> Optional[str]:
6676
"""Maps a Language enum value to a Gemini Live supported language code.
@@ -871,21 +881,19 @@ def create_context_aggregator(
871881
self,
872882
context: OpenAILLMContext,
873883
*,
874-
user_kwargs: Mapping[str, Any] = {},
875-
assistant_kwargs: Mapping[str, Any] = {},
884+
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
885+
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
876886
) -> GeminiMultimodalLiveContextAggregatorPair:
877887
"""Create an instance of GeminiMultimodalLiveContextAggregatorPair from
878888
an OpenAILLMContext. Constructor keyword arguments for both the user and
879889
assistant aggregators can be provided.
880890
881891
Args:
882892
context (OpenAILLMContext): The LLM context.
883-
user_kwargs (Mapping[str, Any], optional): Additional keyword
884-
arguments for the user context aggregator constructor. Defaults
885-
to an empty mapping.
886-
assistant_kwargs (Mapping[str, Any], optional): Additional keyword
887-
arguments for the assistant context aggregator
888-
constructor. Defaults to an empty mapping.
893+
user_params (LLMUserAggregatorParams, optional): User aggregator
894+
parameters.
895+
assistant_params (LLMAssistantAggregatorParams, optional): User
896+
aggregator parameters.
889897
890898
Returns:
891899
GeminiMultimodalLiveContextAggregatorPair: A pair of context
@@ -896,11 +904,8 @@ def create_context_aggregator(
896904
context.set_llm_adapter(self.get_llm_adapter())
897905

898906
GeminiMultimodalLiveContext.upgrade(context)
899-
user = GeminiMultimodalLiveUserContextAggregator(context, **user_kwargs)
907+
user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params)
900908

901-
default_assistant_kwargs = {"expect_stripped_words": True}
902-
default_assistant_kwargs.update(assistant_kwargs)
903-
assistant = GeminiMultimodalLiveAssistantContextAggregator(
904-
context, **default_assistant_kwargs
905-
)
909+
assistant_params.expect_stripped_words = True
910+
assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params)
906911
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)

0 commit comments

Comments
 (0)