37
37
from langchain_core .messages .base import get_msg_title_repr
38
38
from langchain_core .prompt_values import ChatPromptValue , ImageURL , PromptValue
39
39
from langchain_core .prompts .base import BasePromptTemplate
40
+ from langchain_core .prompts .dict import DictPromptTemplate
40
41
from langchain_core .prompts .image import ImagePromptTemplate
41
42
from langchain_core .prompts .message import (
42
43
BaseMessagePromptTemplate ,
43
- _DictMessagePromptTemplate ,
44
44
)
45
45
from langchain_core .prompts .prompt import PromptTemplate
46
46
from langchain_core .prompts .string import (
@@ -396,9 +396,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
396
396
397
397
prompt : Union [
398
398
StringPromptTemplate ,
399
- list [
400
- Union [StringPromptTemplate , ImagePromptTemplate , _DictMessagePromptTemplate ]
401
- ],
399
+ list [Union [StringPromptTemplate , ImagePromptTemplate , DictPromptTemplate ]],
402
400
]
403
401
"""Prompt template."""
404
402
additional_kwargs : dict = Field (default_factory = dict )
@@ -447,7 +445,12 @@ def from_template(
447
445
raise ValueError (msg )
448
446
prompt = []
449
447
for tmpl in template :
450
- if isinstance (tmpl , str ) or isinstance (tmpl , dict ) and "text" in tmpl :
448
+ if (
449
+ isinstance (tmpl , str )
450
+ or isinstance (tmpl , dict )
451
+ and "text" in tmpl
452
+ and set (tmpl .keys ()) <= {"type" , "text" }
453
+ ):
451
454
if isinstance (tmpl , str ):
452
455
text : str = tmpl
453
456
else :
@@ -457,7 +460,15 @@ def from_template(
457
460
text , template_format = template_format
458
461
)
459
462
)
460
- elif isinstance (tmpl , dict ) and "image_url" in tmpl :
463
+ elif (
464
+ isinstance (tmpl , dict )
465
+ and "image_url" in tmpl
466
+ and set (tmpl .keys ())
467
+ <= {
468
+ "type" ,
469
+ "image_url" ,
470
+ }
471
+ ):
461
472
img_template = cast ("_ImageTemplateParam" , tmpl )["image_url" ]
462
473
input_variables = []
463
474
if isinstance (img_template , str ):
@@ -503,7 +514,7 @@ def from_template(
503
514
"format."
504
515
)
505
516
raise ValueError (msg )
506
- data_template_obj = _DictMessagePromptTemplate (
517
+ data_template_obj = DictPromptTemplate (
507
518
template = cast ("dict[str, Any]" , tmpl ),
508
519
template_format = template_format ,
509
520
)
@@ -592,7 +603,7 @@ def format(self, **kwargs: Any) -> BaseMessage:
592
603
elif isinstance (prompt , ImagePromptTemplate ):
593
604
formatted = prompt .format (** inputs )
594
605
content .append ({"type" : "image_url" , "image_url" : formatted })
595
- elif isinstance (prompt , _DictMessagePromptTemplate ):
606
+ elif isinstance (prompt , DictPromptTemplate ):
596
607
formatted = prompt .format (** inputs )
597
608
content .append (formatted )
598
609
return self ._msg_class (
@@ -624,7 +635,7 @@ async def aformat(self, **kwargs: Any) -> BaseMessage:
624
635
elif isinstance (prompt , ImagePromptTemplate ):
625
636
formatted = await prompt .aformat (** inputs )
626
637
content .append ({"type" : "image_url" , "image_url" : formatted })
627
- elif isinstance (prompt , _DictMessagePromptTemplate ):
638
+ elif isinstance (prompt , DictPromptTemplate ):
628
639
formatted = prompt .format (** inputs )
629
640
content .append (formatted )
630
641
return self ._msg_class (
0 commit comments