Skip to content

Commit 0cf858d

Browse files
committed
Replaced artifacts with JLL artifacts
Using extended platform selection based on platform augmentation tags, i.e. https://pkgdocs.julialang.org/v1.7/artifacts/#Extending-Platform-Selection - adapted for use with JLL Artifacts.
1 parent f79d863 commit 0cf858d

7 files changed

+90
-88
lines changed

.pkg/platform_augmentation.jl

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using Libdl, Base.BinaryPlatforms
2+
3+
function augment_platform!(p::Platform, tag::Union{String,Nothing} = nothing)
4+
if tag === nothing
5+
return p
6+
end
7+
8+
if tag === "cuda"
9+
if haskey(p, tag)
10+
return p
11+
end
12+
13+
# Open libcuda explicitly, so it gets `dlclose()`'ed after we're done
14+
try
15+
dlopen("libcuda") do lib
16+
# find symbol to ask for driver version; if we can't find it, just silently continue
17+
cuDriverGetVersion = dlsym(lib, "cuDriverGetVersion"; throw_error=false)
18+
if cuDriverGetVersion !== nothing
19+
# Interrogate CUDA driver for driver version:
20+
driverVersion = Ref{Cint}()
21+
ccall(cuDriverGetVersion, UInt32, (Ptr{Cint},), driverVersion)
22+
23+
# Store only the major version
24+
p[tag] = div(driverVersion, 1000)
25+
end
26+
end
27+
catch
28+
end
29+
30+
return p
31+
else
32+
@warn "Unexpected tag $tag for $p"
33+
return p
34+
end
35+
end

.pkg/select_artifacts.jl

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using TOML, Artifacts, Base.BinaryPlatforms
2+
3+
import ONNXRuntime_jll
4+
5+
include("platform_augmentation.jl")
6+
7+
artifacts_toml = Artifacts.find_artifacts_toml(pathof(ONNXRuntime_jll))
8+
9+
# Get "target triplet" from ARGS, if given (defaulting to the host triplet otherwise)
10+
target_triplet = get(ARGS, 1, Base.BinaryPlatforms.host_triplet())
11+
12+
# Augment this platform object with any special tags we require
13+
platform = augment_platform!(HostPlatform(parse(Platform, target_triplet)))
14+
15+
# Select all downloadable artifacts that match that platform
16+
artifacts = select_downloadable_artifacts(artifacts_toml; platform)
17+
18+
# Output the result to `stdout` as a TOML dictionary
19+
TOML.print(stdout, artifacts)

Artifacts.toml

-48
This file was deleted.

Project.toml

+3
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ version = "0.3.0"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
8+
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
89
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
910
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1011
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1112
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
1213
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
14+
ONNXRuntime_jll = "09e6dd1b-8208-5c7e-a336-6e9061773d0b"
1315
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1416
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
17+
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
1518

1619
[compat]
1720
ArgCheck = "2"

src/ONNXRunTime.jl

+25-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
module ONNXRunTime
2-
using Requires:@require
2+
3+
using Artifacts
4+
using LazyArtifacts
5+
using ONNXRuntime_jll
6+
7+
using Requires: @require
38

49
function _perm(arr::AbstractArray{T,N}) where {T,N}
510
ntuple(i->N+1-i, N)
@@ -11,11 +16,28 @@ function reversedims_lazy(arr)
1116
PermutedDimsArray(arr, _perm(arr))
1217
end
1318

14-
include("capi.jl")
15-
include("highlevel.jl")
19+
const EXECUTION_PROVIDERS = [:cpu, :cuda]
20+
21+
const artifact_dir_map = Dict{Symbol, String}()
22+
23+
include("../.pkg/platform_augmentation.jl")
1624

