diff --git a/kernel-builder/Cargo.lock b/kernel-builder/Cargo.lock index ec5dbd91..16efb5e0 100644 --- a/kernel-builder/Cargo.lock +++ b/kernel-builder/Cargo.lock @@ -245,6 +245,7 @@ dependencies = [ "serde_json", "thiserror", "toml", + "url", ] [[package]] @@ -810,6 +811,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/kernel-builder/Cargo.toml b/kernel-builder/Cargo.toml index 8df3bab0..c5e01406 100644 --- a/kernel-builder/Cargo.toml +++ b/kernel-builder/Cargo.toml @@ -23,6 +23,7 @@ serde_json = "1" serde-value = "0.7" thiserror = "1" toml = "0.8" +url = { version = "2", features = ["serde"] } [build-dependencies] minijinja-embed = "2.5" diff --git a/kernel-builder/src/config/compat.rs b/kernel-builder/src/config/compat.rs index c8553a4a..78d13d96 100644 --- a/kernel-builder/src/config/compat.rs +++ b/kernel-builder/src/config/compat.rs @@ -5,6 +5,7 @@ use serde_value::Value; use super::{v1, v2, v3, Build}; #[derive(Debug)] +#[allow(clippy::large_enum_variant)] pub enum BuildCompat { V1(v1::Build), V2(v2::Build), diff --git a/kernel-builder/src/config/mod.rs b/kernel-builder/src/config/mod.rs index 8f632989..dbd3ff78 100644 --- a/kernel-builder/src/config/mod.rs +++ b/kernel-builder/src/config/mod.rs @@ -61,6 +61,9 @@ pub struct General { /// Hugging Face Hub license identifier. pub license: Option, + /// Source repository or reference for the kernel code. + pub upstream: Option, + pub backends: Vec, pub hub: Option, pub python_depends: Option>, diff --git a/kernel-builder/src/config/v1.rs b/kernel-builder/src/config/v1.rs index c4f523bb..01c969f1 100644 --- a/kernel-builder/src/config/v1.rs +++ b/kernel-builder/src/config/v1.rs @@ -106,6 +106,7 @@ impl TryFrom for super::Build { name: build.general.name, version: None, license: None, + upstream: None, backends, hub: None, neuron: None, diff --git a/kernel-builder/src/config/v2.rs b/kernel-builder/src/config/v2.rs index 7c94eac0..7a6f7384 100644 --- a/kernel-builder/src/config/v2.rs +++ b/kernel-builder/src/config/v2.rs @@ -171,6 +171,7 @@ impl General { name: general.name, version: None, license: None, + upstream: None, backends, cuda, hub: general.hub.map(Into::into), diff --git a/kernel-builder/src/config/v3.rs b/kernel-builder/src/config/v3.rs index 625d8102..a74cd7ee 100644 --- a/kernel-builder/src/config/v3.rs +++ b/kernel-builder/src/config/v3.rs @@ -38,6 +38,8 @@ pub struct General { pub license: Option, + pub upstream: Option, + pub backends: Vec, pub cuda: Option, @@ -183,6 +185,7 @@ impl From for super::General { name: general.name, version: general.version, license: general.license, + upstream: general.upstream, backends: general.backends.into_iter().map(Into::into).collect(), cuda: general.cuda.map(Into::into), hub: general.hub.map(Into::into), @@ -363,6 +366,7 @@ impl From for General { name: general.name, version: general.version, license: general.license, + upstream: general.upstream, backends: general.backends.into_iter().map(Into::into).collect(), cuda: general.cuda.map(Into::into), hub: general.hub.map(Into::into), diff --git a/kernel-builder/src/pyproject/common.rs b/kernel-builder/src/pyproject/common.rs index d16cd2d4..af398c3c 100644 --- a/kernel-builder/src/pyproject/common.rs +++ b/kernel-builder/src/pyproject/common.rs @@ -34,6 +34,7 @@ pub fn write_metadata(general: &General, file_set: &mut FileSet) -> Result<()> { let metadata = Metadata { version: general.version, license: general.license.clone(), + upstream: general.upstream.clone(), python_depends, }; diff --git a/kernel-builder/src/pyproject/metadata.rs b/kernel-builder/src/pyproject/metadata.rs index c8635d90..d40d0f58 100644 --- a/kernel-builder/src/pyproject/metadata.rs +++ b/kernel-builder/src/pyproject/metadata.rs @@ -7,5 +7,7 @@ pub struct Metadata { pub version: Option, #[serde(skip_serializing_if = "Option::is_none")] pub license: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub upstream: Option, pub python_depends: Vec, } diff --git a/template/build.toml b/template/build.toml index 09db8326..caf4548f 100644 --- a/template/build.toml +++ b/template/build.toml @@ -8,6 +8,7 @@ backends = [ ] name = "__KERNEL_NAME__" version = 1 +upstream = "__UPSTREAM_URL__" [torch] src = [