55import os
66from collections .abc import Sequence
77from typing import NamedTuple
8-
8+ import json
99
1010class BadFile (NamedTuple ):
1111 filename : str
1212 key : str
1313
14-
1514def get_aws_cred_files_from_env () -> set [str ]:
1615 """Extract credential file paths from environment variables."""
1716 return {
@@ -23,17 +22,38 @@ def get_aws_cred_files_from_env() -> set[str]:
2322 if env_var in os .environ
2423 }
2524
26-
2725def get_aws_secrets_from_env () -> set [str ]:
2826 """Extract AWS secrets from environment variables."""
2927 keys = set ()
3028 for env_var in (
31- 'AWS_SECRET_ACCESS_KEY' , 'AWS_SECURITY_TOKEN' , 'AWS_SESSION_TOKEN' ,
29+ 'AWS_SECRET_ACCESS_KEY' , 'AWS_SECURITY_TOKEN' , 'AWS_SESSION_TOKEN' ,
3230 ):
3331 if os .environ .get (env_var ):
3432 keys .add (os .environ [env_var ])
3533 return keys
3634
35+ def get_aws_secrets_from_json_file (json_credentials_file : str ) -> set [str ]:
36+ """Extract AWS secrets from JSON configuration files.
37+
38+ Read a JSON-style configuration file and return a set with all found AWS
39+ secret access keys.
40+ """
41+ aws_credentials_file_path = os .path .expanduser (json_credentials_file )
42+ if not os .path .exists (aws_credentials_file_path ):
43+ return set ()
44+
45+ with open (aws_credentials_file_path , 'r' ) as f :
46+ try :
47+ data = json .load (f )
48+ except json .JSONDecodeError :
49+ return set ()
50+
51+ keys = set ()
52+ for var in ('AccessKeyId' , 'SecretAccessKey' , 'SessionToken' , 'aws_secret_access_key' , 'aws_security_token' , 'aws_session_token' ):
53+ if var in data .get ('Credentials' , {}):
54+ keys .add (data ['Credentials' ][var ])
55+ return keys
56+
3757
3858def get_aws_secrets_from_file (credentials_file : str ) -> set [str ]:
3959 """Extract AWS secrets from configuration files.
@@ -54,8 +74,8 @@ def get_aws_secrets_from_file(credentials_file: str) -> set[str]:
5474 keys = set ()
5575 for section in parser .sections ():
5676 for var in (
57- 'aws_secret_access_key' , 'aws_security_token' ,
58- 'aws_session_token' ,
77+ 'aws_secret_access_key' , 'aws_security_token' ,
78+ 'aws_session_token' ,
5979 ):
6080 try :
6181 key = parser .get (section , var ).strip ()
@@ -104,6 +124,16 @@ def main(argv: Sequence[str] | None = None) -> int:
104124 'secret keys. Can be passed multiple times.'
105125 ),
106126 )
127+ parser .add_argument (
128+ '--json-credentials-file' ,
129+ dest = 'json_credential_file_locations' ,
130+ action = 'append' ,
131+ default = ['~/.aws/cli/cache/' , '~/.aws/login/cache/' ],
132+ help = (
133+ 'Location of additional AWS JSON credential file from which to get '
134+ 'secret keys. Can be passed multiple times.'
135+ ),
136+ )
107137 parser .add_argument (
108138 '--allow-missing-credentials' ,
109139 dest = 'allow_missing_credentials' ,
@@ -113,6 +143,13 @@ def main(argv: Sequence[str] | None = None) -> int:
113143 args = parser .parse_args (argv )
114144
115145 credential_files = set (args .credentials_file )
146+ json_credential_file_locations = set (args .json_credential_file_locations )
147+ json_credential_files = set ()
148+ for json_credential_file_location in json_credential_file_locations :
149+ if os .path .isdir (os .path .expanduser (json_credential_file_location )):
150+ for filename in os .listdir (os .path .expanduser (json_credential_file_location )):
151+ if filename .endswith ('.json' ):
152+ json_credential_files .add (os .path .join (json_credential_file_location , filename ))
116153
117154 # Add the credentials files configured via environment variables to the set
118155 # of files to to gather AWS secrets from.
@@ -122,6 +159,8 @@ def main(argv: Sequence[str] | None = None) -> int:
122159 for credential_file in credential_files :
123160 keys |= get_aws_secrets_from_file (credential_file )
124161
162+ for json_credential_file in json_credential_files :
163+ keys |= get_aws_secrets_from_json_file (json_credential_file )
125164 # Secrets might be part of environment variables, so add such secrets to
126165 # the set of keys.
127166 keys |= get_aws_secrets_from_env ()
@@ -148,4 +187,4 @@ def main(argv: Sequence[str] | None = None) -> int:
148187
149188
150189if __name__ == '__main__' :
151- raise SystemExit (main ())
190+ raise SystemExit (main ())
0 commit comments