Skip to content

Commit 4f59470

Browse files
ellismsvincbeck
authored andcommitted
Added Amazon SageMaker Notebook hook and operators (apache#33219)
--------- Co-authored-by: Vincent <[email protected]>
1 parent 90d18fc commit 4f59470

File tree

4 files changed

+573
-0
lines changed

4 files changed

+573
-0
lines changed

airflow/providers/amazon/aws/operators/sagemaker.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from airflow.providers.amazon.aws.utils import trim_none_values
3838
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
3939
from airflow.providers.amazon.aws.utils.tags import format_tags
40+
from airflow.utils.helpers import prune_dict
4041
from airflow.utils.json import AirflowJsonEncoder
4142

4243
if TYPE_CHECKING:
@@ -1523,3 +1524,244 @@ def execute(self, context: Context) -> str:
15231524
arn = ans["ExperimentArn"]
15241525
self.log.info("Experiment %s created successfully with ARN %s.", self.name, arn)
15251526
return arn
1527+
1528+
1529+
class SageMakerCreateNotebookOperator(BaseOperator):
1530+
"""
1531+
Create a SageMaker notebook.
1532+
1533+
More information regarding parameters of this operator can be found here
1534+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_notebook_instance.html.
1535+
1536+
.. seealso:
1537+
For more information on how to use this operator, take a look at the guide:
1538+
:ref:`howto/operator:SageMakerCreateNotebookOperator`
1539+
1540+
:param instance_name: The name of the notebook instance.
1541+
:param instance_type: The type of instance to create.
1542+
:param role_arn: The Amazon Resource Name (ARN) of the IAM role that SageMaker can assume to access
1543+
:param volume_size_in_gb: Size in GB of the EBS root device volume of the notebook instance.
1544+
:param volume_kms_key_id: The KMS key ID for the EBS root device volume.
1545+
:param lifecycle_config_name: The name of the lifecycle configuration to associate with the notebook
1546+
:param direct_internet_access: Whether to enable direct internet access for the notebook instance.
1547+
:param root_access: Whether to give the notebook instance root access to the Amazon S3 bucket.
1548+
:param wait_for_completion: Whether or not to wait for the notebook to be InService before returning
1549+
:param create_instance_kwargs: Additional configuration options for the create call.
1550+
:param aws_conn_id: The AWS connection ID to use.
1551+
1552+
:return: The ARN of the created notebook.
1553+
"""
1554+
1555+
template_fields: Sequence[str] = (
1556+
"instance_name",
1557+
"instance_type",
1558+
"role_arn",
1559+
"volume_size_in_gb",
1560+
"volume_kms_key_id",
1561+
"lifecycle_config_name",
1562+
"direct_internet_access",
1563+
"root_access",
1564+
"wait_for_completion",
1565+
"create_instance_kwargs",
1566+
)
1567+
1568+
ui_color = "#ff7300"
1569+
1570+
def __init__(
1571+
self,
1572+
*,
1573+
instance_name: str,
1574+
instance_type: str,
1575+
role_arn: str,
1576+
volume_size_in_gb: int | None = None,
1577+
volume_kms_key_id: str | None = None,
1578+
lifecycle_config_name: str | None = None,
1579+
direct_internet_access: str | None = None,
1580+
root_access: str | None = None,
1581+
create_instance_kwargs: dict[str, Any] = {},
1582+
wait_for_completion: bool = True,
1583+
aws_conn_id: str = "aws_default",
1584+
**kwargs,
1585+
):
1586+
super().__init__(**kwargs)
1587+
self.instance_name = instance_name
1588+
self.instance_type = instance_type
1589+
self.role_arn = role_arn
1590+
self.volume_size_in_gb = volume_size_in_gb
1591+
self.volume_kms_key_id = volume_kms_key_id
1592+
self.lifecycle_config_name = lifecycle_config_name
1593+
self.direct_internet_access = direct_internet_access
1594+
self.root_access = root_access
1595+
self.wait_for_completion = wait_for_completion
1596+
self.aws_conn_id = aws_conn_id
1597+
self.create_instance_kwargs = create_instance_kwargs
1598+
1599+
if self.create_instance_kwargs.get("tags") is not None:
1600+
self.create_instance_kwargs["tags"] = format_tags(self.create_instance_kwargs["tags"])
1601+
1602+
@cached_property
1603+
def hook(self) -> SageMakerHook:
1604+
"""Create and return SageMakerHook."""
1605+
return SageMakerHook(aws_conn_id=self.aws_conn_id)
1606+
1607+
def execute(self, context: Context):
1608+
1609+
create_notebook_instance_kwargs = {
1610+
"NotebookInstanceName": self.instance_name,
1611+
"InstanceType": self.instance_type,
1612+
"RoleArn": self.role_arn,
1613+
"VolumeSizeInGB": self.volume_size_in_gb,
1614+
"KmsKeyId": self.volume_kms_key_id,
1615+
"LifecycleConfigName": self.lifecycle_config_name,
1616+
"DirectInternetAccess": self.direct_internet_access,
1617+
"RootAccess": self.root_access,
1618+
}
1619+
if len(self.create_instance_kwargs) > 0:
1620+
create_notebook_instance_kwargs.update(self.create_instance_kwargs)
1621+
1622+
self.log.info("Creating SageMaker notebook %s.", self.instance_name)
1623+
response = self.hook.conn.create_notebook_instance(**prune_dict(create_notebook_instance_kwargs))
1624+
1625+
self.log.info("SageMaker notebook created: %s", response["NotebookInstanceArn"])
1626+
1627+
if self.wait_for_completion:
1628+
self.log.info("Waiting for SageMaker notebook %s to be in service", self.instance_name)
1629+
waiter = self.hook.conn.get_waiter("notebook_instance_in_service")
1630+
waiter.wait(NotebookInstanceName=self.instance_name)
1631+
1632+
return response["NotebookInstanceArn"]
1633+
1634+
1635+
class SageMakerStopNotebookOperator(BaseOperator):
1636+
"""
1637+
Stop a notebook instance.
1638+
1639+
.. seealso:
1640+
For more information on how to use this operator, take a look at the guide:
1641+
:ref:`howto/operator:SageMakerStopNotebookOperator`
1642+
1643+
:param instance_name: The name of the notebook instance to stop.
1644+
:param wait_for_completion: Whether or not to wait for the notebook to be stopped before returning
1645+
:param aws_conn_id: The AWS connection ID to use.
1646+
"""
1647+
1648+
template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
1649+
1650+
ui_color = "#ff7300"
1651+
1652+
def __init__(
1653+
self,
1654+
instance_name: str,
1655+
wait_for_completion: bool = True,
1656+
aws_conn_id: str = "aws_default",
1657+
**kwargs,
1658+
):
1659+
super().__init__(**kwargs)
1660+
self.instance_name = instance_name
1661+
self.wait_for_completion = wait_for_completion
1662+
self.aws_conn_id = aws_conn_id
1663+
1664+
@cached_property
1665+
def hook(self) -> SageMakerHook:
1666+
"""Create and return SageMakerHook."""
1667+
return SageMakerHook(aws_conn_id=self.aws_conn_id)
1668+
1669+
def execute(self, context):
1670+
self.log.info("Stopping SageMaker notebook %s.", self.instance_name)
1671+
self.hook.conn.stop_notebook_instance(NotebookInstanceName=self.instance_name)
1672+
1673+
if self.wait_for_completion:
1674+
self.log.info("Waiting for SageMaker notebook %s to stop", self.instance_name)
1675+
self.hook.conn.get_waiter("notebook_instance_stopped").wait(
1676+
NotebookInstanceName=self.instance_name
1677+
)
1678+
1679+
1680+
class SageMakerDeleteNotebookOperator(BaseOperator):
1681+
"""
1682+
Delete a notebook instance.
1683+
1684+
.. seealso:
1685+
For more information on how to use this operator, take a look at the guide:
1686+
:ref:`howto/operator:SageMakerDeleteNotebookOperator`
1687+
1688+
:param instance_name: The name of the notebook instance to delete.
1689+
:param wait_for_completion: Whether or not to wait for the notebook to delete before returning.
1690+
:param aws_conn_id: The AWS connection ID to use.
1691+
"""
1692+
1693+
template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
1694+
1695+
ui_color = "#ff7300"
1696+
1697+
def __init__(
1698+
self,
1699+
instance_name: str,
1700+
wait_for_completion: bool = True,
1701+
aws_conn_id: str = "aws_default",
1702+
**kwargs,
1703+
):
1704+
super().__init__(**kwargs)
1705+
self.instance_name = instance_name
1706+
self.aws_conn_id = aws_conn_id
1707+
self.wait_for_completion = wait_for_completion
1708+
1709+
@cached_property
1710+
def hook(self) -> SageMakerHook:
1711+
"""Create and return SageMakerHook."""
1712+
return SageMakerHook(aws_conn_id=self.aws_conn_id)
1713+
1714+
def execute(self, context):
1715+
self.log.info("Deleting SageMaker notebook %s....", self.instance_name)
1716+
self.hook.conn.delete_notebook_instance(NotebookInstanceName=self.instance_name)
1717+
1718+
if self.wait_for_completion:
1719+
self.log.info("Waiting for SageMaker notebook %s to delete...", self.instance_name)
1720+
self.hook.conn.get_waiter("notebook_instance_deleted").wait(
1721+
NotebookInstanceName=self.instance_name
1722+
)
1723+
1724+
1725+
class SageMakerStartNoteBookOperator(BaseOperator):
1726+
"""
1727+
Start a notebook instance.
1728+
1729+
.. seealso:
1730+
For more information on how to use this operator, take a look at the guide:
1731+
:ref:`howto/operator:SageMakerStartNotebookOperator`
1732+
1733+
:param instance_name: The name of the notebook instance to start.
1734+
:param wait_for_completion: Whether or not to wait for notebook to be InService before returning
1735+
:param aws_conn_id: The AWS connection ID to use.
1736+
"""
1737+
1738+
template_fields: Sequence[str] = ("instance_name", "wait_for_completion")
1739+
1740+
ui_color = "#ff7300"
1741+
1742+
def __init__(
1743+
self,
1744+
instance_name: str,
1745+
wait_for_completion: bool = True,
1746+
aws_conn_id: str = "aws_default",
1747+
**kwargs,
1748+
):
1749+
super().__init__(**kwargs)
1750+
self.instance_name = instance_name
1751+
self.aws_conn_id = aws_conn_id
1752+
self.wait_for_completion = wait_for_completion
1753+
1754+
@cached_property
1755+
def hook(self) -> SageMakerHook:
1756+
"""Create and return SageMakerHook."""
1757+
return SageMakerHook(aws_conn_id=self.aws_conn_id)
1758+
1759+
def execute(self, context):
1760+
self.log.info("Starting SageMaker notebook %s....", self.instance_name)
1761+
self.hook.conn.start_notebook_instance(NotebookInstanceName=self.instance_name)
1762+
1763+
if self.wait_for_completion:
1764+
self.log.info("Waiting for SageMaker notebook %s to start...", self.instance_name)
1765+
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
1766+
NotebookInstanceName=self.instance_name
1767+
)

