|
37 | 37 | from airflow.providers.amazon.aws.utils import trim_none_values
|
38 | 38 | from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
|
39 | 39 | from airflow.providers.amazon.aws.utils.tags import format_tags
|
| 40 | +from airflow.utils.helpers import prune_dict |
40 | 41 | from airflow.utils.json import AirflowJsonEncoder
|
41 | 42 |
|
42 | 43 | if TYPE_CHECKING:
|
@@ -1523,3 +1524,244 @@ def execute(self, context: Context) -> str:
|
1523 | 1524 | arn = ans["ExperimentArn"]
|
1524 | 1525 | self.log.info("Experiment %s created successfully with ARN %s.", self.name, arn)
|
1525 | 1526 | return arn
|
| 1527 | + |
| 1528 | + |
| 1529 | +class SageMakerCreateNotebookOperator(BaseOperator): |
| 1530 | + """ |
| 1531 | + Create a SageMaker notebook. |
| 1532 | +
|
| 1533 | + More information regarding parameters of this operator can be found here |
| 1534 | + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/create_notebook_instance.html. |
| 1535 | +
|
| 1536 | + .. seealso: |
| 1537 | + For more information on how to use this operator, take a look at the guide: |
| 1538 | + :ref:`howto/operator:SageMakerCreateNotebookOperator` |
| 1539 | +
|
| 1540 | + :param instance_name: The name of the notebook instance. |
| 1541 | + :param instance_type: The type of instance to create. |
| 1542 | + :param role_arn: The Amazon Resource Name (ARN) of the IAM role that SageMaker can assume to access |
| 1543 | + :param volume_size_in_gb: Size in GB of the EBS root device volume of the notebook instance. |
| 1544 | + :param volume_kms_key_id: The KMS key ID for the EBS root device volume. |
| 1545 | + :param lifecycle_config_name: The name of the lifecycle configuration to associate with the notebook |
| 1546 | + :param direct_internet_access: Whether to enable direct internet access for the notebook instance. |
| 1547 | + :param root_access: Whether to give the notebook instance root access to the Amazon S3 bucket. |
| 1548 | + :param wait_for_completion: Whether or not to wait for the notebook to be InService before returning |
| 1549 | + :param create_instance_kwargs: Additional configuration options for the create call. |
| 1550 | + :param aws_conn_id: The AWS connection ID to use. |
| 1551 | +
|
| 1552 | + :return: The ARN of the created notebook. |
| 1553 | + """ |
| 1554 | + |
| 1555 | + template_fields: Sequence[str] = ( |
| 1556 | + "instance_name", |
| 1557 | + "instance_type", |
| 1558 | + "role_arn", |
| 1559 | + "volume_size_in_gb", |
| 1560 | + "volume_kms_key_id", |
| 1561 | + "lifecycle_config_name", |
| 1562 | + "direct_internet_access", |
| 1563 | + "root_access", |
| 1564 | + "wait_for_completion", |
| 1565 | + "create_instance_kwargs", |
| 1566 | + ) |
| 1567 | + |
| 1568 | + ui_color = "#ff7300" |
| 1569 | + |
| 1570 | + def __init__( |
| 1571 | + self, |
| 1572 | + *, |
| 1573 | + instance_name: str, |
| 1574 | + instance_type: str, |
| 1575 | + role_arn: str, |
| 1576 | + volume_size_in_gb: int | None = None, |
| 1577 | + volume_kms_key_id: str | None = None, |
| 1578 | + lifecycle_config_name: str | None = None, |
| 1579 | + direct_internet_access: str | None = None, |
| 1580 | + root_access: str | None = None, |
| 1581 | + create_instance_kwargs: dict[str, Any] = {}, |
| 1582 | + wait_for_completion: bool = True, |
| 1583 | + aws_conn_id: str = "aws_default", |
| 1584 | + **kwargs, |
| 1585 | + ): |
| 1586 | + super().__init__(**kwargs) |
| 1587 | + self.instance_name = instance_name |
| 1588 | + self.instance_type = instance_type |
| 1589 | + self.role_arn = role_arn |
| 1590 | + self.volume_size_in_gb = volume_size_in_gb |
| 1591 | + self.volume_kms_key_id = volume_kms_key_id |
| 1592 | + self.lifecycle_config_name = lifecycle_config_name |
| 1593 | + self.direct_internet_access = direct_internet_access |
| 1594 | + self.root_access = root_access |
| 1595 | + self.wait_for_completion = wait_for_completion |
| 1596 | + self.aws_conn_id = aws_conn_id |
| 1597 | + self.create_instance_kwargs = create_instance_kwargs |
| 1598 | + |
| 1599 | + if self.create_instance_kwargs.get("tags") is not None: |
| 1600 | + self.create_instance_kwargs["tags"] = format_tags(self.create_instance_kwargs["tags"]) |
| 1601 | + |
| 1602 | + @cached_property |
| 1603 | + def hook(self) -> SageMakerHook: |
| 1604 | + """Create and return SageMakerHook.""" |
| 1605 | + return SageMakerHook(aws_conn_id=self.aws_conn_id) |
| 1606 | + |
| 1607 | + def execute(self, context: Context): |
| 1608 | + |
| 1609 | + create_notebook_instance_kwargs = { |
| 1610 | + "NotebookInstanceName": self.instance_name, |
| 1611 | + "InstanceType": self.instance_type, |
| 1612 | + "RoleArn": self.role_arn, |
| 1613 | + "VolumeSizeInGB": self.volume_size_in_gb, |
| 1614 | + "KmsKeyId": self.volume_kms_key_id, |
| 1615 | + "LifecycleConfigName": self.lifecycle_config_name, |
| 1616 | + "DirectInternetAccess": self.direct_internet_access, |
| 1617 | + "RootAccess": self.root_access, |
| 1618 | + } |
| 1619 | + if len(self.create_instance_kwargs) > 0: |
| 1620 | + create_notebook_instance_kwargs.update(self.create_instance_kwargs) |
| 1621 | + |
| 1622 | + self.log.info("Creating SageMaker notebook %s.", self.instance_name) |
| 1623 | + response = self.hook.conn.create_notebook_instance(**prune_dict(create_notebook_instance_kwargs)) |
| 1624 | + |
| 1625 | + self.log.info("SageMaker notebook created: %s", response["NotebookInstanceArn"]) |
| 1626 | + |
| 1627 | + if self.wait_for_completion: |
| 1628 | + self.log.info("Waiting for SageMaker notebook %s to be in service", self.instance_name) |
| 1629 | + waiter = self.hook.conn.get_waiter("notebook_instance_in_service") |
| 1630 | + waiter.wait(NotebookInstanceName=self.instance_name) |
| 1631 | + |
| 1632 | + return response["NotebookInstanceArn"] |
| 1633 | + |
| 1634 | + |
| 1635 | +class SageMakerStopNotebookOperator(BaseOperator): |
| 1636 | + """ |
| 1637 | + Stop a notebook instance. |
| 1638 | +
|
| 1639 | + .. seealso: |
| 1640 | + For more information on how to use this operator, take a look at the guide: |
| 1641 | + :ref:`howto/operator:SageMakerStopNotebookOperator` |
| 1642 | +
|
| 1643 | + :param instance_name: The name of the notebook instance to stop. |
| 1644 | + :param wait_for_completion: Whether or not to wait for the notebook to be stopped before returning |
| 1645 | + :param aws_conn_id: The AWS connection ID to use. |
| 1646 | + """ |
| 1647 | + |
| 1648 | + template_fields: Sequence[str] = ("instance_name", "wait_for_completion") |
| 1649 | + |
| 1650 | + ui_color = "#ff7300" |
| 1651 | + |
| 1652 | + def __init__( |
| 1653 | + self, |
| 1654 | + instance_name: str, |
| 1655 | + wait_for_completion: bool = True, |
| 1656 | + aws_conn_id: str = "aws_default", |
| 1657 | + **kwargs, |
| 1658 | + ): |
| 1659 | + super().__init__(**kwargs) |
| 1660 | + self.instance_name = instance_name |
| 1661 | + self.wait_for_completion = wait_for_completion |
| 1662 | + self.aws_conn_id = aws_conn_id |
| 1663 | + |
| 1664 | + @cached_property |
| 1665 | + def hook(self) -> SageMakerHook: |
| 1666 | + """Create and return SageMakerHook.""" |
| 1667 | + return SageMakerHook(aws_conn_id=self.aws_conn_id) |
| 1668 | + |
| 1669 | + def execute(self, context): |
| 1670 | + self.log.info("Stopping SageMaker notebook %s.", self.instance_name) |
| 1671 | + self.hook.conn.stop_notebook_instance(NotebookInstanceName=self.instance_name) |
| 1672 | + |
| 1673 | + if self.wait_for_completion: |
| 1674 | + self.log.info("Waiting for SageMaker notebook %s to stop", self.instance_name) |
| 1675 | + self.hook.conn.get_waiter("notebook_instance_stopped").wait( |
| 1676 | + NotebookInstanceName=self.instance_name |
| 1677 | + ) |
| 1678 | + |
| 1679 | + |
| 1680 | +class SageMakerDeleteNotebookOperator(BaseOperator): |
| 1681 | + """ |
| 1682 | + Delete a notebook instance. |
| 1683 | +
|
| 1684 | + .. seealso: |
| 1685 | + For more information on how to use this operator, take a look at the guide: |
| 1686 | + :ref:`howto/operator:SageMakerDeleteNotebookOperator` |
| 1687 | +
|
| 1688 | + :param instance_name: The name of the notebook instance to delete. |
| 1689 | + :param wait_for_completion: Whether or not to wait for the notebook to delete before returning. |
| 1690 | + :param aws_conn_id: The AWS connection ID to use. |
| 1691 | + """ |
| 1692 | + |
| 1693 | + template_fields: Sequence[str] = ("instance_name", "wait_for_completion") |
| 1694 | + |
| 1695 | + ui_color = "#ff7300" |
| 1696 | + |
| 1697 | + def __init__( |
| 1698 | + self, |
| 1699 | + instance_name: str, |
| 1700 | + wait_for_completion: bool = True, |
| 1701 | + aws_conn_id: str = "aws_default", |
| 1702 | + **kwargs, |
| 1703 | + ): |
| 1704 | + super().__init__(**kwargs) |
| 1705 | + self.instance_name = instance_name |
| 1706 | + self.aws_conn_id = aws_conn_id |
| 1707 | + self.wait_for_completion = wait_for_completion |
| 1708 | + |
| 1709 | + @cached_property |
| 1710 | + def hook(self) -> SageMakerHook: |
| 1711 | + """Create and return SageMakerHook.""" |
| 1712 | + return SageMakerHook(aws_conn_id=self.aws_conn_id) |
| 1713 | + |
| 1714 | + def execute(self, context): |
| 1715 | + self.log.info("Deleting SageMaker notebook %s....", self.instance_name) |
| 1716 | + self.hook.conn.delete_notebook_instance(NotebookInstanceName=self.instance_name) |
| 1717 | + |
| 1718 | + if self.wait_for_completion: |
| 1719 | + self.log.info("Waiting for SageMaker notebook %s to delete...", self.instance_name) |
| 1720 | + self.hook.conn.get_waiter("notebook_instance_deleted").wait( |
| 1721 | + NotebookInstanceName=self.instance_name |
| 1722 | + ) |
| 1723 | + |
| 1724 | + |
| 1725 | +class SageMakerStartNoteBookOperator(BaseOperator): |
| 1726 | + """ |
| 1727 | + Start a notebook instance. |
| 1728 | +
|
| 1729 | + .. seealso: |
| 1730 | + For more information on how to use this operator, take a look at the guide: |
| 1731 | + :ref:`howto/operator:SageMakerStartNotebookOperator` |
| 1732 | +
|
| 1733 | + :param instance_name: The name of the notebook instance to start. |
| 1734 | + :param wait_for_completion: Whether or not to wait for notebook to be InService before returning |
| 1735 | + :param aws_conn_id: The AWS connection ID to use. |
| 1736 | + """ |
| 1737 | + |
| 1738 | + template_fields: Sequence[str] = ("instance_name", "wait_for_completion") |
| 1739 | + |
| 1740 | + ui_color = "#ff7300" |
| 1741 | + |
| 1742 | + def __init__( |
| 1743 | + self, |
| 1744 | + instance_name: str, |
| 1745 | + wait_for_completion: bool = True, |
| 1746 | + aws_conn_id: str = "aws_default", |
| 1747 | + **kwargs, |
| 1748 | + ): |
| 1749 | + super().__init__(**kwargs) |
| 1750 | + self.instance_name = instance_name |
| 1751 | + self.aws_conn_id = aws_conn_id |
| 1752 | + self.wait_for_completion = wait_for_completion |
| 1753 | + |
| 1754 | + @cached_property |
| 1755 | + def hook(self) -> SageMakerHook: |
| 1756 | + """Create and return SageMakerHook.""" |
| 1757 | + return SageMakerHook(aws_conn_id=self.aws_conn_id) |
| 1758 | + |
| 1759 | + def execute(self, context): |
| 1760 | + self.log.info("Starting SageMaker notebook %s....", self.instance_name) |
| 1761 | + self.hook.conn.start_notebook_instance(NotebookInstanceName=self.instance_name) |
| 1762 | + |
| 1763 | + if self.wait_for_completion: |
| 1764 | + self.log.info("Waiting for SageMaker notebook %s to start...", self.instance_name) |
| 1765 | + self.hook.conn.get_waiter("notebook_instance_in_service").wait( |
| 1766 | + NotebookInstanceName=self.instance_name |
| 1767 | + ) |
0 commit comments