Skip to content

Commit c939421

Browse files
authored
[ML] Expose vnet parameters on sdk (Azure#26488)
1 parent 25ca3db commit c939421

File tree

5 files changed

+18
-10
lines changed

5 files changed

+18
-10
lines changed

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/online/online_deployment.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def make(self, data: Any, **kwargs: Any) -> Any:
6363

6464
class ManagedOnlineDeploymentSchema(OnlineDeploymentSchema):
6565
instance_type = fields.Str(required=True)
66-
egress_public_network_access = ExperimentalField(
67-
StringTransformedEnum(allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED])
68-
)
66+
egress_public_network_access = StringTransformedEnum(allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED])
6967
data_collector = ExperimentalField(NestedField(DataCollectorSchema))
7068
private_network_connection = ExperimentalField(fields.Bool())
7169

sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_endpoint/online/online_endpoint.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@ def make(self, data: Any, **kwargs: Any) -> Any:
5757

5858
class ManagedOnlineEndpointSchema(OnlineEndpointSchema):
5959
provisioning_state = fields.Str()
60-
public_network_access = ExperimentalField(
61-
StringTransformedEnum(allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED])
62-
)
60+
public_network_access = StringTransformedEnum(allowed_values=[PublicNetworkAccess.ENABLED, PublicNetworkAccess.DISABLED])
6361

6462
@post_load
6563
def make(self, data: Any, **kwargs: Any) -> Any:

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/online_deployment.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,9 @@ class ManagedOnlineDeployment(OnlineDeployment):
515515
:type code_path: Union[str, PathLike], optional
516516
:param scoring_script: Scoring script name. Equivalent to code_configuration.code.scoring_script.
517517
:type scoring_script: Union[str, PathLike], optional
518+
:param egress_public_network_access: Wether to restrict communication between a deployment
519+
and the Azure resources used to by the deployment. Allowed values are: "enabled", "disabled"
520+
:param egress_public_network_access: str
518521
"""
519522

520523
def __init__(
@@ -538,12 +541,12 @@ def __init__(
538541
instance_count: int = None,
539542
code_path: Union[str, PathLike] = None, # promoted property from code_configuration.code
540543
scoring_script: Union[str, PathLike] = None, # promoted property from code_configuration.scoring_script
544+
egress_public_network_access = None,
541545
**kwargs,
542546
):
543547

544548
kwargs["type"] = EndpointComputeType.MANAGED.value
545549
self.private_network_connection = kwargs.pop("private_network_connection", None)
546-
self.egress_public_network_access = kwargs.pop("egress_public_network_access", None)
547550
self.data_collector = kwargs.pop("data_collector", None)
548551

549552
super(ManagedOnlineDeployment, self).__init__(
@@ -569,6 +572,7 @@ def __init__(
569572
)
570573

571574
self.readiness_probe = readiness_probe
575+
self.egress_public_network_access = egress_public_network_access
572576

573577
def _to_dict(self) -> Dict:
574578
return ManagedOnlineDeploymentSchema(context={BASE_PATH_CONTEXT_KEY: "./"}).dump(self)
@@ -702,4 +706,4 @@ def _validate_scale_settings(self) -> None:
702706
target=ErrorTarget.ONLINE_DEPLOYMENT,
703707
no_personal_data_message=msg,
704708
error_category=ErrorCategory.USER_ERROR,
705-
)
709+
)

sdk/ml/azure-ai-ml/azure/ai/ml/entities/_endpoint/online_endpoint.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,10 @@ class ManagedOnlineEndpoint(OnlineEndpoint):
363363
:param identity: defaults to SystemAssigned
364364
:type identity: IdentityConfiguration, optional
365365
:param kind: Kind of the resource, we have two kinds: K8s and Managed online endpoints, defaults to None.
366-
:type kind: str, optional
366+
:type kind: str, optional,
367+
:param public_network_access: Whether to allow public endpoint connectivity
368+
Allowed values are: "enabled", "disabled"
369+
:type public_network_access: str
367370
"""
368371

369372
def __init__(
@@ -379,9 +382,10 @@ def __init__(
379382
mirror_traffic: Dict[str, int] = None,
380383
identity: IdentityConfiguration = None,
381384
kind: str = None,
385+
public_network_access = None,
382386
**kwargs,
383387
):
384-
self.public_network_access = kwargs.pop("public_network_access", None)
388+
self.public_network_access = public_network_access
385389

386390
super(ManagedOnlineEndpoint, self).__init__(
387391
name=name,

sdk/ml/azure-ai-ml/tests/model/e2etests/test_model.py

+4
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ def test_list_model_registry(self, registry_client: MLClient, randstr: Callable[
201201
model_list = [m.name for m in model_list if m is not None]
202202
assert model.name in model_list
203203

204+
@pytest.mark.skipif(
205+
condition=not is_live(),
206+
reason="Registry uploads do not record well. Investigate later"
207+
)
204208
def test_promote_model(self, randstr: Callable[[], str], client: MLClient, registry_client: MLClient) -> None:
205209
# Create model in workspace
206210
model_path = Path("./tests/test_configs/model/model_full.yml")

0 commit comments

Comments
 (0)