From e468cde19d5a4713e88b0a133cd1c620431f8b77 Mon Sep 17 00:00:00 2001 From: Jesper Stemann Andersen Date: Sat, 10 Sep 2022 12:45:04 +0200 Subject: [PATCH] Replaced artifacts with JLL artifacts Using extended platform selection based on platform augmentation tags, i.e. https://pkgdocs.julialang.org/v1.7/artifacts/#Extending-Platform-Selection - adapted for use with JLL Artifacts. --- .pkg/platform_augmentation.jl | 35 +++++++++++++++++++++ .pkg/select_artifacts.jl | 19 ++++++++++++ Artifacts.toml | 57 ----------------------------------- Project.toml | 3 ++ src/ONNXRunTime.jl | 4 +++ src/capi.jl | 42 +++++--------------------- src/highlevel.jl | 3 +- 7 files changed, 69 insertions(+), 94 deletions(-) create mode 100644 .pkg/platform_augmentation.jl create mode 100644 .pkg/select_artifacts.jl delete mode 100644 Artifacts.toml diff --git a/.pkg/platform_augmentation.jl b/.pkg/platform_augmentation.jl new file mode 100644 index 0000000..71fa9d6 --- /dev/null +++ b/.pkg/platform_augmentation.jl @@ -0,0 +1,35 @@ +using Libdl, Base.BinaryPlatforms + +function augment_platform!(p::Platform, tag::Union{String,Nothing} = nothing) + if tag === nothing + return p + end + + if tag === "cuda" + if haskey(p, tag) + return p + end + + # Open libcuda explicitly, so it gets `dlclose()`'ed after we're done + try + dlopen("libcuda") do lib + # find symbol to ask for driver version; if we can't find it, just silently continue + cuDriverGetVersion = dlsym(lib, "cuDriverGetVersion"; throw_error=false) + if cuDriverGetVersion !== nothing + # Interrogate CUDA driver for driver version: + driverVersion = Ref{Cint}() + ccall(cuDriverGetVersion, UInt32, (Ptr{Cint},), driverVersion) + + # Store only the major version + p[tag] = div(driverVersion, 1000) + end + end + catch + end + + return p + else + @warn "Unexpected tag $tag for $p" + return p + end +end diff --git a/.pkg/select_artifacts.jl b/.pkg/select_artifacts.jl new file mode 100644 index 0000000..f26630a --- /dev/null +++ b/.pkg/select_artifacts.jl @@ -0,0 +1,19 @@ +using TOML, Artifacts, Base.BinaryPlatforms + +import ONNXRuntime_jll + +include("platform_augmentation.jl") + +artifacts_toml = Artifacts.find_artifacts_toml(pathof(ONNXRuntime_jll)) + +# Get "target triplet" from ARGS, if given (defaulting to the host triplet otherwise) +target_triplet = get(ARGS, 1, Base.BinaryPlatforms.host_triplet()) + +# Augment this platform object with any special tags we require +platform = augment_platform!(HostPlatform(parse(Platform, target_triplet))) + +# Select all downloadable artifacts that match that platform +artifacts = select_downloadable_artifacts(artifacts_toml; platform) + +# Output the result to `stdout` as a TOML dictionary +TOML.print(stdout, artifacts) diff --git a/Artifacts.toml b/Artifacts.toml deleted file mode 100644 index e620db0..0000000 --- a/Artifacts.toml +++ /dev/null @@ -1,57 +0,0 @@ -[[onnxruntime_cpu]] -arch = "x86_64" -git-tree-sha1 = "03ca4920f8c66efacaa58cdb30ee6da83b7a40e0" -lazy = true -os = "windows" - - [[onnxruntime_cpu.download]] - sha256 = "0dbe0bb59e2c000d9b331d1dfd2a61a7af2cfea1d5391d369fcd8d3ea504b3b8" - url = "https://github.com/jw3126/ONNXRunTimeArtifacts/releases/download/v1.20.1-rc1/onnxruntime-win-x64-1.20.1.tgz" -[[onnxruntime_cpu]] -arch = "x86_64" -git-tree-sha1 = "7a5a70a99ed5931ff451b6e70a29944cd3d4b54d" -lazy = true -libc = "glibc" -os = "linux" - - [[onnxruntime_cpu.download]] - sha256 = "67db4dc1561f1e3fd42e619575c82c601ef89849afc7ea85a003abbac1a1a105" - url = "https://github.com/microsoft/onnxruntime/releases/download/v1.20.1/onnxruntime-linux-x64-1.20.1.tgz" -[[onnxruntime_cpu]] -arch = "x86_64" -git-tree-sha1 = "4652621efdd6d8c61987734403a7e7e1fff19079" -lazy = true -os = "macos" - - [[onnxruntime_cpu.download]] - sha256 = "da4349e01a7e997f5034563183c7183d069caadc1d95f499b560961787813efd" - url = "https://github.com/microsoft/onnxruntime/releases/download/v1.20.1/onnxruntime-osx-universal2-1.20.1.tgz" -[[onnxruntime_cpu]] -arch = "aarch64" -git-tree-sha1 = "9f203a6745ce1e17f0afb8528688b95d1791d851" -lazy = true -os = "macos" - - [[onnxruntime_cpu.download]] - sha256 = "b678fc3c2354c771fea4fba420edeccfba205140088334df801e7fc40e83a57a" - url = "https://github.com/microsoft/onnxruntime/releases/download/v1.20.1/onnxruntime-osx-arm64-1.20.1.tgz" - -[[onnxruntime_gpu]] -arch = "x86_64" -git-tree-sha1 = "9740e9c731f3aeaa8ea93e25a787e91cf8f64bd7" -lazy = true -os = "windows" - - [[onnxruntime_gpu.download]] - sha256 = "622828adb36268ce58d69eade91827d92354388f6e6a3be777fffe54acab84a2" - url = "https://github.com/jw3126/ONNXRunTimeArtifacts/releases/download/v1.20.1-rc1/onnxruntime-win-x64-gpu-1.20.1.tgz" -[[onnxruntime_gpu]] -arch = "x86_64" -git-tree-sha1 = "478f998b8d737218ef0cc06d26e91288d1e72f3b" -lazy = true -libc = "glibc" -os = "linux" - - [[onnxruntime_gpu.download]] - sha256 = "6bfb87c6ebe55367a94509b8ef062239e188dccf8d5caac8d6909b2344893bf0" - url = "https://github.com/microsoft/onnxruntime/releases/download/v1.20.1/onnxruntime-linux-x64-gpu-1.20.1.tgz" diff --git a/Project.toml b/Project.toml index 8925983..ff08e75 100644 --- a/Project.toml +++ b/Project.toml @@ -5,12 +5,15 @@ version = "1.3.1" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" +Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +ONNXRuntime_jll = "09e6dd1b-8208-5c7e-a336-6e9061773d0b" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [compat] ArgCheck = "2" diff --git a/src/ONNXRunTime.jl b/src/ONNXRunTime.jl index 067dfef..f939c9c 100644 --- a/src/ONNXRunTime.jl +++ b/src/ONNXRunTime.jl @@ -1,5 +1,9 @@ module ONNXRunTime +using Artifacts +using LazyArtifacts +using ONNXRuntime_jll + function _perm(arr::AbstractArray{T,N}) where {T,N} ntuple(i->N+1-i, N) end diff --git a/src/capi.jl b/src/capi.jl index bad5e34..3f0f2ee 100644 --- a/src/capi.jl +++ b/src/capi.jl @@ -5,20 +5,19 @@ This module closely follows the offical onnxruntime [C-API](https://github.com/m See [here](https://github.com/microsoft/onnxruntime-inference-examples/blob/d031f879c9a8d33c8b7dc52c5bc65fe8b9e3960d/c_cxx/fns_candy_style_transfer/fns_candy_style_transfer.c) for a C code example. """ module CAPI -using ONNXRunTime: reversedims_lazy + +using ONNXRunTime: EXECUTION_PROVIDERS, artifact_dir_map, reversedims_lazy + +using ONNXRuntime_jll using DocStringExtensions using Libdl using CEnum: @cenum using ArgCheck -using LazyArtifacts -using Pkg.Artifacts: artifact_path, ensure_artifact_installed, find_artifacts_toml const LIB_CPU = Ref(C_NULL) const LIB_CUDA = Ref(C_NULL) -const EXECUTION_PROVIDERS = [:cpu, :cuda] - # For model_path on windows ONNX uses wchar_t while on linux + mac char is used. # Other strings use char on any platform it seems # https://github.com/microsoft/onnxruntime/issues/9568#issuecomment-952951564 @@ -45,36 +44,9 @@ end function make_lib!(execution_provider) @argcheck execution_provider in EXECUTION_PROVIDERS - artifact_name = if execution_provider === :cpu - "onnxruntime_cpu" - elseif execution_provider === :cuda - "onnxruntime_gpu" - else - error("Unreachable") - end - artifacts_toml = find_artifacts_toml(joinpath(@__DIR__ , "ONNXRunTime.jl")) - h = artifact_hash(artifact_name, artifacts_toml) - if h === nothing - msg = """ - Unsupported execution_provider = $(repr(execution_provider)) for - this architectur. - """ - error(msg) - end - ensure_artifact_installed(artifact_name, artifacts_toml) - root = artifact_path(h) - @check isdir(root) - dir = joinpath(root, only(readdir(root))) - @check isdir(dir) - libname = if Sys.iswindows() - "onnxruntime.dll" - elseif Sys.isapple() - "libonnxruntime.dylib" - else - "libonnxruntime.so" - end - path = joinpath(dir, "lib", libname) - @check isfile(path) + path = ONNXRuntime_jll.libonnxruntime_path + rel_path = joinpath(basename(dirname(path)), basename(path)) + path = joinpath(artifact_dir_map[execution_provider], rel_path) set_lib!(path, execution_provider) end diff --git a/src/highlevel.jl b/src/highlevel.jl index 1aa24d4..e793e5b 100644 --- a/src/highlevel.jl +++ b/src/highlevel.jl @@ -1,5 +1,4 @@ using ArgCheck -using LazyArtifacts using DataStructures: OrderedDict using DocStringExtensions import CEnum @@ -12,7 +11,7 @@ end using .CAPI -using .CAPI: juliatype, EXECUTION_PROVIDERS +using .CAPI: juliatype export InferenceSession, load_inference, release """