Skip to content

Commit 112ecbf

Browse files
feat: Support Athena query prepared statements & Athena parameterized queries (#2344)
1 parent baba6c9 commit 112ecbf

File tree

11 files changed

+744
-33
lines changed

11 files changed

+744
-33
lines changed

awswrangler/athena/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
wait_query,
88
)
99
from awswrangler.athena._spark import create_spark_session, run_spark_calculation
10+
from awswrangler.athena._statements import (
11+
create_prepared_statement,
12+
delete_prepared_statement,
13+
list_prepared_statements,
14+
)
1015
from awswrangler.athena._read import ( # noqa
1116
get_query_results,
1217
read_sql_query,
@@ -51,5 +56,8 @@
5156
"stop_query_execution",
5257
"unload",
5358
"wait_query",
59+
"create_prepared_statement",
60+
"list_prepared_statements",
61+
"delete_prepared_statement",
5462
"to_iceberg",
5563
]

awswrangler/athena/_executions.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,24 @@
44
from typing import (
55
Any,
66
Dict,
7+
List,
78
Optional,
89
Union,
910
cast,
1011
)
1112

1213
import boto3
1314
import botocore
15+
from typing_extensions import Literal
1416

1517
from awswrangler import _utils, exceptions, typing
1618
from awswrangler._config import apply_configs
17-
from awswrangler._sql_formatter import _process_sql_params
1819