1725
function __init__()
1826
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" include("cuda.jl")
27+
28+
# Workaround/replacement for Artifacts.@artifact_str using the local Artifacts.toml
29+
function artifact_dir(m::Module, artifact_name::String, p::Platform)
30+
artifacts_toml = find_artifacts_toml(pathof(m))
31+
h = artifact_hash(artifact_name, artifacts_toml; platform = p)
32+
path = artifact_path(h)
33+
return path
34+
end
35+
36+
artifact_dir_map[:cpu] = artifact_dir(ONNXRuntime_jll, "ONNXRuntime", HostPlatform())
37+
artifact_dir_map[:cuda] = artifact_dir(ONNXRuntime_jll, "ONNXRuntime", augment_platform!(HostPlatform(), "cuda"))
1938
end
2039

40+
include("capi.jl")
41+
include("highlevel.jl")
42+
2143
end #module

src/capi.jl

+7-35
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@ This module closely follows the offical onnxruntime [C-API](https://github.com/m
55
See [here](https://github.com/microsoft/onnxruntime-inference-examples/blob/d031f879c9a8d33c8b7dc52c5bc65fe8b9e3960d/c_cxx/fns_candy_style_transfer/fns_candy_style_transfer.c) for a C code example.
66
"""
77
module CAPI
8-
using ONNXRunTime: reversedims_lazy
8+
9+
using ONNXRunTime: EXECUTION_PROVIDERS, artifact_dir_map, reversedims_lazy
10+
11+
using ONNXRuntime_jll
912

1013
using DocStringExtensions
1114
using Libdl
1215
using CEnum: @cenum
1316
using ArgCheck
14-
using LazyArtifacts
15-
using Pkg.Artifacts: artifact_path, ensure_artifact_installed, find_artifacts_toml
1617

1718
const LIB_CPU = Ref(C_NULL)
1819
const LIB_CUDA = Ref(C_NULL)
1920

20-
const EXECUTION_PROVIDERS = [:cpu, :cuda]
21-
2221
# For model_path on windows ONNX uses wchar_t while on linux + mac char is used.
2322
# Other strings use char on any platform it seems
2423
# https://github.com/microsoft/onnxruntime/issues/9568#issuecomment-952951564
@@ -45,36 +44,9 @@ end
4544

4645
function make_lib!(execution_provider)
4746
@argcheck execution_provider in EXECUTION_PROVIDERS
48-
artifact_name = if execution_provider === :cpu
49-
"onnxruntime_cpu"
50-
elseif execution_provider === :cuda
51-
"onnxruntime_gpu"
52-
else
53-
error("Unreachable")
54-
end
55-
artifacts_toml = find_artifacts_toml(joinpath(@__DIR__ , "ONNXRunTime.jl"))
56-
h = artifact_hash(artifact_name, artifacts_toml)
57-
if h === nothing
58-
msg = """
59-
Unsupported execution_provider = $(repr(execution_provider)) for
60-
this architectur.
61-
"""
62-
error(msg)
63-
end
64-
ensure_artifact_installed(artifact_name, artifacts_toml)
65-
root = artifact_path(h)
66-
@check isdir(root)
67-
dir = joinpath(root, only(readdir(root)))
68-
@check isdir(dir)
69-
libname = if Sys.iswindows()
70-
"onnxruntime.dll"
71-
elseif Sys.isapple()
72-
"libonnxruntime.dylib"
73-
else
74-
"libonnxruntime.so"
75-
end
76-
path = joinpath(dir, "lib", libname)
77-
@check isfile(path)
47+
path = ONNXRuntime_jll.libonnxruntime_path
48+
rel_path = joinpath(basename(dirname(path)), basename(path))
49+
path = joinpath(artifact_dir_map[execution_provider], rel_path)
7850
set_lib!(path, execution_provider)
7951
end
8052

src/highlevel.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using ArgCheck
2-
using LazyArtifacts
32
using DataStructures: OrderedDict
43
using DocStringExtensions
54
################################################################################
@@ -11,7 +10,7 @@ end
1110

1211

1312
using .CAPI
14-
using .CAPI: juliatype, EXECUTION_PROVIDERS
13+
using .CAPI: juliatype
1514
export InferenceSession, load_inference
1615

1716
"""

0 commit comments

Comments
 (0)