Skip to content

Commit 8a1c2c2

Browse files
committed
added featurizers
1 parent 5eb6045 commit 8a1c2c2

File tree

13 files changed

+328
-99
lines changed

13 files changed

+328
-99
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@
1313
[submodule "lib/cpp/pattern_matching"]
1414
path = lib/cpp/pattern_matching
1515
url = https://github.com/roberto-trani/pattern_matching.git
16+
[submodule "lib/cpp/buffered_stream"]
17+
path = lib/cpp/buffered_stream
18+
url = https://github.com/roberto-trani/buffered_stream.git

cfg.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@
44

55
# base directory
66
base_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
7-
# add the python libraries to the sys path
8-
sys.path.append(base_dir + "lib/python/")
9-
sys.path.append(base_dir + "lib/cpp/")
7+
data_dir = base_dir + "data/"
8+
lib_dir = base_dir + "lib/"
9+
10+
# add the python libraries to the python path
11+
sys.path.append(lib_dir + "python/")
12+
sys.path.append(lib_dir + "cpp/")
13+
sys.path.append(lib_dir + "cython/")
1014

1115
# other directories
12-
raw_dir = base_dir + "raw/"
13-
processed_dir = base_dir + "processed/"
14-
thesaurus_dir = base_dir + "thesaurus/"
16+
processed_dir = data_dir + "processed/"
17+
raw_dir = data_dir + "raw/"
18+
thesaurus_dir = data_dir + "thesaurus/"
19+
tmp_dir = data_dir + "tmp/"
1520

1621
# number of parts the wikipedia file must be splitted to
1722
wiki_preprocessing_split_into = 10
1823

1924
# some checks for consistency
2025
assert os.path.isdir(base_dir)
21-
assert os.path.isdir(raw_dir)
26+
assert os.path.isdir(data_dir)
27+
assert os.path.isdir(lib_dir)
File renamed without changes.

lib/cpp/buffered_stream

Submodule buffered_stream added at 99a8405

lib/cython/setup.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import numpy
2+
import os
3+
import sys
4+
from Cython.Build import cythonize
5+
from distutils.core import setup, Extension
6+
7+
8+
def add_extension(extensions, name, **kwargs):
9+
assert isinstance(name, str) and "/" not in name
10+
assert name not in extensions
11+
12+
source = name.replace(".", "/") + ".pyx"
13+
if "language" not in kwargs:
14+
kwargs["language"] = "c++"
15+
if "extra_compile_args" in kwargs:
16+
kwargs["extra_compile_args"] += ["-std=c++11", "-O3"]
17+
else:
18+
kwargs["extra_compile_args"] = ["-std=c++11", "-O3"]
19+
20+
extensions[name] = Extension(
21+
name,
22+
sources=[source],
23+
**kwargs
24+
)
25+
26+
27+
if __name__ == "__main__":
28+
sys.path.append("../../")
29+
import cfg
30+
extensions = e = dict()
31+
32+
# set the compiler
33+
os.environ["CC"] = "g++-7"
34+
35+
# collection_stats
36+
kwargs = {"include_dirs": []}
37+
kwargs["include_dirs"].append(cfg.lib_dir + "cpp")
38+
kwargs["include_dirs"].append(cfg.lib_dir + "cpp/pattern_matching")
39+
add_extension(e, 'collection_stats.collection_stats', extra_link_args=['-fopenmp'], extra_compile_args=['-fopenmp'], **kwargs)
40+
add_extension(e, 'collection_stats.collection_stats_restricted', extra_link_args=['-fopenmp'], extra_compile_args=['-fopenmp'], **kwargs)
41+
# featurizers
42+
kwargs["include_dirs"].append(numpy.get_include())
43+
add_extension(e, 'feature_extraction.featurizer_textual', **kwargs)
44+
add_extension(e, 'feature_extraction.featurizer_tags', **kwargs)
45+
add_extension(e, 'feature_extraction.featurizer_w2v', **kwargs)
46+
add_extension(e, 'feature_extraction.featurizer_sigir08', **kwargs)
47+
add_extension(e, 'feature_extraction.featurizer_sigir08extended', **kwargs)
48+
add_extension(e, 'feature_extraction.featurizer_custom', **kwargs)
49+
add_extension(e, 'feature_extraction.featurizer_qpp', **kwargs)
50+
51+
# setup
52+
setup(
53+
ext_modules=cythonize(extensions.values()),
54+
packages=extensions.keys(),
55+
)

lib/python/efficient_query_expansion/__init__.py

Whitespace-only changes.

