diff --git a/mojo/extensions.bzl b/mojo/extensions.bzl index 4379ded..13ac043 100644 --- a/mojo/extensions.bzl +++ b/mojo/extensions.bzl @@ -17,13 +17,13 @@ _PLATFORM_MAPPINGS = { def _mojo_toolchain_impl(rctx): rctx.download_and_extract( - url = "https://dl.modular.com/public/nightly/python/max-{}-py3-none-{}.whl".format( + url = rctx.attr.urls or "https://dl.modular.com/public/nightly/python/max-{}-py3-none-{}.whl".format( rctx.attr.version, _PLATFORM_MAPPINGS[rctx.attr.platform], ), - sha256 = _KNOWN_SHAS.get(rctx.attr.version, {}).get(rctx.attr.platform, ""), + sha256 = rctx.attr.sha256 or _KNOWN_SHAS.get(rctx.attr.version, {}).get(rctx.attr.platform, ""), type = "zip", - strip_prefix = "max-{}.data/platlib/max".format(rctx.attr.version), + strip_prefix = rctx.attr.strip_prefix or "max-{}.data/platlib/max".format(rctx.attr.version), ) rctx.template( @@ -51,6 +51,18 @@ _mojo_toolchain_repository = repository_rule( doc = "Whether to automatically add prebuilt mojopkgs to every mojo target.", mandatory = True, ), + "urls": attr.string_list( + doc = "The URL to download the Mojo toolchain from.", + mandatory = False, + ), + "sha256": attr.string( + doc = "The SHA256 hash of the Mojo toolchain archive.", + mandatory = False, + ), + "strip_prefix": attr.string( + doc = "The prefix to strip from the extracted Mojo toolchain.", + mandatory = False, + ), "_template": attr.label( default = Label("//mojo/private:toolchain.BUILD"), ), @@ -89,7 +101,8 @@ _mojo_toolchain_hub = repository_rule( def _mojo_impl(mctx): # TODO: This requires the root module always call mojo.toolchain(), we # should improve this. - has_toolchains = False + platforms = [] + for module in mctx.modules: if not module.is_root: continue @@ -97,21 +110,24 @@ def _mojo_impl(mctx): if len(module.tags.toolchain) > 1: fail("mojo.toolchain() can only be called once per module.") - has_toolchains = True tags = module.tags.toolchain[0] - for platform in _PLATFORMS: + platforms = tags.urls.keys() if tags.urls else _PLATFORMS + for platform in platforms: name = "mojo_toolchain_{}".format(platform) _mojo_toolchain_repository( name = name, version = tags.version, platform = platform, + urls = tags.urls.get(platform, None), + sha256 = tags.sha256.get(platform, None), + strip_prefix = tags.strip_prefix.get(platform, None), use_prebuilt_packages = tags.use_prebuilt_packages, ) _mojo_toolchain_hub( name = "mojo_toolchains", - platforms = _PLATFORMS if has_toolchains else [], + platforms = platforms, ) return mctx.extension_metadata(reproducible = True) @@ -128,6 +144,50 @@ _toolchain_tag = tag_class( doc = "Whether to automatically add prebuilt mojopkgs to every mojo target.", default = True, ), + "urls": attr.string_list_dict( + mandatory = False, + doc = """\ +URLs to prebuilt archives containing mojo toolchains. They key is the platform +and the value is a list of URLs for the download. Only the provided platforms +will have toolchains created for them. Providing 'sha256's is recommended. + +Example: + +urls = { + "linux_x86_64": ["https://.../max-25.4.0.dev2025050905-py3-none-manylinux_2_34_x86_64.whl"], + "linux_aarch64": ["https://.../max-25.4.0.dev2025050905-py3-none-manylinux_2_34_aarch64.whl"], + "macos_arm64": ["https://.../max-25.4.0.dev2025050905-py3-none-macosx_13_0_arm64.whl"], +} +""", + ), + "sha256": attr.string_dict( + mandatory = False, + doc = """\ +SHA256 hashes for the provided URLs. The key is the platform and the value is the sha256 hash. + +Example: + +sha256 = { + "linux_aarch64": "abc123", + "linux_x86_64": "abc123", + "macos_arm64": "abc123", +} +""", + ), + "strip_prefix": attr.string_dict( + mandatory = False, + doc = """\ +The prefix to strip from the extracted Mojo toolchain. The key is the platform the value is the prefix to strip. Otherwise there is a reasonable default based on the 'version' attribute. + +Example: + +strip_prefix = { + "linux_aarch64": "abc123", + "linux_x86_64": "abc123", + "macos_arm64": "abc123", +} +""", + ), }, )