Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 34 additions & 185 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ extern crate bindgen;

use std::env;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};

fn get_catboost_version() -> String {
Expand Down Expand Up @@ -38,213 +37,63 @@ fn get_platform_info() -> (String, String) {
}

fn download_model_interface_headers(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
let version = get_catboost_version();

// Create the model_interface directory
let model_interface_dir = out_dir.join("libs/model_interface");
fs::create_dir_all(&model_interface_dir)?;

// Download the c_api.h file
let c_api_url = format!(
"https://raw.githubusercontent.com/catboost/catboost/v{}/catboost/libs/model_interface/c_api.h",
version
);
// Use bundled c_api.h file (hardcoded for testing)
let manifest_dir = std::path::Path::new(file!()).parent().unwrap();
let bundled_c_api = manifest_dir.join("c_api.h");
let c_api_path = model_interface_dir.join("c_api.h");

println!("cargo:warning=Downloading c_api.h from: {}", c_api_url);
println!("cargo:warning=Using bundled c_api.h from: {}", bundled_c_api.display());

let response = ureq::get(&c_api_url).call()?;
let status = response.status();
if !(200..300).contains(&status) {
return Err(format!("Failed to download c_api.h: HTTP {}", status).into());
}

let c_api_path = model_interface_dir.join("c_api.h");
let mut file = fs::File::create(&c_api_path)?;
io::copy(&mut response.into_reader(), &mut file)?;
fs::copy(&bundled_c_api, &c_api_path)?;

Ok(())
}

fn download_compiled_library(out_dir: &Path) -> Result<(), Box<dyn std::error::Error>> {
let (os, arch) = get_platform_info();
let version = get_catboost_version();

// Create the library directory early
// Create the library directory
let lib_dir = out_dir.join("libs");
fs::create_dir_all(&lib_dir)?;

// Parse version to determine URL format
// v1.0.x - v1.1.x use simple filenames
// v1.2+ use platform-specific versioned filenames
let version_parts: Vec<&str> = version.split('.').collect();
let major: u32 = version_parts
.first()
.and_then(|s| s.parse().ok())
.unwrap_or(1);
let minor: u32 = version_parts
.get(1)
.and_then(|s| s.parse().ok())
.unwrap_or(0);

let use_new_format = major > 1 || (major == 1 && minor >= 2);

// Determine download URL based on version and platform
let (lib_filename, download_url) = if use_new_format {
// v1.2+ format with platform and version in filename
match (os.as_str(), arch.as_str()) {
("linux", "x86_64") => (
"libcatboostmodel.so".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-x86_64-{}.so",
version, version
),
),
("linux", "aarch64") => (
"libcatboostmodel.so".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-linux-aarch64-{}.so",
version, version
),
),
("darwin", "x86_64") | ("darwin", "aarch64") => (
"libcatboostmodel.dylib".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel-darwin-universal2-{}.dylib",
version, version
),
),
("windows", "x86_64") => {
// On Windows, we need to download both the DLL and LIB files
// First download the DLL
let dll_url = format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-x86_64-{}.dll",
version, version
);
println!("cargo:warning=Downloading Windows DLL from: {}", dll_url);
let dll_response = ureq::get(&dll_url).call()?;
if !(200..300).contains(&dll_response.status()) {
return Err(
format!("Failed to download DLL: HTTP {}", dll_response.status()).into(),
);
}
let dll_path = lib_dir.join("catboostmodel.dll");
let mut dll_file = fs::File::create(&dll_path)?;
io::copy(&mut dll_response.into_reader(), &mut dll_file)?;

// Then download the LIB file
let lib_url = format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-x86_64-{}.lib",
version, version
);
println!("cargo:warning=Downloading Windows LIB from: {}", lib_url);
let lib_response = ureq::get(&lib_url).call()?;
if !(200..300).contains(&lib_response.status()) {
return Err(
format!("Failed to download LIB: HTTP {}", lib_response.status()).into(),
);
}
let lib_path = lib_dir.join("catboostmodel.lib");
let mut lib_file = fs::File::create(&lib_path)?;
io::copy(&mut lib_response.into_reader(), &mut lib_file)?;

