Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linux ARM64 bindings (as with Mac ARM64, no prebuilt available) #79

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions onnxruntime-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ const ORT_ENV_SYSTEM_LIB_LOCATION: &str = "ORT_LIB_LOCATION";
const ORT_ENV_GPU: &str = "ORT_USE_CUDA";

/// Subdirectory (of the 'target' directory) into which to extract the prebuilt library.
const ORT_PREBUILT_EXTRACT_DIR: &str = "onnxruntime";
const ORT_EXTRACT_DIR: &str = "onnxruntime";

#[cfg(feature = "disable-sys-build-script")]
fn main() {
Expand Down Expand Up @@ -206,6 +206,7 @@ fn prebuilt_archive_url() -> (PathBuf, String) {
let arch_str = match arch.as_str() {
"x86_64" => "x64",
"x86" => "x86",
"aarch64" => "arm64",
unsupported => panic!("Unsupported architecture {:?}", unsupported),
};

Expand All @@ -216,6 +217,14 @@ fn prebuilt_archive_url() -> (PathBuf, String) {
);
}

if arch_str == "arm64" {
panic!(
"{}",
"ONNX on ARM64 has no prebuilt packages - ".to_string() +
"run again with ORT_STRATEGY=\"compile\""
);
}

// Only Windows and Linux x64 support GPU
if !gpu_str.is_empty() {
if arch_str == "x64" && (os == "windows" || os == "linux") {
Expand All @@ -241,14 +250,23 @@ fn prebuilt_archive_url() -> (PathBuf, String) {
ORT_RELEASE_BASE_URL, ORT_VERSION, prebuilt_archive
);


(PathBuf::from(prebuilt_archive), prebuilt_url)
}

// Get URL of source archive corresponding to release, can't be a simple
// constant because of formatting macro.
fn get_source_url(ort_version: String) -> String {
// FIXME: This won't work because of submodules, need to download git repo
// in order to get version with submodules.
format!("https://github.com/microsoft/onnxruntime/archive/refs/tags/v{}.tar.gz", ort_version)
}

fn prepare_libort_dir_prebuilt() -> PathBuf {
let (prebuilt_archive, prebuilt_url) = prebuilt_archive_url();

let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let extract_dir = out_dir.join(ORT_PREBUILT_EXTRACT_DIR);
let extract_dir = out_dir.join(ORT_EXTRACT_DIR);
let downloaded_file = out_dir.join(&prebuilt_archive);

if !downloaded_file.exists() {
Expand All @@ -271,6 +289,15 @@ fn prepare_libort_dir_prebuilt() -> PathBuf {
extract_dir.join(prebuilt_archive.file_stem().unwrap())
}

// Compiles ONNX runtime lib from scratch
// TODO: Clone repo, checkout tag version, run build.sh script with passed-in
// settings.
// Then have to copy over the artifacts.
// fn prepare_libort_dir_compile() -> PathBuf {
// let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
// out_dir
// }

fn prepare_libort_dir() -> PathBuf {
let strategy = env::var(ORT_ENV_STRATEGY);
println!(
Expand Down
21 changes: 11 additions & 10 deletions onnxruntime-sys/examples/c_api_sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::os::unix::ffi::OsStrExt;
#[cfg(target_family = "windows")]
use std::os::windows::ffi::OsStrExt;

use std::os::raw::c_char;
use onnxruntime_sys::*;

// https://github.com/microsoft/onnxruntime/blob/v1.4.0/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp
Expand Down Expand Up @@ -117,7 +118,7 @@ fn main() {
// iterate over all input nodes
for i in 0..num_input_nodes {
// print input node names
let mut input_name: *mut i8 = std::ptr::null_mut();
let mut input_name: *mut c_char = std::ptr::null_mut();
let status = unsafe {
g_ort.as_ref().unwrap().SessionGetInputName.unwrap()(
session_ptr,
Expand Down Expand Up @@ -282,24 +283,24 @@ fn main() {
.into_iter()
.map(|n| std::ffi::CString::new(n).unwrap())
.collect();
let input_node_names_ptr: Vec<*const i8> = input_node_names_cstring
let input_node_names_ptr: Vec<*const c_char> = input_node_names_cstring
.into_iter()
.map(|n| n.into_raw() as *const i8)
.map(|n| n.into_raw() as *const c_char)
.collect();
let input_node_names_ptr_ptr: *const *const i8 = input_node_names_ptr.as_ptr();
let input_node_names_ptr_ptr: *const *const c_char = input_node_names_ptr.as_ptr();

let output_node_names_cstring: Vec<std::ffi::CString> = output_node_names
.into_iter()
.map(|n| std::ffi::CString::new(n.clone()).unwrap())
.collect();
let output_node_names_ptr: Vec<*const i8> = output_node_names_cstring
let output_node_names_ptr: Vec<*const c_char> = output_node_names_cstring
.iter()
.map(|n| n.as_ptr() as *const i8)
.map(|n| n.as_ptr() as *const c_char)
.collect();
let output_node_names_ptr_ptr: *const *const i8 = output_node_names_ptr.as_ptr();
let output_node_names_ptr_ptr: *const *const c_char = output_node_names_ptr.as_ptr();

let _input_node_names_cstring =
unsafe { std::ffi::CString::from_raw(input_node_names_ptr[0] as *mut i8) };
unsafe { std::ffi::CString::from_raw(input_node_names_ptr[0] as *mut c_char) };
let run_options_ptr: *const OrtRunOptions = std::ptr::null();
let mut output_tensor_ptr: *mut OrtValue = std::ptr::null_mut();
let output_tensor_ptr_ptr: *mut *mut OrtValue = &mut output_tensor_ptr;
Expand Down Expand Up @@ -371,7 +372,7 @@ fn CheckStatus(g_ort: *const OrtApi, status: *const OrtStatus) -> Result<(), Str
}
}

fn char_p_to_str<'a>(raw: *const i8) -> Result<&'a str, std::str::Utf8Error> {
let c_str = unsafe { std::ffi::CStr::from_ptr(raw as *mut i8) };
fn char_p_to_str<'a>(raw: *const c_char) -> Result<&'a str, std::str::Utf8Error> {
let c_str = unsafe { std::ffi::CStr::from_ptr(raw as *mut c_char) };
c_str.to_str()
}
6 changes: 6 additions & 0 deletions onnxruntime-sys/src/generated/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/generated/windows/x86_64/bindings.rs"
));

#[cfg(all(target_os = "linux", target_arch = "aarch64"))]
include!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/generated/linux/aarch64/bindings.rs"
));
Loading