@@ -1615,6 +1615,7 @@ def _prepare_request(
1615
1615
top_k : Optional [int ] = None ,
1616
1616
top_p : Optional [float ] = None ,
1617
1617
stop_sequences : Optional [List [str ]] = None ,
1618
+ candidate_count : Optional [int ] = None ,
1618
1619
) -> _PredictionRequest :
1619
1620
"""Prepares a request for the language model.
1620
1621
@@ -1629,6 +1630,7 @@ def _prepare_request(
1629
1630
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1630
1631
Uses the value specified when calling `ChatModel.start_chat` by default.
1631
1632
stop_sequences: Customized stop sequences to stop the decoding process.
1633
+ candidate_count: Number of candidates to return.
1632
1634
1633
1635
Returns:
1634
1636
A `_PredictionRequest` object.
@@ -1660,6 +1662,9 @@ def _prepare_request(
1660
1662
if stop_sequences :
1661
1663
prediction_parameters ["stopSequences" ] = stop_sequences
1662
1664
1665
+ if candidate_count is not None :
1666
+ prediction_parameters ["candidateCount" ] = candidate_count
1667
+
1663
1668
message_structs = []
1664
1669
for past_message in self ._message_history :
1665
1670
message_structs .append (
@@ -1697,8 +1702,7 @@ def _parse_chat_prediction_response(
1697
1702
cls ,
1698
1703
prediction_response : aiplatform .models .Prediction ,
1699
1704
prediction_idx : int = 0 ,
1700
- candidate_idx : int = 0 ,
1701
- ) -> TextGenerationResponse :
1705
+ ) -> MultiCandidateTextGenerationResponse :
1702
1706
"""Parses prediction response for chat models.
1703
1707
1704
1708
Args:
@@ -1707,25 +1711,33 @@ def _parse_chat_prediction_response(
1707
1711
candidate_idx: Index of the candidate to parse.
1708
1712
1709
1713
Returns:
1710
- A `TextGenerationResponse ` object.
1714
+ A `MultiCandidateTextGenerationResponse ` object.
1711
1715
"""
1712
1716
prediction = prediction_response .predictions [prediction_idx ]
1713
- # ! Note: For chat models, the safetyAttributes is a list.
1714
- safety_attributes = prediction ["safetyAttributes" ][candidate_idx ]
1715
- return TextGenerationResponse (
1716
- text = prediction ["candidates" ][candidate_idx ]["content" ]
1717
- if prediction .get ("candidates" )
1718
- else None ,
1717
+ candidate_count = len (prediction ["candidates" ])
1718
+ candidates = []
1719
+ for candidate_idx in range (candidate_count ):
1720
+ safety_attributes = prediction ["safetyAttributes" ][candidate_idx ]
1721
+ candidate_response = TextGenerationResponse (
1722
+ text = prediction ["candidates" ][candidate_idx ]["content" ],
1723
+ _prediction_response = prediction_response ,
1724
+ is_blocked = safety_attributes .get ("blocked" , False ),
1725
+ safety_attributes = dict (
1726
+ zip (
1727
+ # Unlike with normal prediction, in streaming prediction
1728
+ # categories and scores can be None
1729
+ safety_attributes .get ("categories" ) or [],
1730
+ safety_attributes .get ("scores" ) or [],
1731
+ )
1732
+ ),
1733
+ )
1734
+ candidates .append (candidate_response )
1735
+ return MultiCandidateTextGenerationResponse (
1736
+ text = candidates [0 ].text ,
1719
1737
_prediction_response = prediction_response ,
1720
- is_blocked = safety_attributes .get ("blocked" , False ),
1721
- safety_attributes = dict (
1722
- zip (
1723
- # Unlike with normal prediction, in streaming prediction
1724
- # categories and scores can be None
1725
- safety_attributes .get ("categories" ) or [],
1726
- safety_attributes .get ("scores" ) or [],
1727
- )
1728
- ),
1738
+ is_blocked = candidates [0 ].is_blocked ,
1739
+ safety_attributes = candidates [0 ].safety_attributes ,
1740
+ candidates = candidates ,
1729
1741
)
1730
1742
1731
1743
def send_message (
@@ -1737,7 +1749,8 @@ def send_message(
1737
1749
top_k : Optional [int ] = None ,
1738
1750
top_p : Optional [float ] = None ,
1739
1751
stop_sequences : Optional [List [str ]] = None ,
1740
- ) -> "TextGenerationResponse" :
1752
+ candidate_count : Optional [int ] = None ,
1753
+ ) -> "MultiCandidateTextGenerationResponse" :
1741
1754
"""Sends message to the language model and gets a response.
1742
1755
1743
1756
Args:
@@ -1751,9 +1764,11 @@ def send_message(
1751
1764
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1752
1765
Uses the value specified when calling `ChatModel.start_chat` by default.
1753
1766
stop_sequences: Customized stop sequences to stop the decoding process.
1767
+ candidate_count: Number of candidates to return.
1754
1768
1755
1769
Returns:
1756
- A `TextGenerationResponse` object that contains the text produced by the model.
1770
+ A `MultiCandidateTextGenerationResponse` object that contains the
1771
+ text produced by the model.
1757
1772
"""
1758
1773
prediction_request = self ._prepare_request (
1759
1774
message = message ,
@@ -1762,6 +1777,7 @@ def send_message(
1762
1777
top_k = top_k ,
1763
1778
top_p = top_p ,
1764
1779
stop_sequences = stop_sequences ,
1780
+ candidate_count = candidate_count ,
1765
1781
)
1766
1782
1767
1783
prediction_response = self ._model ._endpoint .predict (
@@ -1791,7 +1807,8 @@ async def send_message_async(
1791
1807
top_k : Optional [int ] = None ,
1792
1808
top_p : Optional [float ] = None ,
1793
1809
stop_sequences : Optional [List [str ]] = None ,
1794
- ) -> "TextGenerationResponse" :
1810
+ candidate_count : Optional [int ] = None ,
1811
+ ) -> "MultiCandidateTextGenerationResponse" :
1795
1812
"""Asynchronously sends message to the language model and gets a response.
1796
1813
1797
1814
Args:
@@ -1805,9 +1822,11 @@ async def send_message_async(
1805
1822
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Range: [0, 1]. Default: 0.95.
1806
1823
Uses the value specified when calling `ChatModel.start_chat` by default.
1807
1824
stop_sequences: Customized stop sequences to stop the decoding process.
1825
+ candidate_count: Number of candidates to return.
1808
1826
1809
1827
Returns:
1810
- A `TextGenerationResponse` object that contains the text produced by the model.
1828
+ A `MultiCandidateTextGenerationResponse` object that contains
1829
+ the text produced by the model.
1811
1830
"""
1812
1831
prediction_request = self ._prepare_request (
1813
1832
message = message ,
@@ -1816,6 +1835,7 @@ async def send_message_async(
1816
1835
top_k = top_k ,
1817
1836
top_p = top_p ,
1818
1837
stop_sequences = stop_sequences ,
1838
+ candidate_count = candidate_count ,
1819
1839
)
1820
1840
1821
1841
prediction_response = await self ._model ._endpoint .predict_async (
0 commit comments