Expose IndexWriter::wait_merging_threads() (#100)

master
Chris Tam 2023-07-22 15:57:30 -04:00 committed by GitHub
parent 6bc86d0e12
commit 35ed22e6d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 11 deletions

View File

@ -28,10 +28,36 @@ const RELOAD_POLICY: &str = "commit";
/// on the index object.
#[pyclass]
pub(crate) struct IndexWriter {
inner_index_writer: tv::IndexWriter,
inner_index_writer: Option<tv::IndexWriter>,
schema: tv::schema::Schema,
}
impl IndexWriter {
fn inner(&self) -> PyResult<&tv::IndexWriter> {
self.inner_index_writer.as_ref().ok_or_else(|| {
exceptions::PyRuntimeError::new_err(
"IndexWriter was consumed and no longer in a valid state",
)
})
}
fn inner_mut(&mut self) -> PyResult<&mut tv::IndexWriter> {
self.inner_index_writer.as_mut().ok_or_else(|| {
exceptions::PyRuntimeError::new_err(
"IndexWriter was consumed and no longer in a valid state",
)
})
}
fn take_inner(&mut self) -> PyResult<tv::IndexWriter> {
self.inner_index_writer.take().ok_or_else(|| {
exceptions::PyRuntimeError::new_err(
"IndexWriter was consumed and no longer in a valid state",
)
})
}
}
#[pymethods]
impl IndexWriter {
/// Add a document to the index.
@ -45,7 +71,7 @@ impl IndexWriter {
pub fn add_document(&mut self, doc: &Document) -> PyResult<u64> {
let named_doc = NamedFieldDocument(doc.field_values.clone());
let doc = self.schema.convert_named_doc(named_doc).map_err(to_pyerr)?;
self.inner_index_writer.add_document(doc).map_err(to_pyerr)
self.inner()?.add_document(doc).map_err(to_pyerr)
}
/// Helper for the `add_document` method, but passing a json string.
@ -58,7 +84,7 @@ impl IndexWriter {
/// since the creation of the index.
pub fn add_json(&mut self, json: &str) -> PyResult<u64> {
let doc = self.schema.parse_document(json).map_err(to_pyerr)?;
let opstamp = self.inner_index_writer.add_document(doc);
let opstamp = self.inner()?.add_document(doc);
opstamp.map_err(to_pyerr)
}
@ -72,7 +98,7 @@ impl IndexWriter {
///
/// Returns the `opstamp` of the last document that made it in the commit.
fn commit(&mut self) -> PyResult<u64> {
self.inner_index_writer.commit().map_err(to_pyerr)
self.inner_mut()?.commit().map_err(to_pyerr)
}
/// Rollback to the last commit
@ -81,14 +107,13 @@ impl IndexWriter {
/// commit. After calling rollback, the index is in the same state as it
/// was after the last commit.
fn rollback(&mut self) -> PyResult<u64> {
self.inner_index_writer.rollback().map_err(to_pyerr)
self.inner_mut()?.rollback().map_err(to_pyerr)
}
/// Detect and removes the files that are not used by the index anymore.
fn garbage_collect_files(&mut self) -> PyResult<()> {
use futures::executor::block_on;
block_on(self.inner_index_writer.garbage_collect_files())
.map_err(to_pyerr)?;
block_on(self.inner()?.garbage_collect_files()).map_err(to_pyerr)?;
Ok(())
}
@ -100,8 +125,8 @@ impl IndexWriter {
/// This is also the opstamp of the commit that is currently available
/// for searchers.
#[getter]
fn commit_opstamp(&self) -> u64 {
self.inner_index_writer.commit_opstamp()
fn commit_opstamp(&self) -> PyResult<u64> {
Ok(self.inner()?.commit_opstamp())
}
/// Delete all documents containing a given term.
@ -144,7 +169,16 @@ impl IndexWriter {
Value::Bool(b) => Term::from_field_bool(field, b),
Value::IpAddr(i) => Term::from_field_ip_addr(field, i)
};
Ok(self.inner_index_writer.delete_term(term))
Ok(self.inner()?.delete_term(term))
}
/// If there are some merging threads, blocks until they all finish
/// their work and then drop the `IndexWriter`.
///
/// This will consume the `IndexWriter`. Further accesses to the
/// object will result in an error.
pub fn wait_merging_threads(&mut self) -> PyResult<()> {
self.take_inner()?.wait_merging_threads().map_err(to_pyerr)
}
}
@ -229,7 +263,7 @@ impl Index {
.map_err(to_pyerr)?;
let schema = self.index.schema();
Ok(IndexWriter {
inner_index_writer: writer,
inner_index_writer: Some(writer),
schema,
})
}

View File

@ -164,6 +164,12 @@ impl Searcher {
self.inner.num_docs()
}
/// Returns the number of segments in the index.
#[getter]
fn num_segments(&self) -> usize {
self.inner.segment_readers().len()
}
/// Fetches a document from Tantivy's store given a DocAddress.
///
/// Args:

View File

@ -355,6 +355,47 @@ class TestClass(object):
result = searcher.search(query, 10, order_by_field="order")
assert len(result.hits) == 0
def test_with_merges(self):
# This test is taken from tantivy's test suite:
# https://github.com/quickwit-oss/tantivy/blob/42acd334f49d5ff7e4fe846b5c12198f24409b50/src/indexer/index_writer.rs#L1130
schema = SchemaBuilder().add_text_field("text", stored=True).build()
index = Index(schema)
index.config_reader(reload_policy="Manual")
writer = index.writer()
for _ in range(100):
doc = Document()
doc.add_text("text", "a")
writer.add_document(doc)
writer.commit()
for _ in range(100):
doc = Document()
doc.add_text("text", "a")
writer.add_document(doc)
# This should create 8 segments and trigger a merge.
writer.commit()
writer.wait_merging_threads()
# Accessing the writer again should result in an error.
with pytest.raises(RuntimeError):
writer.wait_merging_threads()
index.reload()
query = index.parse_query("a")
searcher = index.searcher()
result = searcher.search(query, limit=500, count=True)
assert result.count == 200
assert searcher.num_segments < 8
def test_doc_from_dict_schema_validation(self):
schema = (
SchemaBuilder()