diff --git a/.gitignore b/.gitignore index 72ff37d7..8205e70a 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ __pycache__/ tantivy.so tantivy/tantivy.cpython*.so tantivy.egg-info/ +.python-version \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 8d047590..f61fe17b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tantivy" -version = "0.13.2" +version = "0.14.0" readme = "README.md" authors = ["Damir Jelić "] edition = "2018" @@ -12,12 +12,12 @@ crate-type = ["cdylib"] [dependencies] chrono = "0.4.19" -tantivy = "0.13.2" +tantivy = "0.14.0" itertools = "0.9.0" futures = "0.3.5" [dependencies.pyo3] -version = "0.13.2" +version = "0.14.1" features = ["extension-module"] [package.metadata.maturin] diff --git a/src/index.rs b/src/index.rs index 89347532..38460d2b 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,6 +1,8 @@ #![allow(clippy::new_ret_no_self)] -use pyo3::{exceptions, prelude::*, types::PyAny}; +use pyo3::exceptions; +use pyo3::prelude::*; +use pyo3::types::PyAny; use crate::{ document::{extract_value, Document}, @@ -277,7 +279,7 @@ impl Index { #[staticmethod] fn exists(path: &str) -> PyResult { let directory = MmapDirectory::open(path).map_err(to_pyerr)?; - Ok(tv::Index::exists(&directory)) + Ok(tv::Index::exists(&directory).unwrap()) } /// The schema of the current index. @@ -341,7 +343,6 @@ impl Index { let parser = tv::query::QueryParser::for_index(&self.index, default_fields); let query = parser.parse_query(query).map_err(to_pyerr)?; - Ok(Query { inner: query }) } } diff --git a/src/schemabuilder.rs b/src/schemabuilder.rs index 58b2a275..980b0967 100644 --- a/src/schemabuilder.rs +++ b/src/schemabuilder.rs @@ -253,11 +253,18 @@ impl SchemaBuilder { /// /// Args: /// name (str): The name of the field. - fn add_bytes_field(&mut self, name: &str) -> PyResult { + fn add_bytes_field( + &mut self, + name: &str, + stored: bool, + indexed: bool, + fast: bool, + ) -> PyResult { let builder = &mut self.builder; + let opts = SchemaBuilder::build_bytes_option(stored, indexed, fast)?; if let Some(builder) = builder.write().unwrap().as_mut() { - builder.add_bytes_field(name); + builder.add_bytes_field(name, opts); } else { return Err(exceptions::PyValueError::new_err( "Schema builder object isn't valid anymore.", @@ -316,4 +323,18 @@ impl SchemaBuilder { Ok(opts) } + + fn build_bytes_option( + stored: bool, + indexed: bool, + fast: bool, + ) -> PyResult { + let opts = schema::BytesOptions::default(); + + let opts = if stored { opts.set_stored() } else { opts }; + let opts = if indexed { opts.set_indexed() } else { opts }; + let opts = if fast { opts.set_fast() } else { opts }; + + Ok(opts) + } } diff --git a/src/searcher.rs b/src/searcher.rs index 2f0cc1bf..7d964073 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -1,7 +1,9 @@ #![allow(clippy::new_ret_no_self)] use crate::{document::Document, get_field, query::Query, to_pyerr}; +use pyo3::types::{PyDict, PyList, PyTuple}; use pyo3::{exceptions::PyValueError, prelude::*, PyObjectProtocol}; +use std::collections::BTreeMap; use tantivy as tv; use tantivy::collector::{Count, MultiCollector, TopDocs}; @@ -41,6 +43,7 @@ impl ToPyObject for Fruit { /// Object holding a results successful search. pub(crate) struct SearchResult { hits: Vec<(Fruit, DocAddress)>, + facets_result: BTreeMap>, #[pyo3(get)] /// How many documents matched the query. Only available if `count` was set /// to true during the search. @@ -52,11 +55,17 @@ impl PyObjectProtocol for SearchResult { fn __repr__(&self) -> PyResult { if let Some(count) = self.count { Ok(format!( - "SearchResult(hits: {:?}, count: {})", - self.hits, count + "SearchResult(hits: {:?}, count: {}, facets: {})", + self.hits, + count, + self.facets_result.len() )) } else { - Ok(format!("SearchResult(hits: {:?})", self.hits)) + Ok(format!( + "SearchResult(hits: {:?}, facets: {})", + self.hits, + self.facets_result.len() + )) } } } @@ -74,6 +83,16 @@ impl SearchResult { .collect(); Ok(ret) } + + #[getter] + /// The list of facets that are requested on the search based on the + /// search results. + fn facets( + &self, + _py: Python, + ) -> PyResult>> { + Ok(self.facets_result.clone()) + } } #[pymethods] @@ -90,6 +109,8 @@ impl Searcher { /// should be ordered by. The field must be declared as a fast field /// when building the schema. Note, this only works for unsigned /// fields. + /// facets (PyDict, optional): A dictionary of facet fields and keys to + /// filter. /// offset (Field, optional): The offset from which the results have /// to be returned. /// @@ -104,6 +125,7 @@ impl Searcher { limit: usize, count: bool, order_by_field: Option<&str>, + facets: Option<&PyDict>, offset: usize, ) -> PyResult { let mut multicollector = MultiCollector::new(); @@ -114,6 +136,37 @@ impl Searcher { None }; + let mut facets_requests = BTreeMap::new(); + + // We create facets collector for each field and terms defined on the facets args + if let Some(facets_dict) = facets { + for key_value_any in facets_dict.items() { + if let Ok(key_value) = key_value_any.downcast::() { + if key_value.len() != 2 { + continue; + } + let key: String = key_value.get_item(0).extract()?; + let field = get_field(&self.inner.index().schema(), &key)?; + + let mut facet_collector = + tv::collector::FacetCollector::for_field(field); + + if let Ok(value_list) = + key_value.get_item(1).downcast::() + { + for value_element in value_list { + if let Ok(s) = value_element.extract::() { + facet_collector.add_facet(&s); + } + } + let facet_handler = + multicollector.add_collector(facet_collector); + facets_requests.insert(key, facet_handler); + } + } + } + } + let (mut multifruit, hits) = { if let Some(order_by) = order_by_field { let field = get_field(&self.inner.index().schema(), order_by)?; @@ -162,7 +215,52 @@ impl Searcher { None => None, }; - Ok(SearchResult { hits, count }) + let mut facets_result: BTreeMap> = + BTreeMap::new(); + + // Go though all collectors that are registered + for (key, facet_collector) in facets_requests { + let facet_count = facet_collector.extract(&mut multifruit); + let mut facet_vec = Vec::new(); + if let Some(facets_dict) = facets { + match facets_dict.get_item(key.clone()) { + Some(facets_list_by_key) => { + if let Ok(facets_list_by_key_native) = + facets_list_by_key.downcast::() + { + for facet_value in facets_list_by_key_native { + if let Ok(s) = facet_value.extract::() { + let facet_value_vec: Vec<( + &tv::schema::Facet, + u64, + )> = facet_count.get(&s).collect(); + + // Go for all elements on facet and count to add on vector + for ( + facet_value_vec_element, + facet_count, + ) in facet_value_vec + { + facet_vec.push(( + facet_value_vec_element.to_string(), + facet_count, + )) + } + } + } + } + } + None => println!("Not found."), + } + } + facets_result.insert(key.clone(), facet_vec); + } + + Ok(SearchResult { + hits, + count, + facets_result, + }) } /// Returns the overall number of documents in the index. diff --git a/tests/tantivy_test.py b/tests/tantivy_test.py index 8c3b6368..dad39c67 100644 --- a/tests/tantivy_test.py +++ b/tests/tantivy_test.py @@ -5,7 +5,14 @@ def schema(): - return SchemaBuilder().add_text_field("title", stored=True).add_text_field("body").build() + return ( + SchemaBuilder() + .add_text_field("title", stored=True) + .add_text_field("body") + .add_facet_field("facet") + .build() + ) + def create_index(dir=None): # assume all tests will use the same documents for now @@ -27,6 +34,7 @@ def create_index(dir=None): "now without taking a fish." ), ) + doc.add_facet("facet", tantivy.Facet.from_string("/mytag")) writer.add_document(doc) # 2 use the built-in json support # keys need to coincide with field names @@ -99,7 +107,9 @@ def test_simple_search_in_ram(self, ram_index): def test_and_query(self, ram_index): index = ram_index - query = index.parse_query("title:men AND body:summer", default_field_names=["title", "body"]) + query = index.parse_query( + "title:men AND body:summer", default_field_names=["title", "body"] + ) # look for an intersection of documents searcher = index.searcher() result = searcher.search(query, 10) @@ -114,17 +124,48 @@ def test_and_query(self, ram_index): def test_and_query_parser_default_fields(self, ram_index): query = ram_index.parse_query("winter", default_field_names=["title"]) - assert repr(query) == """Query(TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114])))""" + assert ( + repr(query) + == """Query(TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114])))""" + ) def test_and_query_parser_default_fields_undefined(self, ram_index): - query = ram_index.parse_query("winter") + query = ram_index.parse_query("/winter") assert ( repr(query) == "Query(BooleanQuery { subqueries: [" "(Should, TermQuery(Term(field=0,bytes=[119, 105, 110, 116, 101, 114]))), " - "(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114])))] " + "(Should, TermQuery(Term(field=1,bytes=[119, 105, 110, 116, 101, 114]))), " + "(Should, TermQuery(Term(field=2,bytes=[119, 105, 110, 116, 101, 114])))] " "})" ) + def test_and_query_parser_default_fields_facets(self, ram_index): + index = ram_index + query = index.parse_query( + "old +facet:/mytag", default_field_names=["title", "body"] + ) + # look for an intersection of documents + searcher = index.searcher() + result = searcher.search(query, 10) + assert result.count == 1 + + query = index.parse_query( + "old +facet:/wrong", default_field_names=["title", "body"] + ) + # look for an intersection of documents + searcher = index.searcher() + result = searcher.search(query, 10) + assert result.count == 0 + + def test_search_facets(self, ram_index): + index = ram_index + query = index.parse_query("old", default_field_names=["title", "body"]) + # look for an intersection of documents + searcher = index.searcher() + result = searcher.search(query, 10, facets={"facet": ["/"]}) + assert result.count == 1 + assert ("/mytag", 1) in result.facets["facet"] + def test_query_errors(self, ram_index): index = ram_index # no "bod" field @@ -132,9 +173,11 @@ def test_query_errors(self, ram_index): index.parse_query("bod:men", ["title", "body"]) def test_order_by_search(self): - schema = (SchemaBuilder() + schema = ( + SchemaBuilder() .add_unsigned_field("order", fast="single") - .add_text_field("title", stored=True).build() + .add_text_field("title", stored=True) + .build() ) index = Index(schema) @@ -155,7 +198,6 @@ def test_order_by_search(self): doc.add_unsigned("order", 1) doc.add_text("title", "Another test title") - writer.add_document(doc) writer.commit() @@ -163,7 +205,6 @@ def test_order_by_search(self): query = index.parse_query("test") - searcher = index.searcher() result = searcher.search(query, 10, offset=2, order_by_field="order") @@ -187,9 +228,11 @@ def test_order_by_search(self): assert searched_doc["title"] == ["Test title"] def test_order_by_search_without_fast_field(self): - schema = (SchemaBuilder() + schema = ( + SchemaBuilder() .add_unsigned_field("order") - .add_text_field("title", stored=True).build() + .add_text_field("title", stored=True) + .build() ) index = Index(schema)