Skip to content

Commit cad035c

Browse files
yeesiancopybara-github
authored andcommitted
feat: Add enable_tracing to LangchainAgent.
PiperOrigin-RevId: 641955580
1 parent a78a35e commit cad035c

File tree

4 files changed

+194
-1
lines changed

4 files changed

+194
-1
lines changed

setup.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@
140140

141141
reasoning_engine_extra_require = [
142142
"cloudpickle >= 2.2.1, < 4.0",
143+
"opentelemetry-sdk < 2",
144+
"opentelemetry-exporter-gcp-trace < 2",
143145
"pydantic >= 2.6.3, < 3",
144146
]
145147

@@ -149,9 +151,10 @@
149151
]
150152

151153
langchain_extra_require = [
152-
"langchain >= 0.1.16, < 0.2",
154+
"langchain >= 0.1.16, < 0.3",
153155
"langchain-core < 0.2",
154156
"langchain-google-vertexai < 2",
157+
"openinference-instrumentation-langchain >= 0.1.19, < 0.2",
155158
]
156159

157160
langchain_testing_extra_require = list(

tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py

+78
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vertexai.preview import reasoning_engines
2424
from vertexai.preview.generative_models import grounding
2525
from vertexai.generative_models import Tool
26+
from vertexai.reasoning_engines import _utils
2627
import pytest
2728

2829

@@ -89,6 +90,48 @@ def mock_chatvertexai():
8990
yield model_mock
9091

9192

93+
@pytest.fixture
94+
def cloud_trace_exporter_mock():
95+
with mock.patch.object(
96+
_utils,
97+
"_import_cloud_trace_exporter_or_warn",
98+
) as cloud_trace_exporter_mock:
99+
yield cloud_trace_exporter_mock
100+
101+
102+
@pytest.fixture
103+
def tracer_provider_mock():
104+
with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock:
105+
yield tracer_provider_mock
106+
107+
108+
@pytest.fixture
109+
def simple_span_processor_mock():
110+
with mock.patch(
111+
"opentelemetry.sdk.trace.export.SimpleSpanProcessor"
112+
) as simple_span_processor_mock:
113+
yield simple_span_processor_mock
114+
115+
116+
@pytest.fixture
117+
def langchain_instrumentor_mock():
118+
with mock.patch.object(
119+
_utils,
120+
"_import_openinference_langchain_or_warn",
121+
) as langchain_instrumentor_mock:
122+
yield langchain_instrumentor_mock
123+
124+
125+
@pytest.fixture
126+
def langchain_instrumentor_none_mock():
127+
with mock.patch.object(
128+
_utils,
129+
"_import_openinference_langchain_or_warn",
130+
) as langchain_instrumentor_mock:
131+
langchain_instrumentor_mock.return_value = None
132+
yield langchain_instrumentor_mock
133+
134+
92135
@pytest.mark.usefixtures("google_auth_mock")
93136
class TestLangchainAgent:
94137
def setup_method(self):
@@ -175,6 +218,41 @@ def test_query(self, langchain_dump_mock):
175218
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
176219
)
177220

221+
@pytest.mark.usefixtures("caplog")
222+
def test_enable_tracing(
223+
self,
224+
caplog,
225+
cloud_trace_exporter_mock,
226+
tracer_provider_mock,
227+
simple_span_processor_mock,
228+
langchain_instrumentor_mock,
229+
):
230+
agent = reasoning_engines.LangchainAgent(
231+
model=_TEST_MODEL,
232+
prompt=self.prompt,
233+
output_parser=self.output_parser,
234+
enable_tracing=True,
235+
)
236+
assert agent._instrumentor is None
237+
agent.set_up()
238+
assert agent._instrumentor is not None
239+
assert (
240+
"enable_tracing=True but proceeding with tracing disabled"
241+
not in caplog.text
242+
)
243+
244+
@pytest.mark.usefixtures("caplog")
245+
def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock):
246+
agent = reasoning_engines.LangchainAgent(
247+
model=_TEST_MODEL,
248+
prompt=self.prompt,
249+
output_parser=self.output_parser,
250+
enable_tracing=True,
251+
)
252+
assert agent._instrumentor is None
253+
agent.set_up()
254+
assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text
255+
178256

179257
class TestConvertToolsOrRaise:
180258
def test_convert_tools_or_raise(self, vertexai_init_mock):

vertexai/preview/reasoning_engines/templates/langchain.py