docs/apache-airflow-providers-amazon/operators/sagemaker.rst

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,63 @@ This creates an experiment so that it's ready to be associated with processing,
222222
:start-after: [START howto_operator_sagemaker_experiment]
223223
:end-before: [END howto_operator_sagemaker_experiment]
224224

225+
.. _howto/operator:SageMakerCreateNotebookOperator:
226+
227+
Create a SageMaker Notebook Instance
228+
====================================
229+
230+
To create a SageMaker Notebook Instance , you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerCreateNotebookOperator`.
231+
This creates a SageMaker Notebook Instance ready to run Jupyter notebooks.
232+
233+
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_notebook.py
234+
:language: python
235+
:dedent: 4
236+
:start-after: [START howto_operator_sagemaker_notebook_create]
237+
:end-before: [END howto_operator_sagemaker_notebook_create]
238+
239+
.. _howto/operator:SageMakerStopNotebookOperator:
240+
241+
Stop a SageMaker Notebook Instance
242+
==================================
243+
244+
To terminate SageMaker Notebook Instance , you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerStopNotebookOperator`.
245+
This terminates the ML compute instance and disconnects the ML storage volume.
246+
247+
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_notebook.py
248+
:language: python
249+
:dedent: 4
250+
:start-after: [START howto_operator_sagemaker_notebook_stop]
251+
:end-before: [END howto_operator_sagemaker_notebook_stop]
252+
253+
.. _howto/operator:SageMakerStartNotebookOperator:
254+
255+
Start a SageMaker Notebook Instance
256+
===================================
257+
258+
To launch a SageMaker Notebook Instance and re-attach an ML storage volume, you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerStartNotebookOperator`.
259+
This launches a new ML compute instance with the latest version of the libraries and attached your ML storage volume.
260+
261+
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_notebook.py
262+
:language: python
263+
:dedent: 4
264+
:start-after: [START howto_operator_sagemaker_notebook_start]
265+
:end-before: [END howto_operator_sagemaker_notebook_start]
266+
267+
268+
.. _howto/operator:SageMakerDeleteNotebookOperator:
269+
270+
Delete a SageMaker Notebook Instance
271+
====================================
272+
273+
To delete a SageMaker Notebook Instance, you can use :class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerDeleteNotebookOperator`.
274+
This terminates the instance and deletes the ML storage volume and network interface associated with the instance. The instance must be stopped before it can be deleted.
275+
276+
.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_sagemaker_notebook.py
277+
:language: python
278+
:dedent: 4
279+
:start-after: [START howto_operator_sagemaker_notebook_delete]
280+
:end-before: [END howto_operator_sagemaker_notebook_delete]
281+
225282
Sensors
226283
-------
227284

0 commit comments

Comments
 (0)