diff --git a/src/index.rs b/src/index.rs index 3cdadf6..cac739e 100644 --- a/src/index.rs +++ b/src/index.rs @@ -28,10 +28,36 @@ const RELOAD_POLICY: &str = "commit"; /// on the index object. #[pyclass] pub(crate) struct IndexWriter { - inner_index_writer: tv::IndexWriter, + inner_index_writer: Option, schema: tv::schema::Schema, } +impl IndexWriter { + fn inner(&self) -> PyResult<&tv::IndexWriter> { + self.inner_index_writer.as_ref().ok_or_else(|| { + exceptions::PyRuntimeError::new_err( + "IndexWriter was consumed and no longer in a valid state", + ) + }) + } + + fn inner_mut(&mut self) -> PyResult<&mut tv::IndexWriter> { + self.inner_index_writer.as_mut().ok_or_else(|| { + exceptions::PyRuntimeError::new_err( + "IndexWriter was consumed and no longer in a valid state", + ) + }) + } + + fn take_inner(&mut self) -> PyResult { + self.inner_index_writer.take().ok_or_else(|| { + exceptions::PyRuntimeError::new_err( + "IndexWriter was consumed and no longer in a valid state", + ) + }) + } +} + #[pymethods] impl IndexWriter { /// Add a document to the index. @@ -45,7 +71,7 @@ impl IndexWriter { pub fn add_document(&mut self, doc: &Document) -> PyResult { let named_doc = NamedFieldDocument(doc.field_values.clone()); let doc = self.schema.convert_named_doc(named_doc).map_err(to_pyerr)?; - self.inner_index_writer.add_document(doc).map_err(to_pyerr) + self.inner()?.add_document(doc).map_err(to_pyerr) } /// Helper for the `add_document` method, but passing a json string. @@ -58,7 +84,7 @@ impl IndexWriter { /// since the creation of the index. pub fn add_json(&mut self, json: &str) -> PyResult { let doc = self.schema.parse_document(json).map_err(to_pyerr)?; - let opstamp = self.inner_index_writer.add_document(doc); + let opstamp = self.inner()?.add_document(doc); opstamp.map_err(to_pyerr) } @@ -72,7 +98,7 @@ impl IndexWriter { /// /// Returns the `opstamp` of the last document that made it in the commit. fn commit(&mut self) -> PyResult { - self.inner_index_writer.commit().map_err(to_pyerr) + self.inner_mut()?.commit().map_err(to_pyerr) } /// Rollback to the last commit @@ -81,14 +107,13 @@ impl IndexWriter { /// commit. After calling rollback, the index is in the same state as it /// was after the last commit. fn rollback(&mut self) -> PyResult { - self.inner_index_writer.rollback().map_err(to_pyerr) + self.inner_mut()?.rollback().map_err(to_pyerr) } /// Detect and removes the files that are not used by the index anymore. fn garbage_collect_files(&mut self) -> PyResult<()> { use futures::executor::block_on; - block_on(self.inner_index_writer.garbage_collect_files()) - .map_err(to_pyerr)?; + block_on(self.inner()?.garbage_collect_files()).map_err(to_pyerr)?; Ok(()) } @@ -100,8 +125,8 @@ impl IndexWriter { /// This is also the opstamp of the commit that is currently available /// for searchers. #[getter] - fn commit_opstamp(&self) -> u64 { - self.inner_index_writer.commit_opstamp() + fn commit_opstamp(&self) -> PyResult { + Ok(self.inner()?.commit_opstamp()) } /// Delete all documents containing a given term. @@ -144,7 +169,16 @@ impl IndexWriter { Value::Bool(b) => Term::from_field_bool(field, b), Value::IpAddr(i) => Term::from_field_ip_addr(field, i) }; - Ok(self.inner_index_writer.delete_term(term)) + Ok(self.inner()?.delete_term(term)) + } + + /// If there are some merging threads, blocks until they all finish + /// their work and then drop the `IndexWriter`. + /// + /// This will consume the `IndexWriter`. Further accesses to the + /// object will result in an error. + pub fn wait_merging_threads(&mut self) -> PyResult<()> { + self.take_inner()?.wait_merging_threads().map_err(to_pyerr) } } @@ -229,7 +263,7 @@ impl Index { .map_err(to_pyerr)?; let schema = self.index.schema(); Ok(IndexWriter { - inner_index_writer: writer, + inner_index_writer: Some(writer), schema, }) } diff --git a/src/searcher.rs b/src/searcher.rs index 8375c5f..7b82964 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -164,6 +164,12 @@ impl Searcher { self.inner.num_docs() } + /// Returns the number of segments in the index. + #[getter] + fn num_segments(&self) -> usize { + self.inner.segment_readers().len() + } + /// Fetches a document from Tantivy's store given a DocAddress. /// /// Args: diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index c0ea109..799396b 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -355,6 +355,47 @@ class TestClass(object): result = searcher.search(query, 10, order_by_field="order") assert len(result.hits) == 0 + def test_with_merges(self): + # This test is taken from tantivy's test suite: + # https://github.com/quickwit-oss/tantivy/blob/42acd334f49d5ff7e4fe846b5c12198f24409b50/src/indexer/index_writer.rs#L1130 + schema = SchemaBuilder().add_text_field("text", stored=True).build() + + index = Index(schema) + index.config_reader(reload_policy="Manual") + + writer = index.writer() + + for _ in range(100): + doc = Document() + doc.add_text("text", "a") + + writer.add_document(doc) + + writer.commit() + + for _ in range(100): + doc = Document() + doc.add_text("text", "a") + + writer.add_document(doc) + + # This should create 8 segments and trigger a merge. + writer.commit() + writer.wait_merging_threads() + + # Accessing the writer again should result in an error. + with pytest.raises(RuntimeError): + writer.wait_merging_threads() + + index.reload() + + query = index.parse_query("a") + searcher = index.searcher() + result = searcher.search(query, limit=500, count=True) + assert result.count == 200 + + assert searcher.num_segments < 8 + def test_doc_from_dict_schema_validation(self): schema = ( SchemaBuilder()