From 08edd980091a99f5fa08ce4819dc26ff95234cd5 Mon Sep 17 00:00:00 2001 From: oskarbraten <5655425+oskarbraten@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:51:06 +0100 Subject: [PATCH] Add identify_reader_sync/async for Read + Seek types --- rust/lib/src/input.rs | 25 +++++++++++++++++++++++++ rust/lib/src/lib.rs | 29 +++++++++++++++++++++++++++++ rust/lib/src/session.rs | 13 ++++++++++++- 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/rust/lib/src/input.rs b/rust/lib/src/input.rs index 7383489d..19b627ba 100644 --- a/rust/lib/src/input.rs +++ b/rust/lib/src/input.rs @@ -81,6 +81,31 @@ impl SyncInputApi for &mut T { } } +pub(crate) struct ReadSeek { + inner: T, + len: usize, +} + +impl ReadSeek { + pub(crate) fn new(mut inner: T) -> Result { + let len = inner.seek(SeekFrom::End(0))? as usize; + Ok(Self { inner, len }) + } +} + +impl SyncInput for ReadSeek {} +impl SyncInputApi for ReadSeek { + fn length(&self) -> Result { + Ok(self.len) + } + + fn read_at(&mut self, buffer: &mut [u8], offset: usize) -> Result<()> { + self.inner.seek(SeekFrom::Start(offset as u64))?; + Ok(self.inner.read_exact(buffer)?) + } +} +impl AsyncInput for ReadSeek {} + impl AsyncInputApi for T { fn length(&self) -> impl Future> { std::future::ready(self.length()) diff --git a/rust/lib/src/lib.rs b/rust/lib/src/lib.rs index e0b490f7..75bc82c8 100644 --- a/rust/lib/src/lib.rs +++ b/rust/lib/src/lib.rs @@ -127,6 +127,35 @@ mod tests { } } + #[test] + fn identify_by_reader_reference() { + #[derive(Debug, Deserialize)] + #[serde(deny_unknown_fields)] + struct Test { + prediction_mode: String, + content_base64: String, + status: String, + prediction: Option, + } + let path = format!( + "../../tests_data/reference/{MODEL_NAME}-inference_examples_by_content.json.gz" + ); + let mut tests = String::new(); + GzDecoder::new(File::open(path).unwrap()).read_to_string(&mut tests).unwrap(); + let tests: Vec = serde_json::from_str(&tests).unwrap(); + let mut session = Session::new().unwrap(); + for test in tests { + if test.prediction_mode != "high-confidence" { + continue; + } + assert_eq!(test.status, "ok"); + let expected = test.prediction.unwrap(); + let content = BASE64.decode(test.content_base64.as_bytes()).unwrap(); + let actual = session.identify_reader_sync(std::io::Cursor::new(content)).unwrap(); + assert_prediction(actual, expected, &test.content_base64); + } + } + #[test] fn identify_by_content_reference() { #[derive(Debug, Deserialize)] diff --git a/rust/lib/src/session.rs b/rust/lib/src/session.rs index cc1d592c..3af92a85 100644 --- a/rust/lib/src/session.rs +++ b/rust/lib/src/session.rs @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::io::{Read, Seek}; use std::path::Path; use ndarray::Array2; use crate::future::{exec, AsyncEnv, Env, SyncEnv}; -use crate::input::AsyncInputApi; +use crate::input::{AsyncInputApi, ReadSeek}; use crate::{AsyncInput, Builder, Features, FeaturesOrRuled, FileType, Result, SyncInput}; /// A Magika session to identify files. @@ -69,6 +70,16 @@ impl Session { self.identify_content::(file).await } + /// Identifies a single file from a [`Read`] + [`Seek`] source (synchronously). + pub fn identify_reader_sync(&mut self, reader: impl Read + Seek) -> Result { + self.identify_content_sync(ReadSeek::new(reader)?) + } + + /// Identifies a single file from a [`Read`] + [`Seek`] source (asynchronously). + pub async fn identify_reader_async(&mut self, reader: impl Read + Seek) -> Result { + self.identify_content_async(ReadSeek::new(reader)?).await + } + async fn identify_content(&mut self, file: impl AsyncInputApi) -> Result { match FeaturesOrRuled::extract(file).await? { FeaturesOrRuled::Ruled(content_type) => Ok(FileType::Ruled(content_type)),