Skip to content

Commit 0212f67

Browse files
authored
openlineage: add support for hook lineage for S3Hook (apache#40819)
Signed-off-by: Maciej Obuchowski <[email protected]>
1 parent 05a5df8 commit 0212f67

File tree

16 files changed

+245
-25
lines changed

16 files changed

+245
-25
lines changed

airflow/lineage/hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,10 @@ class NoOpCollector(HookLineageCollector):
139139
It is used when you want to disable lineage collection.
140140
"""
141141

142-
def add_input_dataset(self, *_):
142+
def add_input_dataset(self, *_, **__):
143143
pass
144144

145-
def add_output_dataset(self, *_):
145+
def add_output_dataset(self, *_, **__):
146146
pass
147147

148148
@property
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from airflow.datasets import Dataset
20+
21+
22+
def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset:
23+
return Dataset(uri=f"s3://{bucket}/{key}", extra=extra)

airflow/providers/amazon/aws/hooks/s3.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from urllib.parse import urlsplit
4242
from uuid import uuid4
4343

44+
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
45+
4446
if TYPE_CHECKING:
4547
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject
4648

@@ -1111,6 +1113,12 @@ def load_file(
11111113

11121114
client = self.get_conn()
11131115
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)
1116+
get_hook_lineage_collector().add_input_dataset(
1117+
context=self, scheme="file", dataset_kwargs={"path": filename}
1118+
)
1119+
get_hook_lineage_collector().add_output_dataset(
1120+
context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1121+
)
11141122

11151123
@unify_bucket_name_and_key
11161124
@provide_bucket_name
@@ -1251,6 +1259,10 @@ def _upload_file_obj(
12511259
ExtraArgs=extra_args,
12521260
Config=self.transfer_config,
12531261
)
1262+
# No input because file_obj can be anything - handle in calling function if possible
1263+
get_hook_lineage_collector().add_output_dataset(
1264+
context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1265+
)
12541266

12551267
def copy_object(
12561268
self,
@@ -1306,6 +1318,12 @@ def copy_object(
13061318
response = self.get_conn().copy_object(
13071319
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
13081320
)
1321+
get_hook_lineage_collector().add_input_dataset(
1322+
context=self, scheme="s3", dataset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key}
1323+
)
1324+
get_hook_lineage_collector().add_output_dataset(
1325+
context=self, scheme="s3", dataset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key}
1326+
)
13091327
return response
13101328

13111329
@provide_bucket_name
@@ -1425,6 +1443,11 @@ def download_file(
14251443

14261444
file_path.parent.mkdir(exist_ok=True, parents=True)
14271445

1446+
get_hook_lineage_collector().add_output_dataset(
1447+
context=self,
1448+
scheme="file",
1449+
dataset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()},
1450+
)
14281451
file = open(file_path, "wb")
14291452
else:
14301453
file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore
@@ -1435,7 +1458,9 @@ def download_file(
14351458
ExtraArgs=self.extra_args,
14361459
Config=self.transfer_config,
14371460
)
1438-
1461+
get_hook_lineage_collector().add_input_dataset(
1462+
context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key}
1463+
)
14391464
return file.name
14401465

14411466
def generate_presigned_url(

airflow/providers/amazon/provider.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ dependencies:
9191
- apache-airflow>=2.7.0
9292
- apache-airflow-providers-common-sql>=1.3.1
9393
- apache-airflow-providers-http
94+
- apache-airflow-providers-common-compat>=1.1.0
9495
# We should update minimum version of boto3 and here regularly to avoid `pip` backtracking with the number
9596
# of candidates to consider. Make sure to configure boto3 version here as well as in all the tools below
9697
# in the `devel-dependencies` section to be the same minimum version.
@@ -561,6 +562,7 @@ sensors:
561562
dataset-uris:
562563
- schemes: [s3]
563564
handler: null
565+
factory: airflow.providers.amazon.aws.datasets.s3.create_dataset
564566

565567
filesystems:
566568
- airflow.providers.amazon.aws.fs.s3

airflow/providers/common/compat/lineage/hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ class NoOpCollector:
3232
It is used when you want to disable lineage collection.
3333
"""
3434

35-
def add_input_dataset(self, *_):
35+
def add_input_dataset(self, *_, **__):
3636
pass
3737

38-
def add_output_dataset(self, *_):
38+
def add_output_dataset(self, *_, **__):
3939
pass
4040

4141
return NoOpCollector()

airflow/providers/common/compat/provider.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ state: ready
2525
source-date-epoch: 1716287191
2626
# note that those versions are maintained by release manager - do not update them manually
2727
versions:
28+
- 1.1.0
2829
- 1.0.0
2930

3031
dependencies:

airflow/providers/common/io/datasets/file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@
1919
from airflow.datasets import Dataset
2020

2121

22-
def create_dataset(*, path: str) -> Dataset:
22+
def create_dataset(*, path: str, extra=None) -> Dataset:
2323
# We assume that we get absolute path starting with /
24-
return Dataset(uri=f"file://{path}")
24+
return Dataset(uri=f"file://{path}", extra=extra)

airflow/providers_manager.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -886,23 +886,23 @@ def _discover_dataset_uri_handlers_and_factories(self) -> None:
886886

887887
for provider_package, provider in self._provider_dict.items():
888888
for handler_info in provider.data.get("dataset-uris", []):
889-
try:
890-
schemes = handler_info["schemes"]
891-
handler_path = handler_info["handler"]
892-
except KeyError:
889+
schemes = handler_info.get("schemes")
890+
handler_path = handler_info.get("handler")
891+
factory_path = handler_info.get("factory")
892+
if schemes is None:
893893
continue
894-
if handler_path is None:
894+
895+
if handler_path is not None and (
896+
handler := _correctness_check(provider_package, handler_path, provider)
897+
):
898+
pass
899+
else:
895900
handler = normalize_noop
896-
elif not (handler := _correctness_check(provider_package, handler_path, provider)):
897-
continue
898901
self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes)
899-
factory_path = handler_info.get("factory")
900-
if not (
901-
factory_path is not None
902-
and (factory := _correctness_check(provider_package, factory_path, provider))
902+
if factory_path is not None and (
903+
factory := _correctness_check(provider_package, factory_path, provider)
903904
):
904-
continue
905-
self._dataset_factories.update((scheme, factory) for scheme in schemes)
905+
self._dataset_factories.update((scheme, factory) for scheme in schemes)
906906

907907
def _discover_taskflow_decorators(self) -> None:
908908
for name, info in self._provider_dict.items():

dev/breeze/tests/test_selective_checks.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
569569
("airflow/providers/amazon/__init__.py",),
570570
{
571571
"affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes "
572-
"common.sql exasol ftp google http imap microsoft.azure "
572+
"common.compat common.sql exasol ftp google http imap microsoft.azure "
573573
"mongo mysql openlineage postgres salesforce ssh teradata",
574574
"all-python-versions": "['3.8']",
575575
"all-python-versions-list-as-string": "3.8",
@@ -585,7 +585,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
585585
"upgrade-to-newer-dependencies": "false",
586586
"run-amazon-tests": "true",
587587
"parallel-test-types-list-as-string": "Always Providers[amazon] "
588-
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,"
588+
"Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http,"
589589
"imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]",
590590
"needs-mypy": "true",
591591
"mypy-folders": "['providers']",
@@ -619,7 +619,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
619619
("airflow/providers/amazon/file.py",),
620620
{
621621
"affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes "
622-
"common.sql exasol ftp google http imap microsoft.azure "
622+
"common.compat common.sql exasol ftp google http imap microsoft.azure "
623623
"mongo mysql openlineage postgres salesforce ssh teradata",
624624
"all-python-versions": "['3.8']",
625625
"all-python-versions-list-as-string": "3.8",
@@ -635,7 +635,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str):
635635
"run-kubernetes-tests": "false",
636636
"upgrade-to-newer-dependencies": "false",
637637
"parallel-test-types-list-as-string": "Always Providers[amazon] "
638-
"Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,"
638+
"Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http,"
639639
"imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]",
640640
"needs-mypy": "true",
641641
"mypy-folders": "['providers']",

generated/provider_dependencies.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"amazon": {
2929
"deps": [
3030
"PyAthena>=3.0.10",
31+
"apache-airflow-providers-common-compat>=1.1.0",
3132
"apache-airflow-providers-common-sql>=1.3.1",
3233
"apache-airflow-providers-http",
3334
"apache-airflow>=2.7.0",
@@ -57,6 +58,7 @@
5758
"cross-providers-deps": [
5859
"apache.hive",
5960
"cncf.kubernetes",
61+
"common.compat",
6062
"common.sql",
6163
"exasol",
6264
"ftp",

prod_image_installed_providers.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
amazon
33
celery
44
cncf.kubernetes
5+
common.compat
56
common.io
67
common.sql
78
docker

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,16 @@ def airflow_root_path() -> Path:
13261326
return Path(airflow.__path__[0]).parent
13271327

13281328

1329+
@pytest.fixture
1330+
def hook_lineage_collector():
1331+
from airflow.lineage import hook
1332+
1333+
hook._hook_lineage_collector = None
1334+
hook._hook_lineage_collector = hook.HookLineageCollector()
1335+
yield hook.get_hook_lineage_collector()
1336+
hook._hook_lineage_collector = None
1337+
1338+
13291339
# This constant is set to True if tests are run with Airflow installed from Packages rather than running
13301340
# the tests within Airflow sources. While most tests in CI are run using Airflow sources, there are
13311341
# also compatibility tests that only use `tests` package and run against installed packages of Airflow in
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from airflow.datasets import Dataset
20+
from airflow.providers.amazon.aws.datasets.s3 import create_dataset
21+
22+
23+
def test_create_dataset():
24+
assert create_dataset(bucket="test-bucket", key="test-path") == Dataset(uri="s3://test-bucket/test-path")
25+
assert create_dataset(bucket="test-bucket", key="test-dir/test-path") == Dataset(
26+
uri="s3://test-bucket/test-dir/test-path"
27+
)

0 commit comments

Comments
 (0)