32
32
USER_POOL_ID = os .getenv ("USER_POOL_ID" )
33
33
AUTH_PATH = os .getenv ("AUTH_PATH" )
34
34
API_BASE_URL = os .getenv ("API_BASE_URL" )
35
- API_VERSION = sorted (os .getenv ("API_VERSION" , "3.1.0" ).split ("," ), key = lambda x : [- int (n ) for n in x .split ('.' )])
35
+ API_VERSION = sorted (os .getenv ("API_VERSION" , "3.1.0" ).strip (). split ("," ), key = lambda x : [- int (n ) for n in x .split ('.' )])
36
36
DEFAULT_API_VERSION = API_VERSION [0 ]
37
37
API_USER_ROLE = os .getenv ("API_USER_ROLE" )
38
38
OIDC_PROVIDER = os .getenv ("OIDC_PROVIDER" )
39
39
CLIENT_ID = os .getenv ("CLIENT_ID" )
40
40
CLIENT_SECRET = os .getenv ("CLIENT_SECRET" )
41
41
SECRET_ID = os .getenv ("SECRET_ID" )
42
- SITE_URL = os .getenv ("SITE_URL" )
43
42
SCOPES_LIST = os .getenv ("SCOPES_LIST" )
44
43
REGION = os .getenv ("AWS_DEFAULT_REGION" )
45
44
TOKEN_URL = os .getenv ("TOKEN_URL" , f"{ AUTH_PATH } /oauth2/token" )
49
48
AUDIENCE = os .getenv ("AUDIENCE" )
50
49
USER_ROLES_CLAIM = os .getenv ("USER_ROLES_CLAIM" , "cognito:groups" )
51
50
SSM_LOG_GROUP_NAME = os .getenv ("SSM_LOG_GROUP_NAME" )
51
+ ARG_VERSION = "version"
52
52
53
53
try :
54
54
if (not USER_POOL_ID or USER_POOL_ID == "" ) and SECRET_ID :
63
63
if not JWKS_URL :
64
64
JWKS_URL = os .getenv ("JWKS_URL" ,
65
65
f"https://cognito-idp.{ REGION } .amazonaws.com/{ USER_POOL_ID } /" ".well-known/jwks.json" )
66
- API_BASE_URL_MAPPING = {}
67
66
68
- if API_BASE_URL :
69
- for url in API_BASE_URL .split ("," ):
70
- if url :
71
- pair = url .split ("=" )
72
- API_BASE_URL_MAPPING [pair [0 ]] = pair [1 ]
67
+ def create_url_map (url_list ):
68
+ url_map = {}
69
+ if url_list :
70
+ for url in url_list .split ("," ):
71
+ if url :
72
+ pair = url .split ("=" )
73
+ url_map [pair [0 ]] = pair [1 ]
74
+ return url_map
75
+
76
+ API_BASE_URL_MAPPING = create_url_map (API_BASE_URL )
77
+ SITE_URL = os .getenv ("SITE_URL" , API_BASE_URL_MAPPING .get (DEFAULT_API_VERSION ))
73
78
74
79
75
80
@@ -242,9 +247,9 @@ def ec2_action():
242
247
def get_cluster_config_text (cluster_name , region = None ):
243
248
url = f"/v3/clusters/{ cluster_name } "
244
249
if region :
245
- info_resp = sigv4_request ("GET" , get_base_url (request . args . get ( "version" ) ), url , params = {"region" : region })
250
+ info_resp = sigv4_request ("GET" , get_base_url (request ), url , params = {"region" : region })
246
251
else :
247
- info_resp = sigv4_request ("GET" , get_base_url (request . args . get ( "version" ) ), url )
252
+ info_resp = sigv4_request ("GET" , get_base_url (request ), url )
248
253
if info_resp .status_code != 200 :
249
254
abort (info_resp .status_code )
250
255
@@ -493,7 +498,7 @@ def get_dcv_session():
493
498
494
499
495
500
def get_custom_image_config ():
496
- image_info = sigv4_request ("GET" , get_base_url (request . args . get ( "version" ) ), f"/v3/images/custom/{ request .args .get ('image_id' )} " ).json ()
501
+ image_info = sigv4_request ("GET" , get_base_url (request ), f"/v3/images/custom/{ request .args .get ('image_id' )} " ).json ()
497
502
configuration = requests .get (image_info ["imageConfiguration" ]["url" ])
498
503
return configuration .text
499
504
@@ -744,9 +749,10 @@ def _get_params(_request):
744
749
params .pop ("path" )
745
750
return params
746
751
747
- def get_base_url (v ):
748
- if v and str (v ) in API_VERSION :
749
- return API_BASE_URL_MAPPING [str (v )]
752
+ def get_base_url (request ):
753
+ version = request .args .get (ARG_VERSION )
754
+ if version and str (version ) in API_VERSION :
755
+ return API_BASE_URL_MAPPING [str (version )]
750
756
return API_BASE_URL_MAPPING [DEFAULT_API_VERSION ]
751
757
752
758
@@ -756,7 +762,7 @@ def get_base_url(v):
756
762
@authenticated ({'admin' })
757
763
@validated (params = PCProxyArgs )
758
764
def pc_proxy_get ():
759
- response = sigv4_request (request .method , get_base_url (request . args . get ( "version" ) ), request .args .get ("path" ), _get_params (request ))
765
+ response = sigv4_request (request .method , get_base_url (request ), request .args .get ("path" ), _get_params (request ))
760
766
return response .json (), response .status_code
761
767
762
768
@pc .route ('/' , methods = ['POST' ,'PUT' ,'PATCH' ,'DELETE' ], strict_slashes = False )
@@ -770,5 +776,5 @@ def pc_proxy():
770
776
except :
771
777
pass
772
778
773
- response = sigv4_request (request .method , get_base_url (request . args . get ( "version" ) ), request .args .get ("path" ), _get_params (request ), body = body )
779
+ response = sigv4_request (request .method , get_base_url (request ), request .args .get ("path" ), _get_params (request ), body = body )
774
780
return response .json (), response .status_code
0 commit comments