Skip to content

Commit 3983407

Browse files
Allow registered component to be passed into the sdk. (Azure#32756)
* allow for a registered component to be passed into the sdk for deployment creation * allow for a completed job to be passed into job definition * fix pylint error * account for cli httperror * black reformatting * resolve pylint errors
1 parent efcd0ca commit 3983407

File tree

3 files changed

+66
-12
lines changed

3 files changed

+66
-12
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/batch/pipeline_component_batch_deployment_schema.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,20 @@
66
import logging
77
from typing import Any
88

9-
from marshmallow import fields, post_load
9+
from marshmallow import INCLUDE, fields, post_load
1010

1111
from azure.ai.ml._schema import (
1212
ArmVersionedStr,
1313
ArmStr,
1414
UnionField,
1515
RegistryStr,
16+
NestedField,
1617
)
18+
from azure.ai.ml._schema.core.fields import PipelineNodeNameStr, TypeSensitiveUnionField
1719
from azure.ai.ml._schema._deployment.deployment import DeploymentSchema
1820
from azure.ai.ml._schema.pipeline.pipeline_component import PipelineComponentFileRefField
1921
from azure.ai.ml.constants._common import AzureMLResourceType
22+
from azure.ai.ml.constants._component import NodeType
2023

2124
module_logger = logging.getLogger(__name__)
2225

@@ -37,7 +40,7 @@ class PipelineComponentBatchDeploymentSchema(DeploymentSchema):
3740
job_definition = UnionField(
3841
[
3942
ArmStr(azureml_type=AzureMLResourceType.JOB),
40-
PipelineComponentFileRefField(),
43+
NestedField("PipelineSchema", unknown=INCLUDE),
4144
]
4245
)
4346

@@ -48,3 +51,19 @@ def make(self, data: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argum
4851
)
4952

5053
return PipelineComponentBatchDeployment(**data)
54+
55+
56+
class NodeNameStr(PipelineNodeNameStr):
57+
def _get_field_name(self) -> str:
58+
return "Pipeline node"
59+
60+
61+
def PipelineJobsField():
62+
pipeline_enable_job_type = {NodeType.PIPELINE: [NestedField("PipelineSchema", unknown=INCLUDE)]}
63+
64+
pipeline_job_field = fields.Dict(
65+
keys=NodeNameStr(),
66+
values=TypeSensitiveUnionField(pipeline_enable_job_type),
67+
)
68+
69+
return pipeline_job_field

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/pipeline_component_batch_deployment.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Dict, Optional, Union
88

99
from azure.ai.ml.entities._component.component import Component
10+
from azure.ai.ml.entities._builders import BaseNode
1011
from azure.ai.ml._schema._deployment.batch.pipeline_component_batch_deployment_schema import (
1112
PipelineComponentBatchDeploymentSchema,
1213
) # pylint: disable=line-too-long
@@ -39,6 +40,8 @@ class PipelineComponentBatchDeployment(Deployment):
3940
:type description: Optional[str]
4041
:param tags: A set of tags. The tags which will be applied to the job.
4142
:type tags: Optional[Dict[str, Any]]
43+
:param job_definition: Arm ID or PipelineJob entity of an existing pipeline job.
44+
:param job_definition: Optional[Dict[str, ~azure.ai.ml.entities._builders.BaseNode]]
4245
"""
4346

4447
def __init__(
@@ -48,12 +51,13 @@ def __init__(
4851
endpoint_name: Optional[str] = None,
4952
component: Optional[Union[Component, str]] = None,
5053
settings: Optional[Dict[str, str]] = None,
54+
job_definition: Optional[Dict[str, BaseNode]] = None,
5155
**kwargs, # pylint: disable=unused-argument
5256
):
53-
self.job_definition = kwargs.pop("job_definition", None)
5457
super().__init__(endpoint_name=endpoint_name, name=name, **kwargs)
5558
self.component = component
5659
self.settings = settings
60+
self.job_definition = job_definition
5761

5862
def _to_rest_object(self, location: str) -> "RestBatchDeployment": # pylint: disable=arguments-differ
5963
if isinstance(self.component, PipelineComponent):

sdk/ml/azure-ai-ml/azure/ai/ml/operations/_batch_deployment_operations.py

+40-9
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,14 @@
2323
from azure.ai.ml._utils._package_utils import package_deployment
2424
from azure.ai.ml._utils.utils import _get_mfe_base_url_from_discovery_service, modified_operation_client
2525
from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, LROConfigurations
26-
from azure.ai.ml.entities import BatchDeployment, BatchJob, ModelBatchDeployment, PipelineComponent
26+
from azure.ai.ml.entities import BatchDeployment, BatchJob, ModelBatchDeployment, PipelineComponent, PipelineJob
2727
from azure.ai.ml.entities._deployment.deployment import Deployment
2828
from azure.ai.ml.entities._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment
2929
from azure.core.credentials import TokenCredential
3030
from azure.core.paging import ItemPaged
3131
from azure.core.polling import LROPoller
3232
from azure.core.tracing.decorator import distributed_trace
33+
from azure.core.exceptions import ResourceNotFoundError, HttpResponseError
3334

3435
from ._operation_orchestrator import OperationOrchestrator
3536

@@ -332,14 +333,25 @@ def _validate_component(self, deployment: Deployment, orchestrators: OperationOr
332333
:type orchestrators: _operation_orchestrator.OperationOrchestrator
333334
"""
334335
if isinstance(deployment.component, PipelineComponent):
335-
deployment.component = self._all_operations.all_operations[AzureMLResourceType.COMPONENT].create_or_update(
336-
name=deployment.component.name,
337-
resource_group_name=self._resource_group_name,
338-
workspace_name=self._workspace_name,
339-
component=deployment.component,
340-
version=deployment.component.version,
341-
**self._init_kwargs,
342-
)
336+
try:
337+
registered_component = self._all_operations.all_operations[AzureMLResourceType.COMPONENT].get(
338+
name=deployment.component.name, version=deployment.component.version
339+
)
340+
deployment.component = registered_component.id
341+
except Exception as err: # pylint: disable=broad-except
342+
if isinstance(err, (ResourceNotFoundError, HttpResponseError)):
343+
deployment.component = self._all_operations.all_operations[
344+
AzureMLResourceType.COMPONENT
345+
].create_or_update(
346+
name=deployment.component.name,
347+
resource_group_name=self._resource_group_name,
348+
workspace_name=self._workspace_name,
349+
component=deployment.component,
350+
version=deployment.component.version,
351+
**self._init_kwargs,
352+
)
353+
else:
354+
raise err
343355
elif isinstance(deployment.component, str):
344356
component_id = orchestrators.get_asset_arm_id(
345357
deployment.component, azureml_type=AzureMLResourceType.COMPONENT
@@ -356,3 +368,22 @@ def _validate_component(self, deployment: Deployment, orchestrators: OperationOr
356368
**self._init_kwargs,
357369
)
358370
deployment.component = job_component.id
371+
372+
elif isinstance(deployment.job_definition, PipelineJob):
373+
try:
374+
registered_job = self._all_operations.all_operations[AzureMLResourceType.JOB].get(
375+
name=deployment.job_definition.name
376+
)
377+
if registered_job:
378+
job_component = PipelineComponent(source_job_id=registered_job.name)
379+
job_component = self._component_operations.create_or_update(
380+
name=job_component.name,
381+
resource_group_name=self._resource_group_name,
382+
workspace_name=self._workspace_name,
383+
body=job_component._to_rest_object(),
384+
version=job_component.version,
385+
**self._init_kwargs,
386+
)
387+
deployment.component = job_component.id
388+
except ResourceNotFoundError as err:
389+
raise err

0 commit comments

Comments
 (0)