Skip to content

Commit 926ce39

Browse files
committed
Track cluster creation time and use it in the DRA create token
1 parent 6a4ee86 commit 926ce39

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

hpc_provisioner/src/hpc_provisioner/aws_queries.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from hpc_provisioner.dynamodb_actions import (
1818
SubnetAlreadyRegisteredException,
1919
dynamodb_client,
20+
dynamodb_resource,
2021
free_subnet,
22+
get_cluster_by_name,
2123
get_registered_subnets,
2224
get_subnet,
2325
register_subnet,
@@ -418,8 +420,7 @@ def create_dra(
418420
filesystem_id: str,
419421
mountpoint: str,
420422
bucket: str,
421-
vlab_id: str,
422-
project_id: str,
423+
cluster: Cluster,
423424
writable: bool = False,
424425
) -> dict:
425426
logger.debug(
@@ -439,6 +440,13 @@ def create_dra(
439440
s3_config["AutoExportPolicy"] = {"Events": ["NEW", "CHANGED", "DELETED"]}
440441

441442
logger.debug(f"s3 config: {s3_config}")
443+
dynamo_cluster = get_cluster_by_name(
444+
dynamodb_resource=dynamodb_resource(), cluster_name=cluster.name
445+
)
446+
if not dynamo_cluster:
447+
raise RuntimeError(f"Clould not retrieve cluster {cluster.name} from dynamodb")
448+
if dynamo_cluster["creation_time"] == 0:
449+
raise ValueError(f"Creation time for {cluster.name} is 0; something is wrong")
442450

443451
dra = fsx_client.create_data_repository_association(
444452
FileSystemId=filesystem_id,
@@ -447,12 +455,12 @@ def create_dra(
447455
BatchImportMetaDataOnCreate=True,
448456
ImportedFileChunkSize=1024,
449457
S3=s3_config,
450-
ClientRequestToken=f"{vlab_id}-{project_id}-{mountpoint.split('/')[-1]}",
458+
ClientRequestToken=f"{dynamo_cluster['creation_time']}-{cluster.vlab_id[:21]}-{cluster.project_id[:21]}-{mountpoint.split('/')[-1]}",
451459
Tags=[
452460
{"Key": "Name", "Value": f"{filesystem_id}-{mountpoint}"},
453461
{"Key": BILLING_TAG_KEY, "Value": BILLING_TAG_VALUE},
454-
{"Key": VLAB_TAG_KEY, "Value": vlab_id},
455-
{"Key": PROJECT_TAG_KEY, "Value": project_id},
462+
{"Key": VLAB_TAG_KEY, "Value": cluster.vlab_id},
463+
{"Key": PROJECT_TAG_KEY, "Value": cluster.project_id},
456464
],
457465
)
458466

hpc_provisioner/src/hpc_provisioner/cluster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ class Cluster:
2020
tier: str
2121
vlab_id: str
2222
provisioning_launched: bool
23+
creation_time: int
2324

2425
def __init__(
2526
self,
2627
project_id: str,
2728
vlab_id: str,
29+
creation_time: int,
2830
tier: str = "debug",
2931
benchmark: bool = False,
3032
dev: bool = False,
@@ -45,6 +47,7 @@ def __init__(
4547
else:
4648
self.admin_ssh_key_name = self.name
4749
self.provisioning_launched = provisioning_launched
50+
self.creation_time = creation_time
4851

4952
@property
5053
def name(self):

hpc_provisioner/src/hpc_provisioner/handlers.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import logging
44
import logging.config
5+
import time
56
from importlib.metadata import version
67

78
import boto3
@@ -142,7 +143,7 @@ def pcluster_create_request_handler(event, _context=None):
142143
* precreate ssh keys
143144
"""
144145

145-
cluster = _get_vlab_query_params(event)
146+
cluster = _get_vlab_query_params(event, set_creation_time=True)
146147
dynamo = dynamodb_resource()
147148
try:
148149
register_cluster(dynamo, cluster)
@@ -267,7 +268,13 @@ def pcluster_delete_handler(event, _context=None):
267268
return response_json(pc_output)
268269

269270

270-
def _get_vlab_query_params(incoming_event) -> Cluster:
271+
def _get_vlab_query_params(incoming_event, set_creation_time=False) -> Cluster:
272+
"""
273+
Retrieve the query parameters from the incoming event and create a Cluster object based on them.
274+
If set_creation_time is set, this implies that this is the first time creating this Cluster object
275+
(ie. we're being called from pcluster_create_request_handler) - this helps us keep track of
276+
when the cluster was requested.
277+
"""
271278
logger.debug(f"Getting query params from event {incoming_event}")
272279
event = copy.deepcopy(incoming_event)
273280

@@ -292,13 +299,17 @@ def _get_vlab_query_params(incoming_event) -> Cluster:
292299
)
293300
if param in event.get("queryStringParameters", {}):
294301
params[param] = event["queryStringParameters"][param]
302+
# we set creation_time to 0 if this is not the initial request
303+
# this shouldn't even be necessary as we only use it in a context where it was retrieved from dynamo
304+
# but it will allow us to check for a "suspicious" value if we ever need it somewhere else
295305
cluster = Cluster(
296306
project_id=params["project_id"],
297307
vlab_id=params["vlab_id"],
298308
tier=params["tier"],
299309
benchmark=params.get("benchmark", "").lower() == "true",
300310
dev=params.get("dev", "").lower() == "true",
301311
include_lustre=params.get("include_lustre", "").lower() == "true",
312+
creation_time=int(time.time()) if set_creation_time else 0,
302313
)
303314

304315
logger.debug(f"Params: {params}")

hpc_provisioner/src/hpc_provisioner/pcluster_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ def fsx_precreate(cluster: Cluster, filesystems: list) -> bool:
193193
filesystem_id=fs["FileSystemId"],
194194
mountpoint=dra["mountpoint"],
195195
bucket=get_fs_bucket(dra["name"], cluster),
196-
vlab_id=cluster.vlab_id,
197-
project_id=cluster.project_id,
196+
cluste=cluster,
198197
writable=dra["writable"],
199198
)
200199
return True

0 commit comments

Comments
 (0)