Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions util/upload-node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
#!/usr/bin/python
#
# Copyright 2019-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
# is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, express or implied. See the License for the specific language
# governing permissions and limitations under the License.
#
#
# Upload the aws-parallelcluster-node package to S3.
#
# Mirrors upload-cookbook.py: the cookbook installs the node package from
# s3://<region>-aws-parallelcluster/parallelcluster/<version>/node/ in all
# regions (see parallelcluster_node.rb), so the release flow must push it there.
#
# usage: ./upload-node.py --regions "<region>[,<region>, ...]" --node-archive-path "<path to node tgz>" \
# --partition <partition> \
# [--unsupportedregions "<region>[, <region>, ...]"] [--dryrun] [--override] \
# [--credential <region>,<endpoint>,<arn>,<role>]*
import hashlib
import os
from datetime import datetime
from importlib.metadata import version

import argparse
import boto3
from botocore.exceptions import ClientError

_NODE_DIR = "parallelcluster/{version}/node".format(version=version("aws-parallelcluster"))
_BACKUP_DIR = "{0}/backup".format(_NODE_DIR)
_bck_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
_bck_error_array = set()
_cp_error_array = set()
_ls_error_array = set()
_credentials = []
_main_region = None


def _get_all_aws_regions(region):
ec2 = boto3.client("ec2", region_name=region)
return set(sorted(r.get("RegionName") for r in ec2.describe_regions().get("Regions")))


def _aws_s3_ls(s3, region, bucket_name, key):
out = s3.list_objects_v2(Bucket=bucket_name, Prefix=key)
if len(out.get("Contents", [])) > 0:
_ls_error_array.add(region)


def _aws_s3_bck(s3, args, region, bucket_name, full_name):
if args.dryrun:
print(
"Not backing up {0} to bucket {1} override is {2}, dryrun is {3}".format(
full_name, bucket_name, args.override, args.dryrun
)
)
else:
try:
copy_source = {"Bucket": bucket_name, "Key": _NODE_DIR + "/" + full_name}
s3.copy(copy_source, bucket_name, _BACKUP_DIR + "/" + full_name + _bck_date)
except ClientError as e:
print("Couldn't backup {0}".format(full_name))
if e.response["Error"]["Code"] == "NoSuchBucket":
print("Bucket is not present.")
_bck_error_array.add(region)


def _aws_s3_cp(s3, args, region, bucket_name, folder, src_file):
key = folder + "/" + os.path.basename(src_file)
print("Bucket dest key: {0}".format(key))
if args.dryrun:
print(
"Not uploading {0} to bucket {1}, override is {2}, dryrun is {3}".format(
src_file, bucket_name, args.override, args.dryrun
)
)
else:
try:
s3.upload_file(src_file, bucket_name, key, ExtraArgs={"ACL": "public-read"})

print("Successfully uploaded {0} to s3://{1}/{2}".format(src_file, bucket_name, key))
except ClientError as e:
print("Couldn't upload {0} to bucket s3://{1}/{2}".format(src_file, bucket_name, key))
_cp_error_array.add(region)
if e.response["Error"]["Code"] == "NoSuchBucket":
print("Bucket is not present.")

raise e


def _create_s3_client(region):
reg_credentials = [c for c in _credentials if c[0] == region]

if reg_credentials:
credential = reg_credentials[0]
credential_region = credential[0]
credential_endpoint = credential[1]
credential_arn = credential[2]
credential_external_id = credential[3]

try:
sts = boto3.client("sts", region_name=_main_region, endpoint_url=credential_endpoint)

assumed_role_object = sts.assume_role(
RoleArn=credential_arn,
ExternalId=credential_external_id,
RoleSessionName=credential_region + "upload_node_sts_session",
)
aws_credentials = assumed_role_object["Credentials"]
s3 = boto3.client(
"s3",
region_name=credential_region,
aws_access_key_id=aws_credentials.get("AccessKeyId"),
aws_secret_access_key=aws_credentials.get("SecretAccessKey"),
aws_session_token=aws_credentials.get("SessionToken"),
)

except ClientError as e:
print("Warning: non authorized in region '{0}', skipping".format(credential_region))
raise e
else:
s3 = boto3.client("s3", region_name=region)
return s3


def _get_bucket_name(args, region):
return region + "-aws-parallelcluster" if not args.bucket else args.bucket


def _md5sum(node_archive_file, md5sum_file):
blocksize = 65536
hasher = hashlib.md5() # nosec
with open(node_archive_file, "rb") as arch:
buf = arch.read(blocksize)
while len(buf) > 0:
hasher.update(buf)
buf = arch.read(blocksize)

with open(md5sum_file, "w+", encoding="utf-8") as md5:
md5.write("{0} {1}".format(hasher.hexdigest(), os.path.basename(node_archive_file)))


def _parse_args():
global _credentials
global _main_region
parser = argparse.ArgumentParser(description="Uploads aws-parallelcluster-node to S3")

