@@ -758,6 +758,124 @@ def reverse_string_2(s):""",
758
758
"pipelineSpec" : json .loads (_TEST_EVAL_PIPELINE_SPEC_JSON ),
759
759
}
760
760
)
761
+ _TEST_DISTILLATION_PIPELINE_SPEC = {
762
+ "components" : {},
763
+ "pipelineInfo" : {
764
+ "description" : "Vertex kfp pipeline for distillation." ,
765
+ "name" : "distillation" ,
766
+ },
767
+ "root" : {
768
+ "dag" : {"tasks" : {}},
769
+ "inputDefinitions" : {
770
+ "parameters" : {
771
+ "accelerator_type" : {
772
+ "defaultValue" : "GPU" ,
773
+ "isOptional" : True ,
774
+ "parameterType" : "STRING" ,
775
+ },
776
+ "api_endpoint" : {
777
+ "defaultValue" : "aiplatform.googleapis.com/ui" ,
778
+ "isOptional" : True ,
779
+ "parameterType" : "STRING" ,
780
+ },
781
+ "dataset_uri" : {"parameterType" : "STRING" },
782
+ "enable_checkpoint_selection" : {
783
+ "defaultValue" : "default" ,
784
+ "isOptional" : True ,
785
+ "parameterType" : "STRING" ,
786
+ },
787
+ "enable_early_stopping" : {
788
+ "defaultValue" : True ,
789
+ "isOptional" : True ,
790
+ "parameterType" : "BOOLEAN" ,
791
+ },
792
+ "encryption_spec_key_name" : {
793
+ "defaultValue" : "" ,
794
+ "isOptional" : True ,
795
+ "parameterType" : "STRING" ,
796
+ },
797
+ "evaluation_data_uri" : {
798
+ "defaultValue" : "" ,
799
+ "isOptional" : True ,
800
+ "parameterType" : "STRING" ,
801
+ },
802
+ "evaluation_interval" : {
803
+ "defaultValue" : 100 ,
804
+ "isOptional" : True ,
805
+ "parameterType" : "NUMBER_INTEGER" ,
806
+ },
807
+ "evaluation_output_root_dir" : {
808
+ "defaultValue" : "" ,
809
+ "isOptional" : True ,
810
+ "parameterType" : "STRING" ,
811
+ },
812
+ "learning_rate_multiplier" : {
813
+ "defaultValue" : 1 ,
814
+ "isOptional" : True ,
815
+ "parameterType" : "NUMBER_DOUBLE" ,
816
+ },
817
+ "location" : {
818
+ "defaultValue" : "" ,
819
+ "isOptional" : True ,
820
+ "parameterType" : "STRING" ,
821
+ },
822
+ "max_context_length" : {
823
+ "defaultValue" : "" ,
824
+ "isOptional" : True ,
825
+ "parameterType" : "STRING" ,
826
+ },
827
+ "model_display_name" : {
828
+ "defaultValue" : "distilled-student-model" ,
829
+ "isOptional" : True ,
830
+ "parameterType" : "STRING" ,
831
+ },
832
+ "project" : {"parameterType" : "STRING" },
833
+ "student_model_reference" : {
834
+ "defaultValue" : "text-bison@002" ,
835
+ "isOptional" : True ,
836
+ "parameterType" : "STRING" ,
837
+ },
838
+ "teacher_model_reference" : {
839
+ "defaultValue" : "text-unicorn@001" ,
840
+ "isOptional" : True ,
841
+ "parameterType" : "STRING" ,
842
+ },
843
+ "temperature" : {
844
+ "defaultValue" : 0 ,
845
+ "isOptional" : True ,
846
+ "parameterType" : "NUMBER_DOUBLE" ,
847
+ },
848
+ "tensorboard_resource_id" : {
849
+ "defaultValue" : "" ,
850
+ "isOptional" : True ,
851
+ "parameterType" : "STRING" ,
852
+ },
853
+ "tpu_training_skip_cmek" : {
854
+ "defaultValue" : False ,
855
+ "isOptional" : True ,
856
+ "parameterType" : "BOOLEAN" ,
857
+ },
858
+ "train_steps" : {
859
+ "defaultValue" : 300 ,
860
+ "isOptional" : True ,
861
+ "parameterType" : "NUMBER_INTEGER" ,
862
+ },
863
+ "version" : {
864
+ "defaultValue" : "latest" ,
865
+ "isOptional" : True ,
866
+ "parameterType" : "STRING" ,
867
+ },
868
+ }
869
+ },
870
+ },
871
+ "schemaVersion" : "2.1.0" ,
872
+ "sdkVersion" : "kfp-2.4.0" ,
873
+ }
874
+
875
+ _TEST_DISTILLATION_PIPELINE_SPEC_JSON = json .dumps (
876
+ _TEST_DISTILLATION_PIPELINE_SPEC ,
877
+ )
878
+
761
879
762
880
# Eval classification spec
763
881
@@ -875,6 +993,10 @@ def reverse_string_2(s):""",
875
993
}
876
994
)
877
995
996
+ _URL_DATA = {
997
+ "https://us-kfp.pkg.dev/ml-pipeline/research/distillation/v1.0.0" : _TEST_DISTILLATION_PIPELINE_SPEC_JSON ,
998
+ }
999
+
878
1000
879
1001
@pytest .fixture
880
1002
def mock_pipeline_bucket_exists ():
@@ -1225,6 +1347,19 @@ def mock_request_urlopen_eval_classification(
1225
1347
yield request .param , mock_urlopen
1226
1348
1227
1349
1350
+ @pytest .fixture
1351
+ def mock_urllib_request_urlopen (request : str ) -> Tuple [str , mock .MagicMock ]:
1352
+ url = request .param
1353
+ data = _URL_DATA [url ]
1354
+ with mock .patch .object (urllib_request , "urlopen" ) as mock_urlopen :
1355
+ mock_read_response = mock .MagicMock ()
1356
+ mock_decode_response = mock .MagicMock ()
1357
+ mock_decode_response .return_value = data
1358
+ mock_read_response .return_value .decode = mock_decode_response
1359
+ mock_urlopen .return_value .read = mock_read_response
1360
+ yield url , mock_urlopen
1361
+
1362
+
1228
1363
@pytest .fixture
1229
1364
def get_endpoint_mock ():
1230
1365
with mock .patch .object (
@@ -4251,3 +4386,102 @@ def test_model_evaluation_text_classification_base_model_only_summary_metrics(
4251
4386
)
4252
4387
assert eval_metrics .confidenceMetrics is None
4253
4388
assert eval_metrics .auPrc == _TEST_TEXT_CLASSIFICATION_METRICS ["auPrc" ]
4389
+
4390
+ @pytest .mark .parametrize (
4391
+ "job_spec" ,
4392
+ [
4393
+ _TEST_DISTILLATION_PIPELINE_SPEC_JSON ,
4394
+ ],
4395
+ )
4396
+ @pytest .mark .parametrize (
4397
+ "mock_urllib_request_urlopen" ,
4398
+ ["https://us-kfp.pkg.dev/ml-pipeline/research/distillation/v1.0.0" ],
4399
+ indirect = True ,
4400
+ )
4401
+ def test_text_generation_model_distill_from (
4402
+ self ,
4403
+ mock_pipeline_service_create ,
4404
+ mock_pipeline_job_get ,
4405
+ mock_pipeline_bucket_exists ,
4406
+ job_spec ,
4407
+ mock_load_yaml_and_json ,
4408
+ mock_gcs_from_string ,
4409
+ mock_gcs_upload ,
4410
+ mock_urllib_request_urlopen ,
4411
+ mock_get_tuned_model ,
4412
+ ):
4413
+ """Tests distilling the text generation model."""
4414
+ aiplatform .init (
4415
+ project = _TEST_PROJECT ,
4416
+ location = _TEST_LOCATION ,
4417
+ encryption_spec_key_name = _TEST_ENCRYPTION_KEY_NAME ,
4418
+ )
4419
+ with mock .patch .object (
4420
+ target = model_garden_service_client .ModelGardenServiceClient ,
4421
+ attribute = "get_publisher_model" ,
4422
+ return_value = gca_publisher_model .PublisherModel (
4423
+ _TEXT_BISON_PUBLISHER_MODEL_DICT
4424
+ ),
4425
+ ):
4426
+ model = preview_language_models .TextGenerationModel .from_pretrained (
4427
+ "text-bison@001"
4428
+ )
4429
+
4430
+ dataset_uri = "gs://bucket/distillation.training_data.jsonl"
4431
+ evaluation_data_uri = "gs://bucket/eval.jsonl"
4432
+ evaluation_interval = 37
4433
+ enable_early_stopping = True
4434
+ enable_checkpoint_selection = True
4435
+ tensorboard_name = (
4436
+ f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } /tensorboards/123"
4437
+ )
4438
+
4439
+ tuning_job = model .distill_from (
4440
+ dataset = dataset_uri ,
4441
+ teacher_model = "text-unicorn@001" ,
4442
+ learning_rate_multiplier = 2.0 ,
4443
+ train_steps = 10 ,
4444
+ evaluation_spec = preview_language_models .TuningEvaluationSpec (
4445
+ evaluation_data = evaluation_data_uri ,
4446
+ evaluation_interval = evaluation_interval ,
4447
+ enable_early_stopping = enable_early_stopping ,
4448
+ enable_checkpoint_selection = enable_checkpoint_selection ,
4449
+ tensorboard = tensorboard_name ,
4450
+ ),
4451
+ accelerator_type = "TPU" ,
4452
+ )
4453
+ call_kwargs = mock_pipeline_service_create .call_args [1 ]
4454
+ pipeline_arguments = call_kwargs [
4455
+ "pipeline_job"
4456
+ ].runtime_config .parameter_values
4457
+ assert pipeline_arguments ["teacher_model_reference" ] == "text-unicorn@001"
4458
+ assert pipeline_arguments ["student_model_reference" ] == "text-bison@001"
4459
+ assert pipeline_arguments ["dataset_uri" ] == dataset_uri
4460
+ assert pipeline_arguments ["project" ] == _TEST_PROJECT
4461
+ assert pipeline_arguments ["location" ] == _TEST_LOCATION
4462
+ assert pipeline_arguments ["train_steps" ] == 10
4463
+ assert pipeline_arguments ["learning_rate_multiplier" ] == 2.0
4464
+ assert pipeline_arguments ["evaluation_data_uri" ] == evaluation_data_uri
4465
+ assert pipeline_arguments ["evaluation_interval" ] == evaluation_interval
4466
+ assert pipeline_arguments ["enable_early_stopping" ] == enable_early_stopping
4467
+ assert (
4468
+ pipeline_arguments ["enable_checkpoint_selection" ]
4469
+ == enable_checkpoint_selection
4470
+ )
4471
+ assert pipeline_arguments ["tensorboard_resource_id" ] == tensorboard_name
4472
+ assert pipeline_arguments ["accelerator_type" ] == "TPU"
4473
+ assert (
4474
+ pipeline_arguments ["encryption_spec_key_name" ]
4475
+ == _TEST_ENCRYPTION_KEY_NAME
4476
+ )
4477
+ assert (
4478
+ call_kwargs ["pipeline_job" ].encryption_spec .kms_key_name
4479
+ == _TEST_ENCRYPTION_KEY_NAME
4480
+ )
4481
+
4482
+ # Testing the tuned model
4483
+ tuned_model = tuning_job .get_tuned_model ()
4484
+ assert (
4485
+ tuned_model ._endpoint_name
4486
+ == test_constants .EndpointConstants ._TEST_ENDPOINT_NAME
4487
+ )
0 commit comments