+45
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def __init__(
236236
runnable_kwargs: Optional[Mapping[str, Any]] = None,
237237
model_builder: Optional[Callable] = None,
238238
runnable_builder: Optional[Callable] = None,
239+
enable_tracing: bool = False,
239240
):
240241
"""Initializes the LangchainAgent.
241242
@@ -349,6 +350,9 @@ def __init__(
349350
for customizing the orchestration logic of the Agent based on
350351
the model returned by `model_builder` and the rest of the input
351352
arguments.
353+
enable_tracing (bool):
354+
Optional. Whether to enable tracing in Cloud Trace. Defaults to
355+
False.
352356
353357
Raises:
354358
TypeError: If there is an invalid tool (e.g. function with an input
@@ -376,6 +380,8 @@ def __init__(
376380
self._model_builder = model_builder
377381
self._runnable = None
378382
self._runnable_builder = runnable_builder
383+
self._instrumentor = None
384+
self._enable_tracing = enable_tracing
379385

380386
def set_up(self):
381387
"""Sets up the agent for execution of queries at runtime.
@@ -387,6 +393,44 @@ def set_up(self):
387393
the ReasoningEngine service for deployment, as it initializes clients
388394
that can not be serialized.
389395
"""
396+
if self._enable_tracing:
397+
from vertexai.reasoning_engines import _utils
398+
399+
cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn()
400+
openinference_langchain = _utils._import_openinference_langchain_or_warn()
401+
opentelemetry = _utils._import_opentelemetry_or_warn()
402+
opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn()
403+
if all(
404+
(
405+
cloud_trace_exporter,
406+
openinference_langchain,
407+
opentelemetry,
408+
opentelemetry_sdk_trace,
409+
)
410+
):
411+
tracer_provider = opentelemetry.trace.get_tracer_provider()
412+
if tracer_provider and _utils._is_noop_tracer_provider(tracer_provider):
413+
# Set a trace provider if it has not been set.
414+
span_exporter = cloud_trace_exporter.CloudTraceSpanExporter(
415+
project_id=self._project,
416+
)
417+
span_processor = opentelemetry_sdk_trace.export.SimpleSpanProcessor(
418+
span_exporter=span_exporter,
419+
)
420+
tracer_provider = opentelemetry_sdk_trace.TracerProvider(
421+
active_span_processor=span_processor,
422+
)
423+
opentelemetry.trace.set_tracer_provider(tracer_provider)
424+
self._instrumentor = openinference_langchain.LangChainInstrumentor()
425+
self._instrumentor.instrument()
426+
else:
427+
from google.cloud.aiplatform import base
428+
429+
_LOGGER = base.Logger(__name__)
430+
_LOGGER.warning(
431+
"enable_tracing=True but proceeding with tracing disabled "
432+
"because not all packages for tracing have been installed"
433+
)
390434
model_builder = self._model_builder or _default_model_builder
391435
self._model = model_builder(
392436
model_name=self._model_name,
@@ -422,6 +466,7 @@ def clone(self) -> "LangchainAgent":
422466
runnable_kwargs=copy.deepcopy(self._runnable_kwargs),
423467
model_builder=self._model_builder,
424468
runnable_builder=self._runnable_builder,
469+
enable_tracing=self._enable_tracing,
425470
)
426471

427472
def query(

vertexai/reasoning_engines/_utils.py

+67
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import proto
2323

24+
from google.cloud.aiplatform import base
2425
from google.protobuf import struct_pb2
2526
from google.protobuf import json_format
2627

@@ -36,6 +37,8 @@
3637

3738
JsonDict = Dict[str, Any]
3839

40+
_LOGGER = base.Logger(__name__)
41+
3942

4043
def to_proto(
4144
obj: Union[JsonDict, proto.Message],
@@ -195,6 +198,14 @@ def generate_schema(
195198
return schema
196199

197200

201+
def _is_noop_tracer_provider(tracer_provider) -> bool:
202+
"""Returns True if the tracer_provider is Proxy or NoOp."""
203+
opentelemetry = _import_opentelemetry_or_warn()
204+
ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider
205+
NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider
206+
return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider))
207+
208+
198209
def _import_cloud_storage_or_raise() -> types.ModuleType:
199210
"""Tries to import the Cloud Storage module."""
200211
try:
@@ -233,3 +244,59 @@ def _import_pydantic_or_raise() -> types.ModuleType:
233244
"'pip install google-cloud-aiplatform[reasoningengine]'."
234245
) from e
235246
return pydantic
247+
248+
249+
def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]:
250+
"""Tries to import the opentelemetry module."""
251+
try:
252+
import opentelemetry # noqa:F401
253+
254+
return opentelemetry
255+
except ImportError:
256+
_LOGGER.warning(
257+
"opentelemetry-sdk is not installed. Please call "
258+
"'pip install google-cloud-aiplatform[reasoningengine]'."
259+
)
260+
return None
261+
262+
263+
def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]:
264+
"""Tries to import the opentelemetry.sdk.trace module."""
265+
try:
266+
import opentelemetry.sdk.trace # noqa:F401
267+
268+
return opentelemetry.sdk.trace
269+
except ImportError:
270+
_LOGGER.warning(
271+
"opentelemetry-sdk is not installed. Please call "
272+
"'pip install google-cloud-aiplatform[reasoningengine]'."
273+
)
274+
return None
275+
276+
277+
def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]:
278+
"""Tries to import the opentelemetry.exporter.cloud_trace module."""
279+
try:
280+
import opentelemetry.exporter.cloud_trace # noqa:F401
281+
282+
return opentelemetry.exporter.cloud_trace
283+
except ImportError:
284+
_LOGGER.warning(
285+
"opentelemetry-exporter-gcp-trace is not installed. Please "
286+
"call 'pip install google-cloud-aiplatform[langchain]'."
287+
)
288+
return None
289+
290+
291+
def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]:
292+
"""Tries to import the openinference.instrumentation.langchain module."""
293+
try:
294+
import openinference.instrumentation.langchain # noqa:F401
295+
296+
return openinference.instrumentation.langchain
297+
except ImportError:
298+
_LOGGER.warning(
299+
"openinference-instrumentation-langchain is not installed. Please "
300+
"call 'pip install google-cloud-aiplatform[langchain]'."
301+
)
302+
return None

0 commit comments

Comments
 (0)