parser.add_argument(
"--regions",
type=str,
help='Valid Regions, can include "all", or comma separated list of regions',
required=True,
)
parser.add_argument(
"--unsupportedregions", type=str, help="Unsupported regions, comma separated", default="", required=False
)
parser.add_argument(
"--override",
action="store_true",
help="If override is false, the file will not be pushed if it already exists in the bucket",
default=False,
required=False,
)
parser.add_argument(
"--bucket", type=str, help="Buckets to upload to, defaults to [region]-aws-parallelcluster", required=False
)
parser.add_argument("--node-archive-path", type=str, help="Node package archive path", required=True)
parser.add_argument(
"--dryrun", action="store_true", help="Doesn't push anything to S3, just outputs", default=False, required=False
)

parser.add_argument("--partition", type=str, help="commercial | china | govcloud", required=True)
parser.add_argument(
"--credential",
type=str,
action="append",
help="STS credential endpoint, in the format <region>,<endpoint>,<ARN>,<externalId>. "
"Could be specified multiple times",
required=False,
)

args = parser.parse_args()
if args.partition == "commercial":
_main_region = "us-east-1"
elif args.partition == "govcloud":
_main_region = "us-gov-west-1"
elif args.partition == "china":
_main_region = "cn-north-1"
else:
print("Unsupported partition {0}".format(args.partition))
exit(1)

Check warning

Code scanning / CodeQL

Use of exit() or quit() Warning

The 'exit' site.Quitter object may not exist if the 'site' module is not loaded or is modified.

if args.credential:
_credentials = [
tuple(credential_tuple.strip().split(","))
for credential_tuple in args.credential
if credential_tuple.strip()
]

if args.regions == "all":
args.regions = _get_all_aws_regions(_main_region)
else:
args.regions = [x.strip() for x in args.regions.split(",")]

args.unsupportedregions = [x.strip() for x in args.unsupportedregions.split(",")]

# Purging regions
args.regions = set(args.regions) - set(args.unsupportedregions)

# Adds all opt-in regions
for credential in _credentials:
args.regions.add(credential[0])

return args


def main():
args = _parse_args()

# Check if archive exists
if not os.path.exists(args.node_archive_path):
print("Node archive {0} not found".format(args.node_archive_path))
exit(1)

Check warning

Code scanning / CodeQL

Use of exit() or quit() Warning

The 'exit' site.Quitter object may not exist if the 'site' module is not loaded or is modified.

base_name = os.path.splitext(os.path.basename(args.node_archive_path))[0]
_md5sum(args.node_archive_path, "{0}.md5".format(base_name))

for region in args.regions:
s3 = _create_s3_client(region)
bucket_name = _get_bucket_name(args, region)

s3_key = _NODE_DIR + "/" + base_name + ".tgz"
print("Listing node package for region: {0}, bucket: {1}, key: {2}".format(region, bucket_name, s3_key))
_aws_s3_ls(s3, region, bucket_name, s3_key)

if len(_ls_error_array) > 0 and not args.override:
print("We know the node archives are already there, in this round we need to upload the .date files!")
print("Failed to push node, already present for regions: {0} ".format(" ".join(_ls_error_array)))
exit(1)

Check warning

Code scanning / CodeQL

Use of exit() or quit() Warning

The 'exit' site.Quitter object may not exist if the 'site' module is not loaded or is modified.
elif len(_ls_error_array) > 0 and args.override:
print("Some or all of the node archives are already there but OVERRIDE=true")

for region in args.regions:
s3 = _create_s3_client(region)
bucket_name = _get_bucket_name(args, region)

if args.override:
print("Backup node package for region: {0}".format(region))
_aws_s3_bck(s3, args, region, bucket_name, base_name + ".tgz")
_aws_s3_bck(s3, args, region, bucket_name, base_name + ".md5")
_aws_s3_bck(s3, args, region, bucket_name, base_name + ".tgz.date")

print("Pushing node package for region: {0}".format(region))
_aws_s3_cp(s3, args, region, bucket_name, _NODE_DIR, args.node_archive_path)
_aws_s3_cp(s3, args, region, bucket_name, _NODE_DIR, base_name + ".md5")

if not args.dryrun:
# Stores LastModified info into .tgz.date file and uploads it back to bucket
with open(base_name + ".tgz.date", "w+") as f:
response = s3.head_object(Bucket=bucket_name, Key=_NODE_DIR + "/" + base_name + ".tgz")
f.write(response.get("LastModified").strftime("%Y-%m-%d_%H-%M-%S"))

_aws_s3_cp(s3, args, region, bucket_name, _NODE_DIR, base_name + ".tgz.date")
else:
print("File {0}.{1} not stored to bucket {2} due to dryrun mode".format(base_name, "tgz.date", bucket_name))

if len(_bck_error_array) > 0:
print("Failed to backup node for region ({0})".format(" ".join(_bck_error_array)))

if len(_cp_error_array) > 0:
print("Failed to push node for region ({0})".format(" ".join(_cp_error_array)))
exit(1)

Check warning

Code scanning / CodeQL

Use of exit() or quit() Warning

The 'exit' site.Quitter object may not exist if the 'site' module is not loaded or is modified.


if __name__ == "__main__":
main()
Loading