Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 2 additions & 32 deletions nidx/nidx_text/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub struct FieldUid {
}

// Unique id for a field, equivalent to {rid}/{field_type}/{field_id}[/{split}]/{paragraph_start}-{paragraph_end}
#[derive(Clone, PartialEq, Eq, Hash)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct ParagraphUid {
pub rid: String,
pub field_type: String,
Expand Down Expand Up @@ -224,37 +224,7 @@ impl TextSearcher {
&self,
paragraph_uids: Vec<ParagraphUid>,
) -> anyhow::Result<HashMap<ParagraphUid, Option<String>>> {
let mut paragraph_fields = HashMap::new();
for paragraph_id in paragraph_uids {
let field_id = FieldUid::from(paragraph_id.clone());
paragraph_fields
.entry(field_id)
.and_modify(|v: &mut Vec<ParagraphUid>| v.push(paragraph_id.clone()))
.or_insert(vec![paragraph_id]);
}

let fields_text = self
.reader
.get_fields_text(paragraph_fields.keys().cloned().collect())?;

let mut paragraphs_text = HashMap::new();

for (field_id, field_text) in fields_text {
if let Some(paragraphs) = paragraph_fields.remove(&field_id) {
for paragraph_id in paragraphs {
let paragraph_text = field_text.as_ref().map(|field_text| {
field_text
.chars()
.skip(paragraph_id.paragraph_start as usize)
.take((paragraph_id.paragraph_end - paragraph_id.paragraph_start) as usize)
.collect()
});
paragraphs_text.insert(paragraph_id, paragraph_text);
}
}
}

Ok(paragraphs_text)
self.reader.get_paragraphs_text(paragraph_uids)
}