lib/python/documents_utils.py renamed to lib/python/efficient_query_expansion/documents_utils.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
import codecs
2+
import gzip
13
import nltk
2-
from utils import get_reader, get_emitter_from_generator
4+
import os.path
5+
import sys
6+
37
from normalize_text import normalize_text, normalize_text_step_1
8+
from parallel_stream.utils import get_emitter_from_iterable
49

510

611
class Doc(object):
@@ -164,6 +169,24 @@ def _xml_extractor_reader_to_doc_generator(reader):
164169
raise Exception("A content was expected before the end of file")
165170

166171

172+
def get_reader(infilename, encoding=None):
173+
if infilename == "-":
174+
reader = sys.stdin
175+
else:
176+
if not os.path.isfile(infilename):
177+
raise Exception("File {} doesn't exist".format(infilename))
178+
179+
if infilename.endswith(".gz"):
180+
reader = gzip.open(infilename, "rb")
181+
else:
182+
reader = open(infilename, "r")
183+
184+
if encoding is None or encoding.upper() == "ASCII":
185+
return reader
186+
187+
return codecs.getreader(encoding)(reader)
188+
189+
167190
def doc_generator_from_file(infilenames, encoding=None, file_format="custom"):
168191
if isinstance(infilenames, (str, unicode)):
169192
infilenames = [infilenames]
@@ -197,7 +220,7 @@ def sentence_generator_from_doc_file(*args, **kwargs):
197220

198221

199222
def get_doc_emitter_from_files(*args, **kwargs):
200-
return get_emitter_from_generator(doc_generator_from_file(*args, **kwargs))
223+
return get_emitter_from_iterable(doc_generator_from_file(*args, **kwargs))
201224

202225

