Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions rust/lib/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,31 @@ impl<T: SyncInputApi> SyncInputApi for &mut T {
}
}

pub(crate) struct ReadSeek<T> {
inner: T,
len: usize,
}

impl<T: Read + Seek> ReadSeek<T> {
pub(crate) fn new(mut inner: T) -> Result<Self> {
let len = inner.seek(SeekFrom::End(0))? as usize;
Ok(Self { inner, len })
}
}

impl<T: Read + Seek> SyncInput for ReadSeek<T> {}
impl<T: Read + Seek> SyncInputApi for ReadSeek<T> {
fn length(&self) -> Result<usize> {
Ok(self.len)
}

fn read_at(&mut self, buffer: &mut [u8], offset: usize) -> Result<()> {
self.inner.seek(SeekFrom::Start(offset as u64))?;
Ok(self.inner.read_exact(buffer)?)
}
}
impl<T: Read + Seek> AsyncInput for ReadSeek<T> {}

impl<T: SyncInputApi> AsyncInputApi for T {
fn length(&self) -> impl Future<Output = Result<usize>> {
std::future::ready(self.length())
Expand Down
29 changes: 29 additions & 0 deletions rust/lib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,35 @@ mod tests {
}
}

#[test]
fn identify_by_reader_reference() {
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct Test {
prediction_mode: String,
content_base64: String,
status: String,
prediction: Option<Prediction>,
}
let path = format!(
"../../tests_data/reference/{MODEL_NAME}-inference_examples_by_content.json.gz"
);
let mut tests = String::new();
GzDecoder::new(File::open(path).unwrap()).read_to_string(&mut tests).unwrap();
let tests: Vec<Test> = serde_json::from_str(&tests).unwrap();
let mut session = Session::new().unwrap();
for test in tests {
if test.prediction_mode != "high-confidence" {
continue;
}
assert_eq!(test.status, "ok");
let expected = test.prediction.unwrap();
let content = BASE64.decode(test.content_base64.as_bytes()).unwrap();
let actual = session.identify_reader_sync(std::io::Cursor::new(content)).unwrap();
assert_prediction(actual, expected, &test.content_base64);
}
}
Comment on lines +131 to +157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication between this new test identify_by_reader_reference and the existing identify_by_content_reference test. The struct definition, test data loading, and test loop are nearly identical.

To improve maintainability, you could refactor the common parts. For example, define the Test struct once outside the test functions and create a helper function to load the test data. This would make the tests cleaner and easier to manage.


#[test]
fn identify_by_content_reference() {
#[derive(Debug, Deserialize)]
Expand Down
13 changes: 12 additions & 1 deletion rust/lib/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::io::{Read, Seek};
use std::path::Path;

use ndarray::Array2;

use crate::future::{exec, AsyncEnv, Env, SyncEnv};
use crate::input::AsyncInputApi;
use crate::input::{AsyncInputApi, ReadSeek};
use crate::{AsyncInput, Builder, Features, FeaturesOrRuled, FileType, Result, SyncInput};

/// A Magika session to identify files.
Expand Down Expand Up @@ -69,6 +70,16 @@ impl Session {
self.identify_content::<AsyncEnv>(file).await
}

/// Identifies a single file from a [`Read`] + [`Seek`] source (synchronously).
pub fn identify_reader_sync(&mut self, reader: impl Read + Seek) -> Result<FileType> {
self.identify_content_sync(ReadSeek::new(reader)?)
}

/// Identifies a single file from a [`Read`] + [`Seek`] source (asynchronously).
pub async fn identify_reader_async(&mut self, reader: impl Read + Seek) -> Result<FileType> {
self.identify_content_async(ReadSeek::new(reader)?).await
}

async fn identify_content<E: Env>(&mut self, file: impl AsyncInputApi) -> Result<FileType> {
match FeaturesOrRuled::extract(file).await? {
FeaturesOrRuled::Ruled(content_type) => Ok(FileType::Ruled(content_type)),
Expand Down