1
1
"""Converters for OSS Vizier's protos from/to PyVizier's classes."""
2
- import datetime
3
2
import logging
3
+ from datetime import timezone
4
4
from typing import List , Optional , Sequence , Tuple , Union
5
5
6
6
from google .protobuf import duration_pb2
7
+ from google .protobuf import struct_pb2
8
+ from google .protobuf import timestamp_pb2
7
9
from google .cloud .aiplatform .compat .types import study as study_pb2
8
10
from google .cloud .aiplatform .vizier .pyvizier import ScaleType
9
11
from google .cloud .aiplatform .vizier .pyvizier import ParameterType
@@ -80,8 +82,8 @@ def _set_default_value(
80
82
default_value : Union [float , int , str ],
81
83
):
82
84
"""Sets the protos' default_value field."""
83
- which_pv_spec = proto .WhichOneof ("parameter_value_spec" )
84
- getattr (proto , which_pv_spec ).default_value . value = default_value
85
+ which_pv_spec = proto ._pb . WhichOneof ("parameter_value_spec" )
86
+ getattr (proto , which_pv_spec ).default_value = default_value
85
87
86
88
@classmethod
87
89
def _matching_parent_values (
@@ -280,17 +282,16 @@ def to_proto(
280
282
cls , parameter_value : ParameterValue , name : str
281
283
) -> study_pb2 .Trial .Parameter :
282
284
"""Returns Parameter Proto."""
283
- proto = study_pb2 .Trial .Parameter (parameter_id = name )
284
-
285
285
if isinstance (parameter_value .value , int ):
286
- proto . value . number_value = parameter_value .value
286
+ value = struct_pb2 . Value ( number_value = parameter_value .value )
287
287
elif isinstance (parameter_value .value , bool ):
288
- proto . value . bool_value = parameter_value .value
288
+ value = struct_pb2 . Value ( bool_value = parameter_value .value )
289
289
elif isinstance (parameter_value .value , float ):
290
- proto . value . number_value = parameter_value .value
290
+ value = struct_pb2 . Value ( number_value = parameter_value .value )
291
291
elif isinstance (parameter_value .value , str ):
292
- proto . value . string_value = parameter_value .value
292
+ value = struct_pb2 . Value ( string_value = parameter_value .value )
293
293
294
+ proto = study_pb2 .Trial .Parameter (parameter_id = name , value = value )
294
295
return proto
295
296
296
297
@@ -340,18 +341,19 @@ def from_proto(cls, proto: study_pb2.Measurement) -> Measurement:
340
341
@classmethod
341
342
def to_proto (cls , measurement : Measurement ) -> study_pb2 .Measurement :
342
343
"""Converts to Measurement proto."""
343
- proto = study_pb2 .Measurement ()
344
+ int_seconds = int (measurement .elapsed_secs )
345
+ proto = study_pb2 .Measurement (
346
+ elapsed_duration = duration_pb2 .Duration (
347
+ seconds = int_seconds ,
348
+ nanos = int (1e9 * (measurement .elapsed_secs - int_seconds )),
349
+ )
350
+ )
344
351
for name , metric in measurement .metrics .items ():
345
352
proto .metrics .append (
346
353
study_pb2 .Measurement .Metric (metric_id = name , value = metric .value )
347
354
)
348
355
349
356
proto .step_count = measurement .steps
350
- int_seconds = int (measurement .elapsed_secs )
351
- proto .elapsed_duration = duration_pb2 .Duration (
352
- seconds = int_seconds ,
353
- nanos = int (1e9 * (measurement .elapsed_secs - int_seconds )),
354
- )
355
357
return proto
356
358
357
359
@@ -426,8 +428,11 @@ def from_proto(cls, proto: study_pb2.Trial) -> Trial:
426
428
infeasibility_reason = None
427
429
if proto .state == study_pb2 .Trial .State .SUCCEEDED :
428
430
if proto .end_time :
429
- completion_ts = proto .end_time .nanosecond / 1e9
430
- completion_time = datetime .datetime .fromtimestamp (completion_ts )
431
+ completion_time = (
432
+ proto .end_time .timestamp_pb ()
433
+ .ToDatetime ()
434
+ .replace (tzinfo = timezone .utc )
435
+ )
431
436
elif proto .state == study_pb2 .Trial .State .INFEASIBLE :
432
437
infeasibility_reason = proto .infeasible_reason
433
438
@@ -437,8 +442,11 @@ def from_proto(cls, proto: study_pb2.Trial) -> Trial:
437
442
438
443
creation_time = None
439
444
if proto .start_time :
440
- creation_ts = proto .start_time .nanosecond / 1e9
441
- creation_time = datetime .datetime .fromtimestamp (creation_ts )
445
+ creation_time = (
446
+ proto .start_time .timestamp_pb ()
447
+ .ToDatetime ()
448
+ .replace (tzinfo = timezone .utc )
449
+ )
442
450
return Trial (
443
451
id = int (proto .name .split ("/" )[- 1 ]),
444
452
description = proto .name ,
@@ -481,22 +489,26 @@ def to_proto(cls, pytrial: Trial) -> study_pb2.Trial:
481
489
482
490
# pytrial always adds an empty metric. Ideally, we should remove it if the
483
491
# metric does not exist in the study config.
492
+ # setattr() is required here as `proto.final_measurement.CopyFrom`
493
+ # raises AttributeErrors when setting the field on the pb2 compat types.
484
494
if pytrial .final_measurement is not None :
485
- proto .final_measurement .CopyFrom (
486
- MeasurementConverter .to_proto (pytrial .final_measurement )
495
+ setattr (
496
+ proto ,
497
+ "final_measurement" ,
498
+ MeasurementConverter .to_proto (pytrial .final_measurement ),
487
499
)
488
500
489
501
for measurement in pytrial .measurements :
490
502
proto .measurements .append (MeasurementConverter .to_proto (measurement ))
491
503
492
504
if pytrial .creation_time is not None :
493
- creation_secs = datetime . datetime . timestamp ( pytrial . creation_time )
494
- proto . start_time .seconds = int ( creation_secs )
495
- proto . start_time . nanos = int ( 1e9 * ( creation_secs - int ( creation_secs )) )
505
+ start_time = timestamp_pb2 . Timestamp ( )
506
+ start_time .FromDatetime ( pytrial . creation_time )
507
+ setattr ( proto , " start_time" , start_time )
496
508
if pytrial .completion_time is not None :
497
- completion_secs = datetime . datetime . timestamp ( pytrial . completion_time )
498
- proto . end_time .seconds = int ( completion_secs )
499
- proto . end_time . nanos = int ( 1e9 * ( completion_secs - int ( completion_secs )) )
509
+ end_time = timestamp_pb2 . Timestamp ( )
510
+ end_time .FromDatetime ( pytrial . completion_time )
511
+ setattr ( proto , " end_time" , end_time )
500
512
if pytrial .infeasibility_reason is not None :
501
513
proto .infeasible_reason = pytrial .infeasibility_reason
502
514
return proto
0 commit comments