Skip to content

Expose tantivy's TermQuery #175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use ::tantivy as tv;
use ::tantivy::schema::{Term, Value};
use pyo3::{exceptions, prelude::*, wrap_pymodule};

mod document;
Expand All @@ -20,6 +21,8 @@ use schemabuilder::SchemaBuilder;
use searcher::{DocAddress, Order, SearchResult, Searcher};
use snippet::{Snippet, SnippetGenerator};

use crate::document::extract_value;

/// Python bindings for the search engine library Tantivy.
///
/// Tantivy is a full text search engine library written in rust.
Expand Down Expand Up @@ -153,3 +156,29 @@ pub(crate) fn get_field(

Ok(field)
}

pub(crate) fn make_term(
schema: &tv::schema::Schema,
field_name: &str,
field_value: &PyAny,
) -> PyResult<tv::Term> {
let field = get_field(schema, field_name)?;
let value = extract_value(field_value)?;
let term = match value {
Value::Str(text) => Term::from_field_text(field, &text),
Value::U64(num) => Term::from_field_u64(field, num),
Value::I64(num) => Term::from_field_i64(field, num),
Value::F64(num) => Term::from_field_f64(field, num),
Value::Date(d) => Term::from_field_date(field, d),
Value::Facet(facet) => Term::from_facet(field, &facet),
Value::Bool(b) => Term::from_field_bool(field, b),
Value::IpAddr(i) => Term::from_field_ip_addr(field, i),
_ => {
return Err(exceptions::PyValueError::new_err(format!(
"Can't create a term for Field `{field_name}` with value `{field_value}`."
)))
}
};

Ok(term)
}
27 changes: 26 additions & 1 deletion src/query.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use pyo3::prelude::*;
use crate::{make_term, Schema};
use pyo3::{exceptions, prelude::*, types::PyAny};
use tantivy as tv;

/// Tantivy's Query
Expand All @@ -18,4 +19,28 @@ impl Query {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("Query({:?})", self.get()))
}

/// Construct a Tantivy's TermQuery
#[staticmethod]
#[pyo3(signature = (schema, field_name, field_value, index_option = "position"))]
pub(crate) fn term_query(
schema: &Schema,
field_name: &str,
field_value: &PyAny,
index_option: &str,
) -> PyResult<Query> {
let term = make_term(&schema.inner, field_name, field_value)?;
let index_option = match index_option {
"position" => tv::schema::IndexRecordOption::WithFreqsAndPositions,
"freq" => tv::schema::IndexRecordOption::WithFreqs,
"basic" => tv::schema::IndexRecordOption::Basic,
_ => return Err(exceptions::PyValueError::new_err(
"Invalid index option, valid choices are: 'basic', 'freq' and 'position'"
))
};
let inner = tv::query::TermQuery::new(term, index_option);
Ok(Query {
inner: Box::new(inner),
})
}
}
4 changes: 3 additions & 1 deletion tantivy/tantivy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ class Document:


class Query:
pass
@staticmethod
def term_query(schema: Schema, field_name: str, field_value: Any, index_option: str = "position") -> Query:
pass


class Order(Enum):
Expand Down
14 changes: 13 additions & 1 deletion tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pickle
import pytest
import tantivy
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator
from tantivy import Document, Index, SchemaBuilder, SnippetGenerator, Query


def schema():
Expand Down Expand Up @@ -925,3 +925,15 @@ def test_document_snippet(self, dir_index):
assert first.end == 23
html_snippet = snippet.to_html()
assert html_snippet == "The Old Man and the <b>Sea</b>"


class TestQuery(object):
def test_term_query(self, ram_index):
index = ram_index
query = Query.term_query(index.schema, "title", "sea")

result = index.searcher().search(query, 10)
assert len(result.hits) == 1
_, doc_address = result.hits[0]
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["The Old Man and the Sea"]