Skip to content

Commit d4fc734

Browse files
authored
core[patch]: update dict prompt template (#30967)
Align with JS changes made in langchain-ai/langchainjs#8043
1 parent 4bc7076 commit d4fc734

File tree

11 files changed

+230
-318
lines changed

11 files changed

+230
-318
lines changed

libs/community/tests/unit_tests/load/test_serializable.py

-155
This file was deleted.

libs/core/langchain_core/load/mapping.py

+6
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,12 @@
540540
"chat_models",
541541
"ChatSambaStudio",
542542
),
543+
("langchain_core", "prompts", "message", "_DictMessagePromptTemplate"): (
544+
"langchain_core",
545+
"prompts",
546+
"dict",
547+
"DictPromptTemplate",
548+
),
543549
}
544550

545551
# Needed for backwards compatibility for old versions of LangChain where things

libs/core/langchain_core/prompts/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
MessagesPlaceholder,
4545
SystemMessagePromptTemplate,
4646
)
47+
from langchain_core.prompts.dict import DictPromptTemplate
4748
from langchain_core.prompts.few_shot import (
4849
FewShotChatMessagePromptTemplate,
4950
FewShotPromptTemplate,
@@ -68,6 +69,7 @@
6869
"BasePromptTemplate",
6970
"ChatMessagePromptTemplate",
7071
"ChatPromptTemplate",
72+
"DictPromptTemplate",
7173
"FewShotPromptTemplate",
7274
"FewShotPromptWithTemplates",
7375
"FewShotChatMessagePromptTemplate",
@@ -94,6 +96,7 @@
9496
"BaseChatPromptTemplate": "chat",
9597
"ChatMessagePromptTemplate": "chat",
9698
"ChatPromptTemplate": "chat",
99+
"DictPromptTemplate": "dict",
97100
"HumanMessagePromptTemplate": "chat",
98101
"MessagesPlaceholder": "chat",
99102
"SystemMessagePromptTemplate": "chat",

libs/core/langchain_core/prompts/chat.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737
from langchain_core.messages.base import get_msg_title_repr
3838
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
3939
from langchain_core.prompts.base import BasePromptTemplate
40+
from langchain_core.prompts.dict import DictPromptTemplate
4041
from langchain_core.prompts.image import ImagePromptTemplate
4142
from langchain_core.prompts.message import (
4243
BaseMessagePromptTemplate,
43-
_DictMessagePromptTemplate,
4444
)
4545
from langchain_core.prompts.prompt import PromptTemplate
4646
from langchain_core.prompts.string import (
@@ -396,9 +396,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
396396

397397
prompt: Union[
398398
StringPromptTemplate,
399-
list[
400-
Union[StringPromptTemplate, ImagePromptTemplate, _DictMessagePromptTemplate]
401-
],
399+
list[Union[StringPromptTemplate, ImagePromptTemplate, DictPromptTemplate]],
402400
]
403401
"""Prompt template."""
404402
additional_kwargs: dict = Field(default_factory=dict)
@@ -447,7 +445,12 @@ def from_template(
447445
raise ValueError(msg)
448446
prompt = []
449447
for tmpl in template:
450-
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
448+
if (
449+
isinstance(tmpl, str)
450+
or isinstance(tmpl, dict)
451+
and "text" in tmpl
452+
and set(tmpl.keys()) <= {"type", "text"}
453+
):
451454
if isinstance(tmpl, str):
452455
text: str = tmpl
453456
else:
@@ -457,7 +460,15 @@ def from_template(
457460
text, template_format=template_format
458461
)
459462
)
460-
elif isinstance(tmpl, dict) and "image_url" in tmpl:
463+
elif (
464+
isinstance(tmpl, dict)
465+
and "image_url" in tmpl
466+
and set(tmpl.keys())
467+
<= {
468+
"type",
469+
"image_url",
470+
}
471+
):
461472
img_template = cast("_ImageTemplateParam", tmpl)["image_url"]
462473
input_variables = []
463474
if isinstance(img_template, str):
@@ -503,7 +514,7 @@ def from_template(
503514
"format."
504515
)
505516
raise ValueError(msg)
506-
data_template_obj = _DictMessagePromptTemplate(
517+
data_template_obj = DictPromptTemplate(
507518
template=cast("dict[str, Any]", tmpl),
508519
template_format=template_format,
509520
)
@@ -592,7 +603,7 @@ def format(self, **kwargs: Any) -> BaseMessage:
592603
elif isinstance(prompt, ImagePromptTemplate):
593604
formatted = prompt.format(**inputs)
594605
content.append({"type": "image_url", "image_url": formatted})
595-
elif isinstance(prompt, _DictMessagePromptTemplate):
606+
elif isinstance(prompt, DictPromptTemplate):
596607
formatted = prompt.format(**inputs)
597608
content.append(formatted)
598609
return self._msg_class(
@@ -624,7 +635,7 @@ async def aformat(self, **kwargs: Any) -> BaseMessage:
624635
elif isinstance(prompt, ImagePromptTemplate):
625636
formatted = await prompt.aformat(**inputs)
626637
content.append({"type": "image_url", "image_url": formatted})
627-
elif isinstance(prompt, _DictMessagePromptTemplate):
638+
elif isinstance(prompt, DictPromptTemplate):
628639
formatted = prompt.format(**inputs)
629640
content.append(formatted)
630641
return self._msg_class(

0 commit comments

Comments
 (0)