pub fn iterator(&self, request: &StreamRequest) -> anyhow::Result<impl Iterator<Item = DocumentItem> + use<>> {
Expand Down
271 changes: 250 additions & 21 deletions nidx/nidx_text/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::time::*;

use crate::schema::{datetime_utc_to_timestamp, decode_field_id, encode_field_id_bytes};
use crate::search_query::filter_to_query;
use crate::{DocumentSearchRequest, FieldUid, prefilter::*};
use crate::{DocumentSearchRequest, FieldUid, ParagraphUid, prefilter::*};

use super::schema::TextSchema;
use super::search_query;
Expand Down Expand Up @@ -469,29 +469,10 @@ impl TextReaderService {
}

pub fn get_fields_text(&self, field_uids: Vec<FieldUid>) -> anyhow::Result<HashMap<FieldUid, Option<String>>> {
let limit = field_uids.len();

// due to implementation details, we use here a BooleanQuery as it's
// around 2 orders of magnitude faster than a TermSetQuery
let mut subqueries: Vec<Box<dyn Query>> = vec![];
for uid in field_uids {
subqueries.push(Box::new(TermQuery::new(
Term::from_field_bytes(
self.schema.encoded_field_id_bytes,
&encode_field_id_bytes(
Uuid::parse_str(&uid.rid)?,
&format!("{}/{}", uid.field_type, uid.field_name),
),
),
IndexRecordOption::Basic,
)));
}
let query: Box<dyn Query> = Box::new(BooleanQuery::union(subqueries));
let collector = TopDocs::with_limit(limit).order_by_score();
let searcher = self.reader.searcher();
let results = self.search_fields(searcher.clone(), field_uids.iter())?;

let mut texts = HashMap::new();
let results = searcher.search(&query, &collector)?;
for (_score, doc_id) in results {
let doc = searcher.doc::<TantivyDocument>(doc_id)?;
let doc_value = doc.get_first(self.schema.text);
Expand Down Expand Up @@ -525,6 +506,162 @@ impl TextReaderService {

Ok(texts)
}

pub fn get_paragraphs_text(
&self,
paragraph_uids: Vec<ParagraphUid>,
) -> anyhow::Result<HashMap<ParagraphUid, Option<String>>> {
let mut field_paragraph_ids = HashMap::new();
for paragraph_id in paragraph_uids {
let field_id = FieldUid::from(paragraph_id.clone());
field_paragraph_ids
.entry(field_id)
.and_modify(|v: &mut Vec<ParagraphUid>| v.push(paragraph_id.clone()))
.or_insert(vec![paragraph_id]);
}

let searcher = self.reader.searcher();
let results = self.search_fields(searcher.clone(), field_paragraph_ids.keys())?;

let mut paragraphs_text = HashMap::new();
for (_score, doc_id) in results {
let doc = searcher.doc::<TantivyDocument>(doc_id)?;

let Some(text) = doc.get_first(self.schema.text).map(|value| value.as_str().unwrap()) else {
// can't do anything without extracted text
continue;
};
let rid = String::from_utf8(
doc.get_first(self.schema.uuid)
.expect("document doesn't appear to have uuid.")
.as_bytes()
.unwrap()
.to_vec(),
)
.unwrap();
let field = decode_facet(
doc.get_first(self.schema.field)
.expect("document doesn't appear to have field.")
.as_facet()
.unwrap(),
)
.to_path_string();

let parts: Vec<_> = field.split('/').collect(); // e.g. /a/title
let field_uid = FieldUid {
rid,
field_type: parts[1].to_string(),
field_name: parts[2].to_string(),
split: parts.get(3).map(|x| x.to_string()),
};

if let Some(paragraph_ids) = field_paragraph_ids.remove(&field_uid) {
// iterate the text by unicode characters only once, reusing the same iterator for
// all paragraphs on the field. This is more useful for multiple paragraphs per
// field on a large text
let mut paragraphs = Self::extract_paragraphs(paragraph_ids.into_iter(), text.chars());
for (k, v) in paragraphs.drain() {
paragraphs_text.insert(k, v);
}
}
}

Ok(paragraphs_text)
}

fn search_fields<'a>(
&self,
searcher: Searcher,
field_uids: impl Iterator<Item = &'a FieldUid>,
) -> anyhow::Result<Vec<(f32, DocAddress)>> {
// due to implementation details, we use here a BooleanQuery as it's
// around 2 orders of magnitude faster than a TermSetQuery
let mut subqueries: Vec<Box<dyn Query>> = vec![];
for field_uid in field_uids {
subqueries.push(Box::new(TermQuery::new(
Term::from_field_bytes(
self.schema.encoded_field_id_bytes,
&encode_field_id_bytes(
Uuid::parse_str(&field_uid.rid)?,
&format!("{}/{}", field_uid.field_type, field_uid.field_name),
),
),
IndexRecordOption::Basic,
)));
}
// we store a doc per field, so we expect at most the number of unique fields
let limit = subqueries.len();
let query: Box<dyn Query> = Box::new(BooleanQuery::union(subqueries));
let collector = TopDocs::with_limit(limit).order_by_score();
let results = searcher.search(&query, &collector)?;
Ok(results)
}

fn extract_paragraphs(
ids: impl Iterator<Item = ParagraphUid>,
mut text: std::str::Chars<'_>,
) -> HashMap<ParagraphUid, Option<String>> {
let mut paragraphs = HashMap::new();

// sort paragraph_ids by (start, end) to avoid the need of already read chars from the text
let mut ids = ids.sorted_by_key(|id| (id.paragraph_start, id.paragraph_end));

let Some(first) = ids.next() else {
return paragraphs;
};
let mut window = std::ops::Range {
start: first.paragraph_start,
end: first.paragraph_end,
};
let mut window_paragraphs = vec![first];

let mut skip = 0;

for paragraph_id in ids {
if paragraph_id.paragraph_start < window.end {
// This paragraph overlaps with the window. We can't be sure if there will be more
// in the future, so we widen the window and continue
window.end = std::cmp::max(window.end, paragraph_id.paragraph_end);
window_paragraphs.push(paragraph_id);
} else {
// A non-overlapping paragraph means we won't find any other paragraph that needs
// the text from the window. We then read the window and extract the paragraphs
skip = window.start - skip;
let take = window.end - window.start;
let chunk: Vec<char> = text.by_ref().skip(skip as usize).take(take as usize).collect();
skip = window.end;

for id in window_paragraphs.drain(..) {
let start = (id.paragraph_start - window.start) as usize;
// clamp to chunk size, we don't have more text
let end = std::cmp::min((id.paragraph_end - window.start) as usize, chunk.len());
let paragraph: String = chunk[start..end].iter().collect();
paragraphs.insert(id, Some(paragraph));
}

// As the new paragraph could overlap with future ones, we reset the window with it
window = std::ops::Range {
start: paragraph_id.paragraph_start,
end: paragraph_id.paragraph_end,
};
window_paragraphs.push(paragraph_id);
}
}

// with no more paragraphs, we can finish with the window
skip = window.start - skip;
let take = window.end - window.start;
let chunk: Vec<char> = text.by_ref().skip(skip as usize).take(take as usize).collect();

for id in window_paragraphs.drain(..) {
let start = (id.paragraph_start - window.start) as usize;
let end = std::cmp::min((id.paragraph_end - window.start) as usize, chunk.len());
let paragraph: String = chunk[start..end].iter().collect();
paragraphs.insert(id, Some(paragraph));
}

paragraphs
}
}

pub struct BatchProducer {
Expand Down Expand Up @@ -588,3 +725,95 @@ impl Iterator for BatchProducer {
Some(items)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn extract_paragraphs_using_a_single_iterator() {
let text = "This is my test text";
let word_positions = [(0, 4), (5, 7), (8, 10), (11, 15), (16, 20)];
let words: Vec<ParagraphUid> = word_positions
.into_iter()
.map(|(start, end)| ParagraphUid {
rid: "rid".to_string(),
field_type: "a".to_string(),
field_name: "title".to_string(),
split: None,
paragraph_start: start,
paragraph_end: end,
})
.collect();

// longer paragraphs overlapping with the words above
let overlapping_positions = [
(0, 7),
// intersects with the above but has content outside
(5, 15),
// same as above
(5, 15),
// subset of the above
(8, 15),
];
let overlapping: Vec<ParagraphUid> = overlapping_positions
.into_iter()
.map(|(start, end)| ParagraphUid {
rid: "rid".to_string(),
field_type: "a".to_string(),
field_name: "title".to_string(),
split: None,
paragraph_start: start,
paragraph_end: end,
})
.collect();

let out_of_bounds: Vec<ParagraphUid> = [(8, 100), (16, 100), (200, 300)]
.into_iter()
.map(|(start, end)| ParagraphUid {
rid: "rid".to_string(),
field_type: "a".to_string(),
field_name: "title".to_string(),
split: None,
paragraph_start: start,
paragraph_end: end,
})
.collect();

let paragraphs = TextReaderService::extract_paragraphs(
[
overlapping[2].clone(),
words[3].clone(),
overlapping[3].clone(),
out_of_bounds[2].clone(),
words[1].clone(),
overlapping[0].clone(),
out_of_bounds[1].clone(),
words[4].clone(),
words[0].clone(),
overlapping[1].clone(),
out_of_bounds[0].clone(),
words[2].clone(),
]
.into_iter(),
text.chars(),
);
assert_eq!(
paragraphs,
HashMap::from_iter([
(words[0].clone(), Some("This".to_string())),
(overlapping[0].clone(), Some("This is".to_string())),
(words[1].clone(), Some("is".to_string())),
(overlapping[1].clone(), Some("is my test".to_string())),
(overlapping[2].clone(), Some("is my test".to_string())),
(words[2].clone(), Some("my".to_string())),
(overlapping[3].clone(), Some("my test".to_string())),
(out_of_bounds[0].clone(), Some("my test text".to_string())),
(words[3].clone(), Some("test".to_string())),
(words[4].clone(), Some("text".to_string())),
(out_of_bounds[1].clone(), Some("text".to_string())),
(out_of_bounds[2].clone(), Some("".to_string())),
])
);
}
}
20 changes: 18 additions & 2 deletions nidx/src/searcher/shard_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::time::Instant;

use nidx_protos::{ExtractedTextsRequest, ExtractedTextsResponse};
use nidx_text::{FieldUid, ParagraphUid, TextSearcher};
use tracing::Span;

use crate::errors::{NidxError, NidxResult};
use crate::searcher::index_cache::IndexCache;
Expand All @@ -47,8 +48,24 @@ pub async fn extracted_texts(
return Err(NidxError::NotFound);
};
let index = index_cache.get(&text_index_id).await?;
let searcher: &TextSearcher = index.as_ref().into();

let span = Span::current();
let extracted_texts = tokio::task::spawn_blocking(move || {
span.in_scope(|| {
let searcher: &TextSearcher = index.as_ref().into();
blocking_extracted_texts(searcher, request)
})
})
.await??;

tracing::debug!("Extracted texts took {:?}", start.elapsed());
Ok(extracted_texts)
}

fn blocking_extracted_texts(
searcher: &TextSearcher,
request: ExtractedTextsRequest,
) -> NidxResult<ExtractedTextsResponse> {
let mut extracted_texts = ExtractedTextsResponse::default();

if !request.field_ids.is_empty() {
Expand Down Expand Up @@ -85,6 +102,5 @@ pub async fn extracted_texts(
}
}

tracing::info!("Extracted texts took {:?}µs", start.elapsed().as_micros());
Ok(extracted_texts)
}
Loading