Skip to content

Commit b2cc21e

Browse files
committed
Add support for binary operators on columns
e.g. foo DIV 1000 - by using “table.foo::DIV::1000”
1 parent 92100ac commit b2cc21e

File tree

3 files changed

+78
-44
lines changed

3 files changed

+78
-44
lines changed

grice/complex_filter.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
LIST_FILTERS = ['in', 'not_in', 'bt', 'nbt']
1010
FILTER_TYPES = ['lt', 'lte', 'eq', 'neq', 'gt', 'gte'] + LIST_FILTERS
1111

12-
ColumnFunction = namedtuple('ColumnFunction', ['table_name', 'column_name', 'func_name'])
12+
ColumnFunction = namedtuple('ColumnFunction', ['table_name', 'column_name', 'func_name', 'operator_name', 'operator_value'])
1313

1414

1515
def _get_column(table_name: str, column_name: str, tables: List[Table]) -> Column:
@@ -33,11 +33,14 @@ def get_column(column_name: str, tables: List[Table]):
3333
if isinstance(column_name, ColumnFunction):
3434
func_name = column_name.func_name
3535
table_name = column_name.table_name
36+
operator_name = column_name.operator_name
37+
operator_value = column_name.operator_value
3638
column_name = column_name.column_name
3739

3840
else:
3941
func_name = None
4042
table_name = None
43+
operator_name = None
4144

4245
try:
4346
column_name, table_name = column_name.split('.')
@@ -46,8 +49,13 @@ def get_column(column_name: str, tables: List[Table]):
4649
pass
4750

4851
column = _get_column(table_name, column_name, tables)
52+
53+
if operator_name:
54+
column = column.op(operator_name)(operator_value)
55+
4956
if func_name:
5057
return getattr(sql_func, func_name)(column)
58+
5159
return column
5260

5361
def parse_filter(filter_string: str):

grice/db_controller.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,21 @@ def parse_column_func(column_string):
214214
215215
expected format: column_name
216216
expected format: function:column_name where function is 'avg' or 'count' etc
217+
expected format: function:column_name::operator::value where operator is DIV, + etc
218+
expected format: column_name::operator::value
217219
218220
:param sort_string: string
219221
:return:
220222
"""
221223
table_name = None
224+
clean_vals = [s.strip() for s in column_string.split('::')]
225+
if len(clean_vals) == 3:
226+
column_string, operator_name, operator_value = clean_vals
227+
else:
228+
column_string = clean_vals[0]
229+
operator_name = None
230+
operator_value = None
231+
222232
clean_vals = [s.strip() for s in column_string.split(':')]
223233
column_name = clean_vals[-1]
224234
func_name = None
@@ -238,7 +248,7 @@ def parse_column_func(column_string):
238248
# This means the column name is not in the table_name.column_name format, which is fine.
239249
pass
240250

241-
return ColumnFunction(table_name, column_name, func_name)
251+
return ColumnFunction(table_name, column_name, func_name, operator_name, operator_value)
242252

243253
def parse_column_funcs(column_list):
244254
"""
@@ -267,7 +277,7 @@ def parse_col_names(column_names):
267277
:return: column_names: list
268278
"""
269279
if column_names:
270-
clean_cols = (column_name.strip() for column_name in column_names.split(','))
280+
clean_cols = (column_name.strip() for column_name in column_names)
271281
unique_ordered = OrderedDict.fromkeys(clean_cols)
272282
return list(unique_ordered)
273283

@@ -285,7 +295,7 @@ def parse_query_args(query_args):
285295
sorts = parse_sorts(query_args.getlist('sort'))
286296
join = parse_join(query_args.get('join'), False) or parse_join(query_args.get('outerjoin'), True)
287297
column_names = parse_column_funcs(query_args.getlist('columns')) or parse_column_funcs(query_args.get('cols', '').split(','))
288-
group_by = parse_col_names(query_args.getlist('group_by', None))
298+
group_by = parse_column_funcs(query_args.getlist('group_by', None))
289299

290300
return column_names, page, per_page, filters, sorts, join, group_by
291301

@@ -318,7 +328,7 @@ def get_query_args(self):
318328
sorts = parse_sorts(content.get('sort', []))
319329
join = parse_join(content.get('join'), False) or parse_join(content.get('outerjoin'), True)
320330
column_names = parse_column_funcs(content.get('columns', [])) or parse_column_funcs(content.get('cols', '').split(','))
321-
group_by = parse_col_names(content.get('group_by', []))
331+
group_by = parse_column_funcs(content.get('group_by', []))
322332
quargs = QueryArguments(column_names, page, per_page, filters, sorts, join, group_by, content.get('_list'))
323333

324334
return quargs

grice/db_service.py

+55-39
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
22
from collections import namedtuple
3+
from typing import Union
34
import urllib
45

56
from sqlalchemy import create_engine, MetaData, Column, Table, select, asc, desc, and_
67
from sqlalchemy import engine
78
from sqlalchemy.sql import Select
89
from sqlalchemy.sql.functions import Function
10+
from sqlalchemy.sql.expression import BinaryExpression
911
from sqlalchemy.engine import reflection
1012
from grice.complex_filter import ComplexFilter, get_column
1113
from grice.errors import ConfigurationError, NotFoundError, JoinError
@@ -15,7 +17,7 @@
1517
DEFAULT_PAGE = 0
1618
DEFAULT_PER_PAGE = 50
1719
SORT_DIRECTIONS = ['asc', 'desc']
18-
SUPPORTED_FUNCS = ['avg', 'count', 'min', 'max', 'sum']
20+
SUPPORTED_FUNCS = ['avg', 'count', 'min', 'max', 'sum', 'stddev_pop']
1921
ColumnSort = namedtuple('ColumnSort', ['table_name', 'column_name', 'direction'])
2022
ColumnPair = namedtuple('ColumnPair', ['from_column', 'to_column'])
2123
TableJoin = namedtuple('TableJoin', ['table_name', 'column_pairs', 'outer_join'])
@@ -48,15 +50,30 @@ def init_database(db_config):
4850
return create_engine(eng_url)
4951

5052

51-
def function_to_dict(func: Function):
52-
data = {
53-
'name': str(func),
54-
'primary_key': func.primary_key,
55-
'table': '<Function {}>'.format(func.name),
56-
}
53+
def computed_column_to_dict(column: Union[Function, BinaryExpression]):
54+
"""
55+
Converts a SqlAlchemy object for a column that contains a computed value to a dict so we can return JSON.
56+
57+
:param column: a SqlAlchemy Function or a SqlAlchemy BinaryExpression
58+
:return: dict
59+
"""
60+
if isinstance(column, Function):
61+
data = {
62+
'name': str(column),
63+
'primary_key': column.primary_key,
64+
'table': '<Function {}>'.format(column.name),
65+
'type': column.type.__class__.__name__,
66+
}
67+
elif isinstance(column, BinaryExpression):
68+
data = {
69+
'name': str(column),
70+
'primary_key': column.primary_key,
71+
'table': '<BinaryExpression {}>'.format(column),
72+
'type': column.type.__class__.__name__,
73+
}
5774
return data
5875

59-
def column_to_dict(column: Column):
76+
def _column_to_dict(column: Column):
6077
"""
6178
Converts a SqlAlchemy Column object to a dict so we can return JSON.
6279
@@ -80,6 +97,16 @@ def column_to_dict(column: Column):
8097

