Skip to content

Commit 51ce25a

Browse files
authoredJun 27, 2023
feat: Add CleanRooms read module (#2366)
1 parent e8291b4 commit 51ce25a

File tree

16 files changed

+654
-60
lines changed

16 files changed

+654
-60
lines changed
 

‎.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,12 @@ building/lambda/arrow
153153
*.swp
154154

155155
# CDK
156+
node_modules
157+
*package.json
156158
*package-lock.json
157159
*.cdk.staging
158160
*cdk.out
161+
*cdk.context.json
159162

160163
# ruff
161164
.ruff_cache/

‎awswrangler/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
athena,
1212
catalog,
1313
chime,
14+
cleanrooms,
1415
cloudwatch,
1516
data_api,
1617
data_quality,
@@ -43,6 +44,7 @@
4344
"athena",
4445
"catalog",
4546
"chime",
47+
"cleanrooms",
4648
"cloudwatch",
4749
"emr",
4850
"emr_serverless",

‎awswrangler/_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from boto3.resources.base import ServiceResource
4646
from botocore.client import BaseClient
4747
from mypy_boto3_athena import AthenaClient
48+
from mypy_boto3_cleanrooms import CleanRoomsServiceClient
4849
from mypy_boto3_dynamodb import DynamoDBClient, DynamoDBServiceResource
4950
from mypy_boto3_ec2 import EC2Client
5051
from mypy_boto3_emr.client import EMRClient
@@ -68,6 +69,7 @@
6869

6970
ServiceName = Literal[
7071
"athena",
72+
"cleanrooms",
7173
"dynamodb",
7274
"ec2",
7375
"emr",
@@ -286,6 +288,16 @@ def client(
286288
...
287289

288290

291+
@overload
292+
def client(
293+
service_name: 'Literal["cleanrooms"]',
294+
session: Optional[boto3.Session] = None,
295+
botocore_config: Optional[Config] = None,
296+
verify: Optional[Union[str, bool]] = None,
297+
) -> "CleanRoomsServiceClient":
298+
...
299+
300+
289301
@overload
290302
def client(
291303
service_name: 'Literal["lakeformation"]',

‎awswrangler/cleanrooms/__init__.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Amazon Clean Rooms Module."""
2+
3+
from awswrangler.cleanrooms._read import read_sql_query
4+
from awswrangler.cleanrooms._utils import wait_query
5+
6+
__all__ = [
7+
"read_sql_query",
8+
"wait_query",
9+
]

‎awswrangler/cleanrooms/_read.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Amazon Clean Rooms Module hosting read_* functions."""
2+
3+
import logging
4+
from typing import Any, Dict, Iterator, Optional, Union
5+
6+
import boto3
7+
8+
import awswrangler.pandas as pd
9+
from awswrangler import _utils, s3
10+
from awswrangler._sql_formatter import _process_sql_params
11+
from awswrangler.cleanrooms._utils import wait_query
12+
13+
_logger: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
def _delete_after_iterate(
17+
dfs: Iterator[pd.DataFrame], keep_files: bool, kwargs: Dict[str, Any]
18+
) -> Iterator[pd.DataFrame]:
19+
for df in dfs:
20+
yield df
21+
if keep_files is False:
22+
s3.delete_objects(**kwargs)
23+
24+
25+
def read_sql_query(
26+
sql: str,
27+
membership_id: str,
28+
output_bucket: str,
29+
output_prefix: str,
30+
keep_files: bool = True,
31+
params: Optional[Dict[str, Any]] = None,
32+
chunksize: Optional[Union[int, bool]] = None,
33+
use_threads: Union[bool, int] = True,
34+
boto3_session: Optional[boto3.Session] = None,
35+
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
36+
) -> Union[Iterator[pd.DataFrame], pd.DataFrame]:
37+
"""Execute Clean Rooms Protected SQL query and return the results as a Pandas DataFrame.
38+
39+
Parameters
40+
----------
41+
sql : str
42+
SQL query
43+
membership_id : str
44+
Membership ID
45+
output_bucket : str
46+
S3 output bucket name
47+
output_prefix : str
48+
S3 output prefix
49+
keep_files : bool, optional
50+
Whether files in S3 output bucket/prefix are retained. 'True' by default
51+
params : Dict[str, any], optional
52+
Dict of parameters used for constructing the SQL query. Only named parameters are supported.
53+
The dict must be in the form {'name': 'value'} and the SQL query must contain
54+
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes
55+
chunksize : Union[int, bool], optional
56+
If passed, the data is split into an iterable of DataFrames (Memory friendly).
57+
If `True` an iterable of DataFrames is returned without guarantee of chunksize.
58+
If an `INTEGER` is passed, an iterable of DataFrames is returned with maximum rows
59+
equal to the received INTEGER
60+
use_threads : Union[bool, int], optional
61+
True to enable concurrent requests, False to disable multiple threads.
62+
If enabled os.cpu_count() is used as the maximum number of threads.
63+
If integer is provided, specified number is used
64+
boto3_session : boto3.Session, optional
65+
Boto3 Session. If None, the default boto3 session is used
66+
pyarrow_additional_kwargs : Optional[Dict[str, Any]]
67+
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame.
68+
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
69+
e.g. pyarrow_additional_kwargs={'split_blocks': True}
70+
71+
Returns
72+
-------
73+
Union[Iterator[pd.DataFrame], pd.DataFrame]
74+
Pandas DataFrame or Generator of Pandas DataFrames if chunksize is provided.
75+
76+
Examples
77+
--------
78+
>>> import awswrangler as wr
79+
>>> df = wr.cleanrooms.read_sql_query(
80+
>>> sql='SELECT DISTINCT...',
81+
>>> membership_id='membership-id',
82+
>>> output_bucket='output-bucket',
83+
>>> output_prefix='output-prefix',
84+
>>> )
85+
"""
86+
client_cleanrooms = _utils.client(service_name="cleanrooms", session=boto3_session)
87+
88+
query_id: str = client_cleanrooms.start_protected_query(
89+
type="SQL",
90+
membershipIdentifier=membership_id,
91+
sqlParameters={"queryString": _process_sql_params(sql, params, engine_type="partiql")},
92+
resultConfiguration={
93+
"outputConfiguration": {
94+
"s3": {
95+
"bucket": output_bucket,
96+
"keyPrefix": output_prefix,
97+
"resultFormat": "PARQUET",
98+
}
99+
}
100+
},
101+
)["protectedQuery"]["id"]
102+
103+
_logger.debug("query_id: %s", query_id)
104+
path: str = wait_query(membership_id=membership_id, query_id=query_id)["protectedQuery"]["result"]["output"]["s3"][
105+
"location"
106+
]
107+
108+
_logger.debug("path: %s", path)
109+
chunked: Union[bool, int] = False if chunksize is None else chunksize
110+
ret = s3.read_parquet(
111+
path=path,
112+
use_threads=use_threads,
113+
chunked=chunked,
114+
boto3_session=boto3_session,
115+
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
116+
)
117+
118+
_logger.debug("type(ret): %s", type(ret))
119+
kwargs: Dict[str, Any] = {
120+
"path": path,
121+
"use_threads": use_threads,
122+
"boto3_session": boto3_session,
123+
}
124+
if chunked is False:
125+
if keep_files is False:
126+
s3.delete_objects(**kwargs)
127+
return ret
128+
return _delete_after_iterate(ret, keep_files, kwargs)

‎awswrangler/cleanrooms/_utils.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""Utilities Module for Amazon Clean Rooms."""
2+
import logging
3+
import time
4+
from typing import TYPE_CHECKING, List, Optional
5+
6+
import boto3
7+
8+
from awswrangler import _utils, exceptions
9+
10+
if TYPE_CHECKING:
11+
from mypy_boto3_cleanrooms.type_defs import GetProtectedQueryOutputTypeDef
12+
13+
_QUERY_FINAL_STATES: List[str] = ["CANCELLED", "FAILED", "SUCCESS", "TIMED_OUT"]
14+
_QUERY_WAIT_POLLING_DELAY: float = 2 # SECONDS
15+
16+
_logger: logging.Logger = logging.getLogger(__name__)
17+
18+
19+
def wait_query(
20+
membership_id: str, query_id: str, boto3_session: Optional[boto3.Session] = None
21+
) -> "GetProtectedQueryOutputTypeDef":
22+
"""Wait for the Clean Rooms protected query to end.
23+
24+
Parameters
25+
----------
26+
membership_id : str
27+
Membership ID
28+
query_id : str
29+
Protected query execution ID
30+
boto3_session : boto3.Session, optional
31+
Boto3 Session. If None, the default boto3 session is used
32+
Returns
33+
-------
34+
Dict[str, Any]
35+
Dictionary with the get_protected_query response.
36+
37+
Raises
38+
------
39+
exceptions.QueryFailed
40+
Raises exception with error message if protected query is cancelled, times out or fails.
41+
42+
Examples
43+
--------
44+
>>> import awswrangler as wr
45+
>>> res = wr.cleanrooms.wait_query(membership_id='membership-id', query_id='query-id')
46+
"""
47+
client_cleanrooms = _utils.client(service_name="cleanrooms", session=boto3_session)
48+
state = "SUBMITTED"
49+
50+
while state not in _QUERY_FINAL_STATES:
51+
time.sleep(_QUERY_WAIT_POLLING_DELAY)
52+
response = client_cleanrooms.get_protected_query(
53+
membershipIdentifier=membership_id, protectedQueryIdentifier=query_id
54+
)
55+
state = response["protectedQuery"].get("status") # type: ignore[assignment]
56+
57+
_logger.debug("state: %s", state)
58+
if state != "SUCCESS":
59+
raise exceptions.QueryFailed(response["protectedQuery"].get("Error"))
60+
return response

‎docs/source/api.rst

+12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ API Reference
1717
* `Amazon Neptune`_
1818
* `DynamoDB`_
1919
* `Amazon Timestream`_
20+
* `AWS Clean Rooms`_
2021
* `Amazon EMR`_
2122
* `Amazon EMR Serverless`_
2223
* `Amazon CloudWatch Logs`_
@@ -351,6 +352,17 @@ Amazon Timestream
351352
unload_to_files
352353
unload
353354

355+
AWS Clean Rooms
356+
-----------------
357+
358+
.. currentmodule:: awswrangler.cleanrooms
359+
360+
.. autosummary::
361+
:toctree: stubs
362+
363+
read_sql_query
364+
wait_query
365+
354366
Amazon EMR
355367
----------
356368

‎poetry.lock

+19-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ wheel = "^0.38.1"
8585

8686
# Lint
8787
black = "^23.1.0"
88-
boto3-stubs = {version = "1.26.151", extras = ["athena", "chime", "cloudwatch", "dynamodb", "ec2", "emr", "emr-serverless", "glue", "kms", "lakeformation", "logs", "neptune", "opensearch", "opensearchserverless", "quicksight", "rds", "rds-data", "redshift", "redshift-data", "s3", "secretsmanager", "ssm", "sts", "timestream-query", "timestream-write"]}
88+
boto3-stubs = {version = "^1.26.151", extras = ["athena", "cleanrooms", "chime", "cloudwatch", "dynamodb", "ec2", "emr", "emr-serverless", "glue", "kms", "lakeformation", "logs", "neptune", "opensearch", "opensearchserverless", "quicksight", "rds", "rds-data", "redshift", "redshift-data", "s3", "secretsmanager", "ssm", "sts", "timestream-query", "timestream-write"]}
8989
doc8 = "^1.0"
9090
mypy = "^1.0"
9191
pylint = "^2.17"

‎test_infra/app.py

+7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from aws_cdk import App, Environment
55
from stacks.base_stack import BaseStack
6+
from stacks.cleanrooms_stack import CleanRoomsStack
67
from stacks.databases_stack import DatabasesStack
78
from stacks.glueray_stack import GlueRayStack
89
from stacks.opensearch_stack import OpenSearchStack
@@ -42,4 +43,10 @@
4243
**env,
4344
)
4445

46+
CleanRoomsStack(
47+
app,
48+
"aws-sdk-pandas-cleanrooms",
49+
**env,
50+
)
51+
4552
app.synth()

‎test_infra/poetry.lock

+61-52
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎test_infra/pyproject.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ license = "Apache License 2.0"
77

88
[tool.poetry.dependencies]
99
python = ">=3.7.1, <4.0"
10-
"aws-cdk-lib" = "^2.64.0"
10+
"aws-cdk-lib" = "^2.85.0"
1111
"constructs" = ">=10.0.0,<11.0.0"
12-
"aws-cdk.aws-glue-alpha" = "^2.64.0a0"
13-
"aws-cdk.aws-redshift-alpha" = "^2.64.0a0"
14-
"aws-cdk.aws-neptune-alpha" = "^2.64.0a0"
12+
"aws-cdk.aws-glue-alpha" = "^2.85.0a0"
13+
"aws-cdk.aws-redshift-alpha" = "^2.85.0a0"
14+
"aws-cdk.aws-neptune-alpha" = "^2.85.0a0"

‎test_infra/stacks/cleanrooms_stack.py

+241
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from aws_cdk import CfnOutput, Duration, Stack
2+
from aws_cdk import aws_cleanrooms as cleanrooms
3+
from aws_cdk import aws_glue_alpha as glue
4+
from aws_cdk import aws_iam as iam
5+
from aws_cdk import aws_s3 as s3
6+
from aws_cdk import aws_ssm as ssm
7+
from constructs import Construct
8+
9+
10+
class CleanRoomsStack(Stack): # type: ignore
11+
def __init__(
12+
self,
13+
scope: Construct,
14+
construct_id: str,
15+
**kwargs: str,
16+
) -> None:
17+
super().__init__(scope, construct_id, **kwargs)
18+
19+
self.collaboration = cleanrooms.CfnCollaboration(
20+
self,
21+
"Collaboration",
22+
name="AWS SDK for pandas - Testing",
23+
creator_display_name="Collaborator Creator",
24+
creator_member_abilities=["CAN_QUERY", "CAN_RECEIVE_RESULTS"],
25+
description="Collaboration Room for AWS SDK for pandas test infrastructure",
26+
members=[],
27+
query_log_status="ENABLED",
28+
)
29+
30+
self.membership = cleanrooms.CfnMembership(
31+
self,
32+
"Membership",
33+
collaboration_identifier=self.collaboration.attr_collaboration_identifier,
34+
query_log_status="ENABLED",
35+
)
36+
37+
self.cleanrooms_service_role = iam.Role(
38+
self,
39+
"Service Role",
40+
assumed_by=iam.CompositePrincipal(
41+
iam.ServicePrincipal("cleanrooms.amazonaws.com").with_conditions(
42+
{
43+
"StringLike": {
44+
"sts:ExternalId": f"arn:aws:*:{self.region}:*:dbuser:*/{self.membership.attr_membership_identifier}*"
45+
}
46+
}
47+
),
48+
iam.ServicePrincipal("cleanrooms.amazonaws.com").with_conditions(
49+
{
50+
"ForAnyValue:ArnEquals": {
51+
"aws:SourceArn": f"arn:aws:cleanrooms:{self.region}:{self.account}:membership/{self.membership.attr_membership_identifier}"
52+
}
53+
}
54+
),
55+
),
56+
managed_policies=[
57+
iam.ManagedPolicy.from_aws_managed_policy_name("service-role/AWSGlueServiceRole"),
58+
iam.ManagedPolicy.from_aws_managed_policy_name("AmazonS3ReadOnlyAccess"),
59+
],
60+
)
61+
62+
self.bucket = s3.Bucket(
63+
self,
64+
"Bucket",
65+
block_public_access=s3.BlockPublicAccess(
66+
block_public_acls=True,
67+
block_public_policy=True,
68+
ignore_public_acls=True,
69+
restrict_public_buckets=True,
70+
),
71+
lifecycle_rules=[
72+
s3.LifecycleRule(
73+
id="CleaningUp",
74+
enabled=True,
75+
expiration=Duration.days(1),
76+
abort_incomplete_multipart_upload_after=Duration.days(1),
77+
),
78+
],
79+
versioned=True,
80+
)
81+
82+
self.database = glue.Database(
83+
self,
84+
id="Glue Database",
85+
database_name="aws_sdk_pandas_cleanrooms",
86+
location_uri=f"s3://{self.bucket.bucket_name}",
87+
)
88+
89+
self.users_table = glue.Table(
90+
self,
91+
"Users Table",
92+
database=self.database,
93+
table_name="users",
94+
columns=[
95+
glue.Column(name="user_id", type=glue.Type(input_string="int", is_primitive=True)),
96+
glue.Column(name="city", type=glue.Type(input_string="string", is_primitive=True)),
97+
],
98+
bucket=self.bucket,
99+
s3_prefix="users",
100+
data_format=glue.DataFormat(
101+
input_format=glue.InputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"),
102+
output_format=glue.OutputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"),
103+
serialization_library=glue.SerializationLibrary(
104+
"org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"
105+
),
106+
),
107+
)
108+
109+
self.purchases_table = glue.Table(
110+
self,
111+
"Purchases Table",
112+
database=self.database,
113+
table_name="purchases",
114+
columns=[
115+
glue.Column(name="purchase_id", type=glue.Type(input_string="int", is_primitive=True)),
116+
glue.Column(name="user_id", type=glue.Type(input_string="int", is_primitive=True)),
117+
glue.Column(name="sale_value", type=glue.Type(input_string="float", is_primitive=True)),
118+
],
119+
bucket=self.bucket,
120+
s3_prefix="purchases",
121+
data_format=glue.DataFormat(
122+
input_format=glue.InputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"),
123+
output_format=glue.OutputFormat("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"),
124+
serialization_library=glue.SerializationLibrary(
125+
"org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"
126+
),
127+
),
128+
)
129+
130+
self.users_configured_table = cleanrooms.CfnConfiguredTable(
131+
self,
132+
"Users Configured Table",
133+
allowed_columns=["user_id", "city"],
134+
analysis_method="DIRECT_QUERY",
135+
name="users",
136+
table_reference=cleanrooms.CfnConfiguredTable.TableReferenceProperty(
137+
glue=cleanrooms.CfnConfiguredTable.GlueTableReferenceProperty(
138+
database_name=self.database.database_name,
139+
table_name=self.users_table.table_name,
140+
)
141+
),
142+
analysis_rules=[
143+
cleanrooms.CfnConfiguredTable.AnalysisRuleProperty(
144+
policy=cleanrooms.CfnConfiguredTable.ConfiguredTableAnalysisRulePolicyProperty(
145+
v1=cleanrooms.CfnConfiguredTable.ConfiguredTableAnalysisRulePolicyV1Property(
146+
aggregation=cleanrooms.CfnConfiguredTable.AnalysisRuleAggregationProperty(
147+
aggregate_columns=[
148+
cleanrooms.CfnConfiguredTable.AggregateColumnProperty(
149+
column_names=["user_id"], function="COUNT"
150+
)
151+
],
152+
dimension_columns=["city"],
153+
join_columns=["user_id"],
154+
output_constraints=[
155+
cleanrooms.CfnConfiguredTable.AggregationConstraintProperty(
156+
column_name="user_id", minimum=2, type="COUNT_DISTINCT"
157+
)
158+
],
159+
scalar_functions=["LOWER"],
160+
join_required="QUERY_RUNNER",
161+
),
162+
)
163+
),
164+
type="AGGREGATION",
165+
)
166+
],
167+
)
168+
169+
self.purchases_configured_table = cleanrooms.CfnConfiguredTable(
170+
self,
171+
"Purchases Configured Table",
172+
allowed_columns=["purchase_id", "user_id", "sale_value"],
173+
analysis_method="DIRECT_QUERY",
174+
name="purchases",
175+
table_reference=cleanrooms.CfnConfiguredTable.TableReferenceProperty(
176+
glue=cleanrooms.CfnConfiguredTable.GlueTableReferenceProperty(
177+
database_name=self.database.database_name,
178+
table_name=self.purchases_table.table_name,
179+
)
180+
),
181+
analysis_rules=[
182+
cleanrooms.CfnConfiguredTable.AnalysisRuleProperty(
183+
policy=cleanrooms.CfnConfiguredTable.ConfiguredTableAnalysisRulePolicyProperty(
184+
v1=cleanrooms.CfnConfiguredTable.ConfiguredTableAnalysisRulePolicyV1Property(
185+
aggregation=cleanrooms.CfnConfiguredTable.AnalysisRuleAggregationProperty(
186+
aggregate_columns=[
187+
cleanrooms.CfnConfiguredTable.AggregateColumnProperty(
188+
column_names=["purchase_id"], function="COUNT"
189+
),
190+
cleanrooms.CfnConfiguredTable.AggregateColumnProperty(
191+
column_names=["sale_value"], function="AVG"
192+
),
193+
cleanrooms.CfnConfiguredTable.AggregateColumnProperty(
194+
column_names=["sale_value"], function="SUM"
195+
),
196+
],
197+
dimension_columns=[],
198+
join_columns=["user_id"],
199+
output_constraints=[
200+
cleanrooms.CfnConfiguredTable.AggregationConstraintProperty(
201+
column_name="user_id", minimum=2, type="COUNT_DISTINCT"
202+
)
203+
],
204+
scalar_functions=[],
205+
join_required="QUERY_RUNNER",
206+
),
207+
)
208+
),
209+
type="AGGREGATION",
210+
)
211+
],
212+
)
213+
214+
self.users_configured_table_association = cleanrooms.CfnConfiguredTableAssociation(
215+
self,
216+
"Users Configured Table Association",
217+
configured_table_identifier=self.users_configured_table.attr_configured_table_identifier,
218+
membership_identifier=self.membership.attr_membership_identifier,
219+
name="users",
220+
role_arn=self.cleanrooms_service_role.role_arn,
221+
)
222+
223+
self.purchases_configured_table_association = cleanrooms.CfnConfiguredTableAssociation(
224+
self,
225+
"Purchases Configured Table Association",
226+
configured_table_identifier=self.purchases_configured_table.attr_configured_table_identifier,
227+
membership_identifier=self.membership.attr_membership_identifier,
228+
name="purchases",
229+
role_arn=self.cleanrooms_service_role.role_arn,
230+
)
231+
232+
CfnOutput(self, "CleanRoomsMembershipId", value=self.membership.attr_membership_identifier)
233+
CfnOutput(self, "CleanRoomsGlueDatabaseName", value=self.database.database_name)
234+
CfnOutput(self, "CleanRoomsS3BucketName", value=self.bucket.bucket_name)
235+
236+
ssm.StringParameter(
237+
self,
238+
"SSM BucketName",
239+
parameter_name="/sdk-pandas/cleanrooms/BucketName",
240+
string_value=self.bucket.bucket_name,
241+
)

‎tests/_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,13 @@ def path_generator(bucket: str) -> Iterator[str]:
489489
def extract_cloudformation_outputs():
490490
outputs = {}
491491
client = boto3.client("cloudformation")
492-
stacks = ["aws-sdk-pandas-base", "aws-sdk-pandas-databases", "aws-sdk-pandas-opensearch", "aws-sdk-pandas-glueray"]
492+
stacks = [
493+
"aws-sdk-pandas-base",
494+
"aws-sdk-pandas-databases",
495+
"aws-sdk-pandas-opensearch",
496+
"aws-sdk-pandas-glueray",
497+
"aws-sdk-pandas-cleanrooms",
498+
]
493499
response = try_it(client.describe_stacks, botocore.exceptions.ClientError, max_num_tries=5)
494500
for stack in response.get("Stacks"):
495501
if (stack["StackName"] in stacks) and (stack["StackStatus"] in CFN_VALID_STATUS):

‎tests/conftest.py

+15
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,21 @@ def glue_data_quality_role(cloudformation_outputs):
449449
return cloudformation_outputs["GlueDataQualityRole"]
450450

451451

452+
@pytest.fixture(scope="session")
453+
def cleanrooms_membership_id(cloudformation_outputs):
454+
return cloudformation_outputs["CleanRoomsMembershipId"]
455+
456+
457+
@pytest.fixture(scope="session")
458+
def cleanrooms_glue_database_name(cloudformation_outputs):
459+
return cloudformation_outputs["CleanRoomsGlueDatabaseName"]
460+
461+
462+
@pytest.fixture(scope="session")
463+
def cleanrooms_s3_bucket_name(cloudformation_outputs):
464+
return cloudformation_outputs["CleanRoomsS3BucketName"]
465+
466+
452467
@pytest.fixture(scope="function")
453468
def local_filename() -> Iterator[str]:
454469
filename = os.path.join(".", f"{get_time_str_with_random_suffix()}.data")

‎tests/unit/test_cleanrooms.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
3+
import awswrangler as wr
4+
import awswrangler.pandas as pd
5+
6+
pytestmark = pytest.mark.distributed
7+
8+
9+
@pytest.fixture()
10+
def data(cleanrooms_s3_bucket_name: str, cleanrooms_glue_database_name: str) -> None:
11+
df_purchases = pd.DataFrame(
12+
{
13+
"purchase_id": list(range(100, 109)),
14+
"user_id": [1, 2, 3, 1, 2, 3, 4, 5, 6],
15+
"sale_value": [2.2, 1.1, 6.2, 2.3, 7.8, 9.9, 7.3, 9.7, 0.7],
16+
}
17+
)
18+
wr.s3.to_parquet(
19+
df_purchases,
20+
f"s3://{cleanrooms_s3_bucket_name}/purchases/",
21+
dataset=True,
22+
database=cleanrooms_glue_database_name,
23+
table="purchases",
24+
mode="overwrite",
25+
)
26+
27+
df_users = pd.DataFrame(
28+
{
29+
"user_id": list(range(1, 9)),
30+
"city": ["LA", "NYC", "Chicago", "NYC", "NYC", "LA", "Seattle", "Seattle"],
31+
}
32+
)
33+
wr.s3.to_parquet(
34+
df_users,
35+
f"s3://{cleanrooms_s3_bucket_name}/users/",
36+
dataset=True,
37+
database=cleanrooms_glue_database_name,
38+
table="users",
39+
mode="overwrite",
40+
)
41+
42+
43+
def test_read_sql_query(data: None, cleanrooms_membership_id: str, cleanrooms_s3_bucket_name: str):
44+
sql = """SELECT city, AVG(p.sale_value)
45+
FROM users u
46+
INNER JOIN purchases p ON u.user_id = p.user_id
47+
GROUP BY city
48+
"""
49+
chunksize = 2
50+
df_chunked = wr.cleanrooms.read_sql_query(
51+
sql=sql,
52+
membership_id=cleanrooms_membership_id,
53+
output_bucket=cleanrooms_s3_bucket_name,
54+
output_prefix="results",
55+
chunksize=chunksize,
56+
keep_files=False,
57+
)
58+
for df in df_chunked:
59+
assert df.shape == (chunksize, 2)
60+
61+
sql = """SELECT COUNT(p.purchase_id), SUM(p.sale_value), city
62+
FROM users u
63+
INNER JOIN purchases p ON u.user_id = p.user_id
64+
GROUP BY city
65+
"""
66+
df = wr.cleanrooms.read_sql_query(
67+
sql=sql,
68+
membership_id=cleanrooms_membership_id,
69+
output_bucket=cleanrooms_s3_bucket_name,
70+
output_prefix="results",
71+
keep_files=False,
72+
)
73+
assert df.shape == (2, 3)

0 commit comments

Comments
 (0)
Please sign in to comment.