203226
def get_doc_normalizer_worker():
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import cPickle
2+
import collections
3+
import json
4+
import socket
5+
import struct
6+
7+
from utils import query_repr_to_sql_query
8+
9+
10+
QueryPerformanceSubset = collections.namedtuple(
11+
"QueryPerformanceSubset",
12+
["num_ret", "exe_time"]
13+
)
14+
QueryPerformance = collections.namedtuple(
15+
"QueryPerformance",
16+
["num_ret", "num_rel", "num_rel_ret", "exe_time"]
17+
)
18+
19+
20+
class SocketChannel(object):
21+
_length_format = "<I"
22+
_length_size = struct.calcsize(_length_format)
23+
24+
def __init__(self, host, port):
25+
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
26+
self._sock.connect((host, port))
27+
28+
def close(self):
29+
self._sock.close()
30+
31+
def send_request(self, request):
32+
assert isinstance(request, dict)
33+
self._send_msg(json.dumps(request))
34+
35+
def receive_reply(self):
36+
return json.loads(self._recv_msg())
37+
38+
def _recvall(self, n):
39+
# Helper function to recv n bytes or return None if EOF is hit
40+
data = b''
41+
while len(data) < n:
42+
packet = self._sock.recv(n - len(data))
43+
if not packet:
44+
return None
45+
data += packet
46+
return data
47+
48+
def _send_msg(self, msg):
49+
# Prefix each message with a 4-byte length (network byte order)
50+
msg = struct.pack(SocketChannel._length_format, len(msg)) + msg
51+
self._sock.sendall(msg)
52+
53+
def _recv_msg(self):
54+
# Read message length and unpack it into an integer
55+
raw_msglen = self._recvall(SocketChannel._length_size)
56+
if not raw_msglen:
57+
return None
58+
msglen = struct.unpack(SocketChannel._length_format, raw_msglen)[0]
59+
# Read the message data
60+
return self._recvall(msglen)
61+
62+
63+
class IndexCursor(object):
64+
def __init__(self, index_cache, db_cursor):
65+
self._index_cache = index_cache
66+
self._db_cursor = db_cursor
67+
68+
def close(self):
69+
self._db_cursor.close()
70+
71+
def __enter__(self):
72+
return self
73+
74+
def __exit__(self, *exc_info):
75+
self.close()
76+
77+
def get_performance(self, query_repr, document_id_list=None, document_id_list_key=None,
78+
include_time=True, force=False):
79+
# check parameters
80+
assert isinstance(query_repr, (list, tuple))
81+
assert document_id_list is None or (isinstance(document_id_list, (list, tuple)) and len(document_id_list) > 0 and all(isinstance(doc_id, (int, long)) for doc_id in document_id_list))
82+
assert document_id_list_key is None or isinstance(document_id_list_key, (int, long))
83+
assert (document_id_list_key is None) == (document_id_list is None)
84+
assert isinstance(include_time, bool)
85+
assert isinstance(force, bool)
86+
87+
zero_document_id_list = (document_id_list is None) or (len(document_id_list) == 0)
88+
89+
# transform the query representation in a query string
90+
sql_str = query_repr_to_sql_query(query_repr)
91+
92+
# get the entry from the cache
93+
key = sql_str if zero_document_id_list else (sql_str, document_id_list_key)
94+
if not force:
95+
query_performance = self._index_cache._get(key)
96+
if query_performance is not None and (not include_time or query_performance.exe_time is not None):
97+
return query_performance
98+
99+
# transform the document_id_list
100+
document_id_list = [] if document_id_list is None else list(set(document_id_list))
101+
102+
request = {
103+
"query": sql_str,
104+
"query_type": "cnf"
105+
}
106+
if not zero_document_id_list:
107+
request["rel"] = document_id_list
108+
self._db_cursor.send_request(request)
109+
110+
result = self._db_cursor.receive_reply()
111+
if "error" in result:
112+
raise Exception(result["error"])
113+
114+
# compose the resulting object
115+
if zero_document_id_list:
116+
query_performance = QueryPerformanceSubset(
117+
num_ret=int(result["num_ret"]),
118+
exe_time=float(result["exe_time"])
119+
)
120+
else:
121+
query_performance = QueryPerformance(
122+
num_ret=int(result["num_ret"]),
123+
num_rel=int(result["num_rel"]),
124+
num_rel_ret=int(result["num_rel_ret"]),
125+
exe_time=float(result["exe_time"])
126+
)
127+
128+
# put the result into the cache
129+
self._index_cache._put(key, query_performance)
130+
if not zero_document_id_list:
131+
qps = self._index_cache._get(key[0])
132+
if qps is None or (include_time and qps.exe_time is None):
133+
qps = QueryPerformanceSubset(
134+
num_ret=query_performance.num_ret,
135+
exe_time=query_performance.exe_time
136+
)
137+
self._index_cache._put(key[0], qps)
138+
139+
# return
140+
return query_performance
141+
142+
143+
class IndexCache(object):
144+
def __init__(self, host, port):
145+
assert isinstance(host, str)
146+
assert isinstance(port, int)
147+
148+
self._host = host
149+
self._port = port
150+
self._cache = dict()
151+
152+
@staticmethod
153+
def load(file_path):
154+
host, port, cache = cPickle.load(open(file_path, "rb"))
155+
index_cache = IndexCache(host, port)
156+
index_cache._cache = cache
157+
return index_cache
158+
159+
def dump(self, file_path):
160+
cPickle.dump(
161+
(self._host, self._port, self._cache),
162+
open(file_path, "wb"),
163+
protocol=cPickle.HIGHEST_PROTOCOL
164+
)
165+
166+
def __len__(self):
167+
return len(self._cache)
168+
169+
def _get(self, key):
170+
return self._cache.get(key, None)
171+
172+
def _put(self, key, value):
173+
self._cache[key] = value
174+
175+
def cursor(self):
176+
connection = SocketChannel(host=self._host, port=self._port)
177+
return IndexCursor(self, connection)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
def query_repr_to_sql_query(query_repr, uniq_repr=True):
2+
join_fun = (
3+
lambda l, m, r, it: "{}{}{}".format(
4+
l,
5+
m.join(sorted(set(it)) if uniq_repr else it),
6+
r
7+
)
8+
)
9+
10+
return \
11+
join_fun("(", ") | (", ")", (
12+
join_fun("(", ") (", ")", (
13+
join_fun("", " | ", "", (
14+
"\"{}\"".format(syn_tag[0]) if " " in syn_tag[0] else syn_tag[0]
15+
for syn_tag in synset
16+
))
17+
for synset in and_query
18+
))
19+
for and_query in query_repr
20+
))
21+
22+
23+
def sql_query_to_query_repr(sql_query):
24+
assert sql_query[:2] == "((" and sql_query[-2:] == "))"
25+
26+
query_repr = \
27+
[
28+
[
29+
[
30+
(syn[1:-1] if (syn[0] == syn[-1] == "\"") else syn, )
31+
for syn in synset.split(" | ")
32+
]
33+
for synset in and_query.split(") (")
34+
]
35+
for and_query in sql_query[2:-2].split(")) | ((")
36+
]
37+
38+
assert all(
39+
" " not in syn_tag[0] or syn_tag[0].find("\"", 1, -1) == -1
40+
for and_query in query_repr
41+
for synset in and_query
42+
for syn_tag in synset
43+
)
44+
return query_repr

0 commit comments

Comments
 (0)