8198
return data
8299

100+
def column_to_dict(column):
101+
"""
102+
Converts a SqlAlchemy Column, or column-like object to a dict so we can return JSON.
103+
104+
:param column: a column
105+
:return: dict
106+
"""
107+
if isinstance(column, Column):
108+
return _column_to_dict(column)
109+
return computed_column_to_dict(column)
83110

84111
def table_to_dict(table: Table):
85112
"""
@@ -184,9 +211,7 @@ def apply_group_by(query, table: Table, join_table: Table, group_by: list):
184211
:return: A SQLAlchemy select object modified to with sorts.
185212
"""
186213
for group in group_by:
187-
column = table.columns.get(group, None)
188-
if join_table is not None and not column:
189-
column = join_table.columns.get(group, None)
214+
column = get_column(group, [table, join_table])
190215

191216
if column is not None:
192217
query = query.group_by(column)
@@ -305,34 +330,25 @@ def query_table(self, table_name: str, quargs: QueryArguments): # pylint: disab
305330
log.debug("Query %s", query)
306331
result = conn.execute(query)
307332

308-
for row in result:
309-
count_of_map = {}
310-
if quargs.format_as_list:
311-
data = []
312-
for column in columns:
313-
if isinstance(column, Function):
314-
counter = count_of_map.get(column.name, 0) + 1
315-
count_of_map[column.name] = counter
316-
column_label = column.name + '_' + str(counter)
317-
else:
318-
column_label = column.table.name + '_' + column.name
319-
data.append(row[column_label])
320-
else:
321-
data = {}
322-
for column in columns:
323-
if isinstance(column, Function):
324-
counter = count_of_map.get(column.name, 0) + 1
325-
count_of_map[column.name] = counter
326-
full_column_name = column.name + '_' + str(counter)
327-
column_label = column.name + '_' + str(counter)
328-
else:
329-
full_column_name = column.table.name + '.' + column.name
330-
column_label = column.table.name + '_' + column.name
331-
data[full_column_name] = row[column_label]
332-
333-
rows.append(data)
334-
335-
column_data = [column_to_dict(column) if isinstance(column, Column) else function_to_dict(column) for column in columns]
333+
if quargs.format_as_list:
334+
# SQLalchemy is giving us the data in the correct format
335+
rows = result
336+
else:
337+
column_name_map = {}
338+
first_row = True
339+
for row in result:
340+
# Make friendlier names if possible
341+
if first_row:
342+
for column, column_label in zip(columns, row.keys()):
343+
if isinstance(column, Column):
344+
full_column_name = column.table.name + '.' + column.name
345+
column_name_map[column_label] = full_column_name
346+
first_row = False
347+
348+
data = {column_name_map.get(key, key): val for key, val in row.items()}
349+
rows.append(data)
350+
351+
column_data = [column_to_dict(column) for column in columns]
336352

337353
return rows, column_data
338354

0 commit comments

Comments
 (0)