Skip to content

Commit 16299d1

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Vizier - Fixed pyvizier client study creation errors
PiperOrigin-RevId: 544186919
1 parent 69aaf01 commit 16299d1

File tree

2 files changed

+451
-38
lines changed

2 files changed

+451
-38
lines changed

google/cloud/aiplatform/vizier/pyvizier/proto_converters.py

+39-27
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Converters for OSS Vizier's protos from/to PyVizier's classes."""
2-
import datetime
32
import logging
3+
from datetime import timezone
44
from typing import List, Optional, Sequence, Tuple, Union
55

66
from google.protobuf import duration_pb2
7+
from google.protobuf import struct_pb2
8+
from google.protobuf import timestamp_pb2
79
from google.cloud.aiplatform.compat.types import study as study_pb2
810
from google.cloud.aiplatform.vizier.pyvizier import ScaleType
911
from google.cloud.aiplatform.vizier.pyvizier import ParameterType
@@ -80,8 +82,8 @@ def _set_default_value(
8082
default_value: Union[float, int, str],
8183
):
8284
"""Sets the protos' default_value field."""
83-
which_pv_spec = proto.WhichOneof("parameter_value_spec")
84-
getattr(proto, which_pv_spec).default_value.value = default_value
85+
which_pv_spec = proto._pb.WhichOneof("parameter_value_spec")
86+
getattr(proto, which_pv_spec).default_value = default_value
8587

8688
@classmethod
8789
def _matching_parent_values(
@@ -280,17 +282,16 @@ def to_proto(
280282
cls, parameter_value: ParameterValue, name: str
281283
) -> study_pb2.Trial.Parameter:
282284
"""Returns Parameter Proto."""
283-
proto = study_pb2.Trial.Parameter(parameter_id=name)
284-
285285
if isinstance(parameter_value.value, int):
286-
proto.value.number_value = parameter_value.value
286+
value = struct_pb2.Value(number_value=parameter_value.value)
287287
elif isinstance(parameter_value.value, bool):
288-
proto.value.bool_value = parameter_value.value
288+
value = struct_pb2.Value(bool_value=parameter_value.value)
289289
elif isinstance(parameter_value.value, float):
290-
proto.value.number_value = parameter_value.value
290+
value = struct_pb2.Value(number_value=parameter_value.value)
291291
elif isinstance(parameter_value.value, str):
292-
proto.value.string_value = parameter_value.value
292+
value = struct_pb2.Value(string_value=parameter_value.value)
293293

294+
proto = study_pb2.Trial.Parameter(parameter_id=name, value=value)
294295
return proto
295296

296297

@@ -340,18 +341,19 @@ def from_proto(cls, proto: study_pb2.Measurement) -> Measurement:
340341
@classmethod
341342
def to_proto(cls, measurement: Measurement) -> study_pb2.Measurement:
342343
"""Converts to Measurement proto."""
343-
proto = study_pb2.Measurement()
344+
int_seconds = int(measurement.elapsed_secs)
345+
proto = study_pb2.Measurement(
346+
elapsed_duration=duration_pb2.Duration(
347+
seconds=int_seconds,
348+
nanos=int(1e9 * (measurement.elapsed_secs - int_seconds)),
349+
)
350+
)
344351
for name, metric in measurement.metrics.items():
345352
proto.metrics.append(
346353
study_pb2.Measurement.Metric(metric_id=name, value=metric.value)
347354
)
348355

349356
proto.step_count = measurement.steps
350-
int_seconds = int(measurement.elapsed_secs)
351-
proto.elapsed_duration = duration_pb2.Duration(
352-
seconds=int_seconds,
353-
nanos=int(1e9 * (measurement.elapsed_secs - int_seconds)),
354-
)
355357
return proto
356358

357359

@@ -426,8 +428,11 @@ def from_proto(cls, proto: study_pb2.Trial) -> Trial:
426428
infeasibility_reason = None
427429
if proto.state == study_pb2.Trial.State.SUCCEEDED:
428430
if proto.end_time:
429-
completion_ts = proto.end_time.nanosecond / 1e9
430-
completion_time = datetime.datetime.fromtimestamp(completion_ts)
431+
completion_time = (
432+
proto.end_time.timestamp_pb()
433+
.ToDatetime()
434+
.replace(tzinfo=timezone.utc)
435+
)
431436
elif proto.state == study_pb2.Trial.State.INFEASIBLE:
432437
infeasibility_reason = proto.infeasible_reason
433438

@@ -437,8 +442,11 @@ def from_proto(cls, proto: study_pb2.Trial) -> Trial:
437442

438443
creation_time = None
439444
if proto.start_time:
440-
creation_ts = proto.start_time.nanosecond / 1e9
441-
creation_time = datetime.datetime.fromtimestamp(creation_ts)
445+
creation_time = (
446+
proto.start_time.timestamp_pb()
447+
.ToDatetime()
448+
.replace(tzinfo=timezone.utc)
449+
)
442450
return Trial(
443451
id=int(proto.name.split("/")[-1]),
444452
description=proto.name,
@@ -481,22 +489,26 @@ def to_proto(cls, pytrial: Trial) -> study_pb2.Trial:
481489

482490
# pytrial always adds an empty metric. Ideally, we should remove it if the
483491
# metric does not exist in the study config.
492+
# setattr() is required here as `proto.final_measurement.CopyFrom`
493+
# raises AttributeErrors when setting the field on the pb2 compat types.
484494
if pytrial.final_measurement is not None:
485-
proto.final_measurement.CopyFrom(
486-
MeasurementConverter.to_proto(pytrial.final_measurement)
495+
setattr(
496+
proto,
497+
"final_measurement",
498+
MeasurementConverter.to_proto(pytrial.final_measurement),
487499
)
488500

489501
for measurement in pytrial.measurements:
490502
proto.measurements.append(MeasurementConverter.to_proto(measurement))
491503

492504
if pytrial.creation_time is not None:
493-
creation_secs = datetime.datetime.timestamp(pytrial.creation_time)
494-
proto.start_time.seconds = int(creation_secs)
495-
proto.start_time.nanos = int(1e9 * (creation_secs - int(creation_secs)))
505+
start_time = timestamp_pb2.Timestamp()
506+
start_time.FromDatetime(pytrial.creation_time)
507+
setattr(proto, "start_time", start_time)
496508
if pytrial.completion_time is not None:
497-
completion_secs = datetime.datetime.timestamp(pytrial.completion_time)
498-
proto.end_time.seconds = int(completion_secs)
499-
proto.end_time.nanos = int(1e9 * (completion_secs - int(completion_secs)))
509+
end_time = timestamp_pb2.Timestamp()
510+
end_time.FromDatetime(pytrial.completion_time)
511+
setattr(proto, "end_time", end_time)
500512
if pytrial.infeasibility_reason is not None:
501513
proto.infeasible_reason = pytrial.infeasibility_reason
502514
return proto

0 commit comments

Comments
 (0)