1920
from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results
2021
from ._utils import (
2122
_QUERY_FINAL_STATES,
2223
_QUERY_WAIT_POLLING_DELAY,
24+
_apply_formatter,
2325
_get_workgroup_config,
2426
_start_query_execution,
2527
_WorkGroupConfig,
@@ -36,7 +38,8 @@ def start_query_execution(
3638
workgroup: Optional[str] = None,
3739
encryption: Optional[str] = None,
3840
kms_key: Optional[str] = None,
39-
params: Optional[Dict[str, Any]] = None,
41+
params: Union[Dict[str, Any], List[str], None] = None,
42+
paramstyle: Literal["qmark", "named"] = "named",
4043
boto3_session: Optional[boto3.Session] = None,
4144
athena_cache_settings: Optional[typing.AthenaCacheSettings] = None,
4245
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
@@ -64,10 +67,25 @@ def start_query_execution(
6467
None, 'SSE_S3', 'SSE_KMS', 'CSE_KMS'.
6568
kms_key : str, optional
6669
For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
67-
params: Dict[str, any], optional
68-
Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported.
69-
The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain
70-
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes.
70+
params: Dict[str, any] | List[str], optional
71+
Parameters that will be used for constructing the SQL query.
72+
Only named or question mark parameters are supported.
73+
The parameter style needs to be specified in the ``paramstyle`` parameter.
74+
75+
For ``paramstyle="named"``, this value needs to be a dictionary.
76+
The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain
77+
``:name``.
78+
The formatter will be applied client-side in this scenario.
79+
80+
For ``paramstyle="qmark"``, this value needs to be a list of strings.
81+
The formatter will be applied server-side.
82+
The values are applied sequentially to the parameters in the query in the order in which the parameters occur.
83+
paramstyle: str, optional
84+
Determines the style of ``params``.
85+
Possible values are:
86+
87+
- ``named``
88+
- ``qmark``
7189
boto3_session : boto3.Session(), optional
7290
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
7391
athena_cache_settings: typing.AthenaCacheSettings, optional
@@ -103,7 +121,8 @@ def start_query_execution(
103121
>>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...', data_source='...')
104122
105123
"""
106-
sql = _process_sql_params(sql, params)
124+
# Substitute query parameters if applicable
125+
sql, execution_params = _apply_formatter(sql, params, paramstyle)
107126
_logger.debug("Executing query:\n%s", sql)
108127

109128
athena_cache_settings = athena_cache_settings if athena_cache_settings else {}
@@ -139,6 +158,7 @@ def start_query_execution(
139158
workgroup=workgroup,
140159
encryption=encryption,
141160
kms_key=kms_key,
161+
execution_params=execution_params,
142162
boto3_session=boto3_session,
143163
)
144164
if wait:

awswrangler/athena/_executions.pyi

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import (
22
Any,
33
Dict,
4+
List,
45
Literal,
56
Optional,
67
Union,
@@ -19,7 +20,8 @@ def start_query_execution(
1920
workgroup: Optional[str] = ...,
2021
encryption: Optional[str] = ...,
2122
kms_key: Optional[str] = ...,
22-
params: Optional[Dict[str, Any]] = ...,
23+
params: Union[Dict[str, Any], List[str], None] = ...,
24+
paramstyle: Literal["qmark", "named"] = ...,
2325
boto3_session: Optional[boto3.Session] = ...,
2426
athena_cache_settings: Optional[typing.AthenaCacheSettings] = ...,
2527
athena_query_wait_polling_delay: float = ...,
@@ -35,7 +37,8 @@ def start_query_execution(
3537
workgroup: Optional[str] = ...,
3638
encryption: Optional[str] = ...,
3739
kms_key: Optional[str] = ...,
38-
params: Optional[Dict[str, Any]] = ...,
40+
params: Union[Dict[str, Any], List[str], None] = ...,
41+
paramstyle: Literal["qmark", "named"] = ...,
3942
boto3_session: Optional[boto3.Session] = ...,
4043
athena_cache_settings: Optional[typing.AthenaCacheSettings] = ...,
4144
athena_query_wait_polling_delay: float = ...,
@@ -51,7 +54,8 @@ def start_query_execution(
5154
workgroup: Optional[str] = ...,
5255
encryption: Optional[str] = ...,
5356
kms_key: Optional[str] = ...,
54-
params: Optional[Dict[str, Any]] = ...,
57+
params: Union[Dict[str, Any], List[str], None] = ...,
58+
paramstyle: Literal["qmark", "named"] = ...,
5559
boto3_session: Optional[boto3.Session] = ...,
5660
athena_cache_settings: Optional[typing.AthenaCacheSettings] = ...,
5761
athena_query_wait_polling_delay: float = ...,

awswrangler/athena/_read.py

+62-16
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from awswrangler import _utils, catalog, exceptions, s3, typing
1717
from awswrangler._config import apply_configs
1818
from awswrangler._data_types import cast_pandas_with_athena_types
19-
from awswrangler._sql_formatter import _process_sql_params
2019
from awswrangler.athena._utils import (
2120
_QUERY_WAIT_POLLING_DELAY,
21+
_apply_formatter,
2222
_apply_query_metadata,
2323
_empty_dataframe_response,
2424
_get_query_metadata,
@@ -287,6 +287,7 @@ def _resolve_query_without_cache_ctas(
287287
s3_additional_kwargs: Optional[Dict[str, Any]],
288288
boto3_session: Optional[boto3.Session],
289289
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
290+
execution_params: Optional[List[str]] = None,
290291
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
291292
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
292293
ctas_query_info: Dict[str, Union[str, _QueryMetadata]] = create_ctas_table(
@@ -304,6 +305,7 @@ def _resolve_query_without_cache_ctas(
304305
wait=True,
305306
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
306307
boto3_session=boto3_session,
308+
execution_params=execution_params,
307309
)
308310
fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"'
309311
ctas_query_metadata = cast(_QueryMetadata, ctas_query_info["ctas_query_metadata"])
@@ -342,6 +344,7 @@ def _resolve_query_without_cache_unload(
342344
s3_additional_kwargs: Optional[Dict[str, Any]],
343345
boto3_session: Optional[boto3.Session],
344346
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
347+
execution_params: Optional[List[str]] = None,
345348
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
346349
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
347350
query_metadata = _unload(
@@ -358,6 +361,7 @@ def _resolve_query_without_cache_unload(
358361
boto3_session=boto3_session,
359362
data_source=data_source,
360363
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
364+
execution_params=execution_params,
361365
)
362366
if file_format == "PARQUET":
363367
return _fetch_parquet_result(
@@ -389,6 +393,7 @@ def _resolve_query_without_cache_regular(
389393
athena_query_wait_polling_delay: float,
390394
s3_additional_kwargs: Optional[Dict[str, Any]],
391395
boto3_session: Optional[boto3.Session],
396+
execution_params: Optional[List[str]] = None,
392397
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
393398
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
394399
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
@@ -404,6 +409,7 @@ def _resolve_query_without_cache_regular(
404409
workgroup=workgroup,
405410
encryption=encryption,
406411
kms_key=kms_key,
412+
execution_params=execution_params,
407413
boto3_session=boto3_session,
408414
)
409415
_logger.debug("Query id: %s", query_id)
@@ -450,6 +456,7 @@ def _resolve_query_without_cache(
450456
s3_additional_kwargs: Optional[Dict[str, Any]],
451457
boto3_session: Optional[boto3.Session],
452458
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
459+
execution_params: Optional[List[str]] = None,
453460
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
454461
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
455462
"""
@@ -483,6 +490,7 @@ def _resolve_query_without_cache(
483490
s3_additional_kwargs=s3_additional_kwargs,
484491
boto3_session=boto3_session,
485492
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
493+
execution_params=execution_params,
486494
dtype_backend=dtype_backend,
487495
)
488496
finally:
@@ -510,6 +518,7 @@ def _resolve_query_without_cache(
510518
s3_additional_kwargs=s3_additional_kwargs,
511519
boto3_session=boto3_session,
512520
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
521+
execution_params=execution_params,
513522
dtype_backend=dtype_backend,
514523
)
515524
return _resolve_query_without_cache_regular(
@@ -527,6 +536,7 @@ def _resolve_query_without_cache(
527536
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
528537
s3_additional_kwargs=s3_additional_kwargs,
529538
boto3_session=boto3_session,
539+
execution_params=execution_params,
530540
dtype_backend=dtype_backend,
531541
)
532542

@@ -545,6 +555,7 @@ def _unload(
545555
boto3_session: Optional[boto3.Session],
546556
data_source: Optional[str],
547557
athena_query_wait_polling_delay: float,
558+
execution_params: Optional[List[str]],
548559
) -> _QueryMetadata:
549560
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
550561
s3_output: str = _get_s3_output(s3_output=path, wg_config=wg_config, boto3_session=boto3_session)
@@ -576,6 +587,7 @@ def _unload(
576587
encryption=encryption,
577588
kms_key=kms_key,
578589
boto3_session=boto3_session,
590+
execution_params=execution_params,
579591
)
580592
except botocore.exceptions.ClientError as ex:
581593
msg: str = str(ex)
@@ -735,7 +747,8 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
735747
athena_cache_settings: Optional[typing.AthenaCacheSettings] = None,
736748
data_source: Optional[str] = None,
737749
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
738-
params: Optional[Dict[str, Any]] = None,
750+
params: Union[Dict[str, Any], List[str], None] = None,
751+
paramstyle: Literal["qmark", "named"] = "named",
739752
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
740753
s3_additional_kwargs: Optional[Dict[str, Any]] = None,
741754
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
@@ -905,10 +918,25 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
905918
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
906919
athena_query_wait_polling_delay: float, default: 0.25 seconds
907920
Interval in seconds for how often the function will check if the Athena query has completed.
908-
params: Dict[str, any], optional
909-
Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported.
910-
The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain
911-
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes.
921+
params: Dict[str, any] | List[str], optional
922+
Parameters that will be used for constructing the SQL query.
923+
Only named or question mark parameters are supported.
924+
The parameter style needs to be specified in the ``paramstyle`` parameter.
925+
926+
For ``paramstyle="named"``, this value needs to be a dictionary.
927+
The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain
928+
``:name``.
929+
The formatter will be applied client-side in this scenario.
930+
931+
For ``paramstyle="qmark"``, this value needs to be a list of strings.
932+
The formatter will be applied server-side.
933+
The values are applied sequentially to the parameters in the query in the order in which the parameters occur.
934+
paramstyle: str, optional
935+
Determines the style of ``params``.
936+
Possible values are:
937+
938+
- ``named``
939+
- ``qmark``
912940
dtype_backend: str, optional
913941
Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays,
914942
nullable dtypes are used for all dtypes that have a nullable implementation when
@@ -964,15 +992,15 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
964992
raise exceptions.InvalidArgumentCombination("Only PARQUET file format is supported if unload_approach=True")
965993
chunksize = sys.maxsize if ctas_approach is False and chunksize is True else chunksize
966994

995+
# Substitute query parameters if applicable
996+
sql, execution_params = _apply_formatter(sql, params, paramstyle)
997+
967998
athena_cache_settings = athena_cache_settings if athena_cache_settings else {}
968999
max_cache_seconds = athena_cache_settings.get("max_cache_seconds", 0)
9691000
max_cache_query_inspections = athena_cache_settings.get("max_cache_query_inspections", 50)
9701001
max_remote_cache_entries = athena_cache_settings.get("max_remote_cache_entries", 50)
9711002
max_local_cache_entries = athena_cache_settings.get("max_local_cache_entries", 100)
9721003

973-
# Substitute query parameters
974-
sql = _process_sql_params(sql, params)
975-
9761004
max_remote_cache_entries = min(max_remote_cache_entries, max_local_cache_entries)
9771005

9781006
_cache_manager.max_cache_size = max_local_cache_entries
@@ -1032,6 +1060,7 @@ def read_sql_query( # pylint: disable=too-many-arguments,too-many-locals
10321060
s3_additional_kwargs=s3_additional_kwargs,
10331061
boto3_session=boto3_session,
10341062
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
1063+
execution_params=execution_params,
10351064
dtype_backend=dtype_backend,
10361065
)
10371066

@@ -1288,7 +1317,8 @@ def unload(
12881317
kms_key: Optional[str] = None,
12891318
boto3_session: Optional[boto3.Session] = None,
12901319
data_source: Optional[str] = None,
1291-
params: Optional[Dict[str, Any]] = None,
1320+
params: Union[Dict[str, Any], List[str], None] = None,
1321+
paramstyle: Literal["qmark", "named"] = "named",
12921322
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
12931323
) -> _QueryMetadata:
12941324
"""Write query results from a SELECT statement to the specified data format using UNLOAD.
@@ -1325,10 +1355,25 @@ def unload(
13251355
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
13261356
data_source : str, optional
13271357
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
1328-
params: Dict[str, any], optional
1329-
Dict of parameters that will be used for constructing the SQL query. Only named parameters are supported.
1330-
The dict needs to contain the information in the form {'name': 'value'} and the SQL query needs to contain
1331-
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes.
1358+
params: Dict[str, any] | List[str], optional
1359+
Parameters that will be used for constructing the SQL query.
1360+
Only named or question mark parameters are supported.
1361+
The parameter style needs to be specified in the ``paramstyle`` parameter.
1362+
1363+
For ``paramstyle="named"``, this value needs to be a dictionary.
1364+
The dict needs to contain the information in the form ``{'name': 'value'}`` and the SQL query needs to contain
1365+
``:name``.
1366+
The formatter will be applied client-side in this scenario.
1367+
1368+
For ``paramstyle="qmark"``, this value needs to be a list of strings.
1369+
The formatter will be applied server-side.
1370+
The values are applied sequentially to the parameters in the query in the order in which the parameters occur.
1371+
paramstyle: str, optional
1372+
Determines the style of ``params``.
1373+
Possible values are:
1374+
1375+
- ``named``
1376+
- ``qmark``
13321377
athena_query_wait_polling_delay: float, default: 0.25 seconds
13331378
Interval in seconds for how often the function will check if the Athena query has completed.
13341379
@@ -1346,8 +1391,8 @@ def unload(
13461391
... )
13471392
13481393
"""
1349-
# Substitute query parameters
1350-
sql = _process_sql_params(sql, params)
1394+
# Substitute query parameters if applicable
1395+
sql, execution_params = _apply_formatter(sql, params, paramstyle)
13511396
return _unload(
13521397
sql=sql,
13531398
path=path,
@@ -1362,4 +1407,5 @@ def unload(
13621407
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
13631408
boto3_session=boto3_session,
13641409
data_source=data_source,
1410+
execution_params=execution_params,
13651411
)

0 commit comments

Comments
 (0)