Skip to content

Replaced artifacts with JLL artifacts #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .pkg/platform_augmentation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using Libdl, Base.BinaryPlatforms

function augment_platform!(p::Platform, tag::Union{String,Nothing} = nothing)
if tag === nothing
return p
end

if tag === "cuda"
if haskey(p, tag)
return p
end

# Open libcuda explicitly, so it gets `dlclose()`'ed after we're done
try
dlopen("libcuda") do lib
# find symbol to ask for driver version; if we can't find it, just silently continue
cuDriverGetVersion = dlsym(lib, "cuDriverGetVersion"; throw_error=false)
if cuDriverGetVersion !== nothing
# Interrogate CUDA driver for driver version:
driverVersion = Ref{Cint}()
ccall(cuDriverGetVersion, UInt32, (Ptr{Cint},), driverVersion)

# Store only the major version
p[tag] = div(driverVersion, 1000)
end
end
catch
end

return p
else
@warn "Unexpected tag $tag for $p"
return p
end
end
19 changes: 19 additions & 0 deletions .pkg/select_artifacts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using TOML, Artifacts, Base.BinaryPlatforms

import ONNXRuntime_jll

include("platform_augmentation.jl")

artifacts_toml = Artifacts.find_artifacts_toml(pathof(ONNXRuntime_jll))

# Get "target triplet" from ARGS, if given (defaulting to the host triplet otherwise)
target_triplet = get(ARGS, 1, Base.BinaryPlatforms.host_triplet())

# Augment this platform object with any special tags we require
platform = augment_platform!(HostPlatform(parse(Platform, target_triplet)))

# Select all downloadable artifacts that match that platform
artifacts = select_downloadable_artifacts(artifacts_toml; platform)

# Output the result to `stdout` as a TOML dictionary
TOML.print(stdout, artifacts)
57 changes: 0 additions & 57 deletions Artifacts.toml

This file was deleted.

3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ version = "1.3.1"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
ONNXRuntime_jll = "09e6dd1b-8208-5c7e-a336-6e9061773d0b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[compat]
ArgCheck = "2"
Expand Down
4 changes: 4 additions & 0 deletions src/ONNXRunTime.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
module ONNXRunTime

using Artifacts
using LazyArtifacts
using ONNXRuntime_jll

function _perm(arr::AbstractArray{T,N}) where {T,N}
ntuple(i->N+1-i, N)
end
Expand Down
42 changes: 7 additions & 35 deletions src/capi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@ This module closely follows the offical onnxruntime [C-API](https://github.com/m
See [here](https://github.com/microsoft/onnxruntime-inference-examples/blob/d031f879c9a8d33c8b7dc52c5bc65fe8b9e3960d/c_cxx/fns_candy_style_transfer/fns_candy_style_transfer.c) for a C code example.
"""
module CAPI
using ONNXRunTime: reversedims_lazy

using ONNXRunTime: EXECUTION_PROVIDERS, artifact_dir_map, reversedims_lazy

using ONNXRuntime_jll

using DocStringExtensions
using Libdl
using CEnum: @cenum
using ArgCheck
using LazyArtifacts
using Pkg.Artifacts: artifact_path, ensure_artifact_installed, find_artifacts_toml

const LIB_CPU = Ref(C_NULL)
const LIB_CUDA = Ref(C_NULL)

const EXECUTION_PROVIDERS = [:cpu, :cuda]

# For model_path on windows ONNX uses wchar_t while on linux + mac char is used.
# Other strings use char on any platform it seems
# https://github.com/microsoft/onnxruntime/issues/9568#issuecomment-952951564
Expand All @@ -45,36 +44,9 @@ end

function make_lib!(execution_provider)
@argcheck execution_provider in EXECUTION_PROVIDERS
artifact_name = if execution_provider === :cpu
"onnxruntime_cpu"
elseif execution_provider === :cuda
"onnxruntime_gpu"
else
error("Unreachable")
end
artifacts_toml = find_artifacts_toml(joinpath(@__DIR__ , "ONNXRunTime.jl"))
h = artifact_hash(artifact_name, artifacts_toml)
if h === nothing
msg = """
Unsupported execution_provider = $(repr(execution_provider)) for
this architectur.
"""
error(msg)
end
ensure_artifact_installed(artifact_name, artifacts_toml)
root = artifact_path(h)
@check isdir(root)
dir = joinpath(root, only(readdir(root)))
@check isdir(dir)
libname = if Sys.iswindows()
"onnxruntime.dll"
elseif Sys.isapple()
"libonnxruntime.dylib"
else
"libonnxruntime.so"
end
path = joinpath(dir, "lib", libname)
@check isfile(path)
path = ONNXRuntime_jll.libonnxruntime_path
rel_path = joinpath(basename(dirname(path)), basename(path))
path = joinpath(artifact_dir_map[execution_provider], rel_path)
set_lib!(path, execution_provider)
end

Expand Down
3 changes: 1 addition & 2 deletions src/highlevel.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using ArgCheck
using LazyArtifacts
using DataStructures: OrderedDict
using DocStringExtensions
import CEnum
Expand All @@ -12,7 +11,7 @@ end


using .CAPI
using .CAPI: juliatype, EXECUTION_PROVIDERS
using .CAPI: juliatype
export InferenceSession, load_inference, release

"""
Expand Down
Loading