// Return early for Windows since we've already downloaded both files
println!(
"cargo:warning=Downloaded CatBoost library to: {}",
dll_path.display()
);
return Ok(());
}
("windows", "aarch64") => (
"catboostmodel.dll".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel-windows-aarch64-{}.dll",
version, version
),
),
_ => return Err(format!("Unsupported platform: {}-{}", os, arch).into()),
}
} else {
// v1.0.x - v1.1.x format with simple filenames
match os.as_str() {
"linux" => (
"libcatboostmodel.so".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel.so",
version
),
),
"darwin" => (
"libcatboostmodel.dylib".to_string(),
format!(
"https://github.com/catboost/catboost/releases/download/v{}/libcatboostmodel.dylib",
version
),
),
"windows" => {
// On Windows, we need to download both the DLL and LIB files
// First download the DLL
let dll_url = format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.dll",
version
);
println!("cargo:warning=Downloading Windows DLL from: {}", dll_url);
let dll_response = ureq::get(&dll_url).call()?;
if !(200..300).contains(&dll_response.status()) {
return Err(
format!("Failed to download DLL: HTTP {}", dll_response.status()).into(),
);
}
let dll_path = lib_dir.join("catboostmodel.dll");
let mut dll_file = fs::File::create(&dll_path)?;
io::copy(&mut dll_response.into_reader(), &mut dll_file)?;

// Then download the LIB file
let lib_url = format!(
"https://github.com/catboost/catboost/releases/download/v{}/catboostmodel.lib",
version
);
println!("cargo:warning=Downloading Windows LIB from: {}", lib_url);
let lib_response = ureq::get(&lib_url).call()?;
if !(200..300).contains(&lib_response.status()) {
return Err(
format!("Failed to download LIB: HTTP {}", lib_response.status()).into(),
);
}
let lib_path = lib_dir.join("catboostmodel.lib");
let mut lib_file = fs::File::create(&lib_path)?;
io::copy(&mut lib_response.into_reader(), &mut lib_file)?;

// Return early for Windows since we've already downloaded both files
println!(
"cargo:warning=Downloaded CatBoost library to: {}",
dll_path.display()
);
return Ok(());
}
_ => return Err(format!("Unsupported platform: {}", os).into()),
// Use bundled library file based on target platform (hardcoded for testing)
let manifest_dir = std::path::Path::new(file!()).parent().unwrap();

// Determine source and target filenames based on OS and architecture
let (bundled_lib, lib_filename) = match (os.as_str(), arch.as_str()) {
("windows", _) => (
manifest_dir.join("catboostmodel.dll"),
"catboostmodel.dll"
),
("darwin", _) => (
manifest_dir.join("libcatboostmodel.dylib"),
"libcatboostmodel.dylib"
),
("linux", "x86_64") => (
manifest_dir.join("libcatboostmodel-x86_64.so"),
"libcatboostmodel.so"
),
("linux", "aarch64") => (
manifest_dir.join("libcatboostmodel.so"),
"libcatboostmodel.so"
),
_ => {
return Err(format!("Unsupported platform: {}-{}", os, arch).into());
}
};

println!(
"cargo:warning=Downloading CatBoost v{} library from: {}",
version, download_url
);
let lib_path = lib_dir.join(lib_filename);

// Download the library directly into the `libs` directory with its correct name
let lib_path = lib_dir.join(&lib_filename);
let mut dest = fs::File::create(&lib_path)?;

let response = ureq::get(&download_url).call()?;
let status = response.status();
if !(200..300).contains(&status) {
return Err(format!("Failed to download library: HTTP {}", status).into());
}
println!("cargo:warning=Using bundled {} library from: {}", arch, bundled_lib.display());

// SIMPLIFIED: No need for extraction, just copy the downloaded content
io::copy(&mut response.into_reader(), &mut dest)?;
fs::copy(&bundled_lib, &lib_path)?;

println!(
"cargo:warning=Downloaded CatBoost library to: {}",
"cargo:warning=Copied CatBoost library to: {}",
lib_path.display()
);

Expand Down
Loading