Skip to content

Commit a79bbfa

Browse files
authored
fix(pypi): handle more URL patterns for requirement sources (#2843)
Summary: - Better handle git references for sdists. - Better handle direct whl references. - Add an extra test that turned out to be not needed in the end, but I left it to increase the code coverage. Work towards #2363 Fixes #2828
1 parent 704ecdd commit a79bbfa

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

python/private/pypi/parse_requirements.bzl

+5
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,17 @@ def _add_dists(*, requirement, index_urls, logger = None):
285285
if requirement.srcs.url:
286286
url = requirement.srcs.url
287287
_, _, filename = url.rpartition("/")
288+
filename, _, _ = filename.partition("#sha256=")
288289
if "." not in filename:
289290
# detected filename has no extension, it might be an sdist ref
290291
# TODO @aignas 2025-04-03: should be handled if the following is fixed:
291292
# https://github.com/bazel-contrib/rules_python/issues/2363
292293
return [], None
293294

295+
if "@" in filename:
296+
# this is most likely foo.git@git_sha, skip special handling of these
297+
return [], None
298+
294299
direct_url_dist = struct(
295300
url = url,
296301
filename = filename,

tests/pypi/index_sources/index_sources_tests.bzl

+13-1
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,50 @@ _tests = []
2121

2222
def _test_no_simple_api_sources(env):
2323
inputs = {
24+
"foo @ git+https://github.com/org/foo.git@deadbeef": struct(
25+
requirement = "foo @ git+https://github.com/org/foo.git@deadbeef",
26+
marker = "",
27+
url = "git+https://github.com/org/foo.git@deadbeef",
28+
shas = [],
29+
version = "",
30+
),
2431
"foo==0.0.1": struct(
2532
requirement = "foo==0.0.1",
2633
marker = "",
2734
url = "",
35+
version = "0.0.1",
2836
),
2937
"foo==0.0.1 @ https://someurl.org": struct(
3038
requirement = "foo==0.0.1 @ https://someurl.org",
3139
marker = "",
3240
url = "https://someurl.org",
41+
version = "0.0.1",
3342
),
3443
"foo==0.0.1 @ https://someurl.org/package.whl": struct(
3544
requirement = "foo==0.0.1 @ https://someurl.org/package.whl",
3645
marker = "",
3746
url = "https://someurl.org/package.whl",
47+
version = "0.0.1",
3848
),
3949
"foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef": struct(
4050
requirement = "foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef",
4151
marker = "",
4252
url = "https://someurl.org/package.whl",
4353
shas = ["deadbeef"],
54+
version = "0.0.1",
4455
),
4556
"foo==0.0.1 @ https://someurl.org/package.whl; python_version < \"2.7\"\\ --hash=sha256:deadbeef": struct(
4657
requirement = "foo==0.0.1 @ https://someurl.org/package.whl --hash=sha256:deadbeef",
4758
marker = "python_version < \"2.7\"",
4859
url = "https://someurl.org/package.whl",
4960
shas = ["deadbeef"],
61+
version = "0.0.1",
5062
),
5163
}
5264
for input, want in inputs.items():
5365
got = index_sources(input)
5466
env.expect.that_collection(got.shas).contains_exactly(want.shas if hasattr(want, "shas") else [])
55-
env.expect.that_str(got.version).equals("0.0.1")
67+
env.expect.that_str(got.version).equals(want.version)
5668
env.expect.that_str(got.requirement).equals(want.requirement)
5769
env.expect.that_str(got.requirement_line).equals(got.requirement)
5870
env.expect.that_str(got.marker).equals(want.marker)

tests/pypi/parse_requirements/parse_requirements_tests.bzl

+59
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,16 @@ foo[extra] @ https://some-url/package.whl
3030
bar @ https://example.org/bar-1.0.whl --hash=sha256:deadbeef
3131
baz @ https://test.com/baz-2.0.whl; python_version < "3.8" --hash=sha256:deadb00f
3232
qux @ https://example.org/qux-1.0.tar.gz --hash=sha256:deadbe0f
33+
torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc
3334
""",
3435
"requirements_extra_args": """\
3536
--index-url=example.org
3637
3738
foo[extra]==0.0.1 \
3839
--hash=sha256:deadbeef
40+
""",
41+
"requirements_git": """
42+
foo @ git+https://github.com/org/foo.git@deadbeef
3943
""",
4044
"requirements_linux": """\
4145
foo==0.0.3 --hash=sha256:deadbaaf
@@ -232,6 +236,31 @@ def _test_direct_urls(env):
232236
whls = [],
233237
),
234238
],
239+
"torch": [
240+
struct(
241+
distribution = "torch",
242+
extra_pip_args = [],
243+
is_exposed = True,
244+
sdist = None,
245+
srcs = struct(
246+
marker = "",
247+
requirement = "torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc",
248+
requirement_line = "torch @ https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc",
249+
shas = [],
250+
url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc",
251+
version = "",
252+
),
253+
target_platforms = ["linux_x86_64"],
254+
whls = [
255+
struct(
256+
filename = "torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl",
257+
sha256 = "",
258+
url = "https://download.pytorch.org/whl/cpu/torch-2.6.0%2Bcpu-cp311-cp311-linux_x86_64.whl#sha256=5b6ae523bfb67088a17ca7734d131548a2e60346c622621e4248ed09dd0790cc",
259+
yanked = False,
260+
),
261+
],
262+
),
263+
],
235264
})
236265

237266
_tests.append(_test_direct_urls)
@@ -623,6 +652,36 @@ def _test_optional_hash(env):
623652

624653
_tests.append(_test_optional_hash)
625654

655+
def _test_git_sources(env):
656+
got = parse_requirements(
657+
ctx = _mock_ctx(),
658+
requirements_by_platform = {
659+
"requirements_git": ["linux_x86_64"],
660+
},
661+
)
662+
env.expect.that_dict(got).contains_exactly({
663+
"foo": [
664+
struct(
665+
distribution = "foo",
666+
extra_pip_args = [],
667+
is_exposed = True,
668+
sdist = None,
669+
srcs = struct(
670+
marker = "",
671+
requirement = "foo @ git+https://github.com/org/foo.git@deadbeef",
672+
requirement_line = "foo @ git+https://github.com/org/foo.git@deadbeef",
673+
shas = [],
674+
url = "git+https://github.com/org/foo.git@deadbeef",
675+
version = "",
676+
),
677+
target_platforms = ["linux_x86_64"],
678+
whls = [],
679+
),
680+
],
681+
})
682+
683+
_tests.append(_test_git_sources)
684+
626685
def parse_requirements_test_suite(name):
627686
"""Create the test suite.
628687

0 commit comments

Comments
 (0)