Skip to content

Commit 8960a80

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add a cloneable protocol for Reasoning Engine.
PiperOrigin-RevId: 638428830
1 parent 3b83ba9 commit 8960a80

File tree

4 files changed

+47
-3
lines changed

4 files changed

+47
-3
lines changed

tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py

+14
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,20 @@ def test_set_up(self, vertexai_init_mock):
135135
agent.set_up()
136136
assert agent._runnable is not None
137137

138+
def test_clone(self, vertexai_init_mock):
139+
agent = reasoning_engines.LangchainAgent(
140+
model=_TEST_MODEL,
141+
prompt=self.prompt,
142+
output_parser=self.output_parser,
143+
)
144+
agent.set_up()
145+
assert agent._runnable is not None
146+
agent_clone = agent.clone()
147+
assert agent._runnable is not None
148+
assert agent_clone._runnable is None
149+
agent_clone.set_up()
150+
assert agent_clone._runnable is not None
151+
138152
def test_query(self, langchain_dump_mock):
139153
agent = reasoning_engines.LangchainAgent(
140154
model=_TEST_MODEL,

tests/unit/vertex_langchain/test_reasoning_engines.py

+3
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def query(self, unused_arbitrary_string_name: str) -> str:
4646
"""Runs the engine."""
4747
return unused_arbitrary_string_name.upper()
4848

49+
def clone(self):
50+
return self
51+
4952

5053
_TEST_RETRY = base._DEFAULT_RETRY
5154
_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials())

vertexai/preview/reasoning_engines/templates/langchain.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
import json
1716
from typing import (
1817
TYPE_CHECKING,
1918
Any,
2019
Callable,
2120
Dict,
22-
List,
2321
Mapping,
2422
Optional,
2523
Sequence,
26-
Tuple,
2724
Union,
2825
)
2926

@@ -390,6 +387,24 @@ def set_up(self):
390387
runnable_kwargs=self._runnable_kwargs,
391388
)
392389

390+
def clone(self) -> "LangchainAgent":
391+
"""Returns a clone of the LangchainAgent."""
392+
import copy
393+
394+
return LangchainAgent(
395+
model=self._model_name,
396+
prompt=copy.deepcopy(self._prompt),
397+
tools=copy.deepcopy(self._tools),
398+
output_parser=copy.deepcopy(self._output_parser),
399+
chat_history=copy.deepcopy(self._chat_history),
400+
model_kwargs=copy.deepcopy(self._model_kwargs),
401+
model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs),
402+
agent_executor_kwargs=copy.deepcopy(self._agent_executor_kwargs),
403+
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
404+
model_builder=self._model_builder,
405+
runnable_builder=self._runnable_builder,
406+
)
407+
393408
def query(
394409
self,
395410
*,

vertexai/reasoning_engines/_reasoning_engines.py

+12
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ def query(self, **kwargs):
4646
"""Runs the Reasoning Engine to serve the user query."""
4747

4848

49+
@typing.runtime_checkable
50+
class Cloneable(Protocol):
51+
"""Protocol for Reasoning Engine applications that can be cloned."""
52+
53+
@abc.abstractmethod
54+
def clone(self):
55+
"""Return a clone of the object."""
56+
57+
4958
class ReasoningEngine(base.VertexAiResourceNounWithFutureManager, Queryable):
5059
"""Represents a Vertex AI Reasoning Engine resource."""
5160

@@ -214,6 +223,9 @@ def create(
214223
"Invalid query signature. This might be due to a missing "
215224
"`self` argument in the reasoning_engine.query method."
216225
) from err
226+
if isinstance(reasoning_engine, Cloneable):
227+
# Avoid undeployable ReasoningChain states.
228+
reasoning_engine = reasoning_engine.clone()
217229
if isinstance(requirements, str):
218230
try:
219231
_LOGGER.info(f"Reading requirements from {requirements=}")

0 commit comments

Comments
 (0)