Skip to content

Commit 4afa84f

Browse files
authored
Merge pull request #9 from ghost/master
OC codegen extensions
2 parents 6b5dd53 + 75d9e31 commit 4afa84f

File tree

164 files changed

+1782
-6774
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

164 files changed

+1782
-6774
lines changed

.travis.yml

-24
This file was deleted.

code-style.xml

-3
This file was deleted.

conclave/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def generate_code(protocol: callable, cfg: CodeGenConfig, mpc_frameworks: list,
3737

3838
# only apply optimizations if required
3939
if apply_optimizations:
40-
dag = comp.rewrite_dag(dag, all_parties=cfg.all_pids, use_leaky_ops=cfg.use_leaky_ops)
40+
dag = comp.rewrite_dag(dag, cfg)
4141

4242
# partition into sub-dags that will run in specific frameworks
4343
mapping = part.heupart(dag, mpc_frameworks, local_frameworks)

conclave/codegen/jiff.py

+59-11
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,6 @@
77
from conclave.job import JiffJob
88

99

10-
'''
11-
TODO:
12-
13-
- append output column row to output csv
14-
- agg
15-
16-
'''
17-
18-
1910
class JiffCodeGen(CodeGen):
2011

2112
def __init__(self, config, dag: Dag, pid: int,
@@ -120,9 +111,9 @@ def _generate(self, job_name: [str, None], output_directory: str):
120111
elif isinstance(node, Concat):
121112
op_code += self._generate_concat(node)
122113
elif isinstance(node, Close):
123-
op_code += ''
114+
op_code += self._generate_close(node)
124115
elif isinstance(node, Create):
125-
op_code += self._generate_create(node)
116+
pass
126117
elif isinstance(node, Join):
127118
op_code += self._generate_join(node)
128119
elif isinstance(node, Open):
@@ -135,13 +126,53 @@ def _generate(self, job_name: [str, None], output_directory: str):
135126
op_code += self._generate_divide(node)
136127
elif isinstance(node, SortBy):
137128
op_code += self._generate_sort_by(node)
129+
elif isinstance(node, ConcatCols):
130+
op_code += self._generate_concat_cols(node)
138131
elif isinstance(node, Open):
139132
op_code += self._generate_open(node)
140133
else:
141134
print("encountered unknown operator type", repr(node))
142135

143136
return self._generate_job(job_name, op_code)
144137

138+
def _generate_close(self, close_op: Close):
139+
140+
# node.parent.out_rel.stored_with
141+
copied_set = copy.deepcopy(close_op.parent.out_rel.stored_with)
142+
data_holder = copied_set.pop()
143+
144+
template = open(
145+
"{0}/create.tmpl".format(self.template_directory), 'r').read()
146+
147+
data = {
148+
"OUTREL": close_op.out_rel.name,
149+
"ID": data_holder
150+
}
151+
152+
return pystache.render(template, data)
153+
154+
def _generate_concat_cols(self, concat_cols_op: ConcatCols):
155+
156+
if len(concat_cols_op.get_in_rels()) != 2:
157+
raise NotImplementedError("Only support concat cols of two relations")
158+
159+
if concat_cols_op.use_mult:
160+
161+
template = open(
162+
"{0}/matrix_mult.tmpl".format(self.template_directory), 'r').read()
163+
164+
data = {
165+
"LEFT_REL": concat_cols_op.get_in_rels()[0].name,
166+
'RIGHT_REL': concat_cols_op.get_in_rels()[1].name,
167+
"OUTREL": concat_cols_op.out_rel.name
168+
}
169+
170+
return pystache.render(template, data)
171+
172+
else:
173+
# TODO: implement this
174+
return ""
175+
145176
def _generate_create(self, create_op: Create):
146177

147178
# check that the input data belongs to exactly one party
@@ -165,6 +196,23 @@ def _generate_aggregate(self, agg_op: Aggregate):
165196
if agg_op.aggregator == 'sum':
166197
template = open(
167198
"{}/agg_sum.tmpl".format(self.template_directory), 'r').read()
199+
elif agg_op.aggregator == 'mean':
200+
template = open(
201+
"{}/agg_mean_with_count_col.tmpl".format(self.template_directory), 'r').read()
202+
203+
data = {
204+
"INREL": agg_op.get_in_rel().name,
205+
"OUTREL": agg_op.out_rel.name,
206+
"KEY_COL": agg_op.group_cols[0].idx,
207+
"AGG_COL": agg_op.agg_col.idx,
208+
"COUNT_COL": 2
209+
}
210+
211+
return pystache.render(template, data)
212+
213+
elif agg_op.aggregator == 'std_dev':
214+
template = open(
215+
"{}/agg_std_dev.tmpl".format(self.template_directory), 'r').read()
168216
else:
169217
raise Exception("Unknown aggregator encountered: {}\n".format(agg_op.aggregator))
170218

conclave/codegen/oblivc.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ def _generate(self, job_name: [str, None], output_directory: [str, None]):
7070
op_code += self._generate_divide(node)
7171
elif isinstance(node, SortBy):
7272
op_code += self._generate_sort_by(node)
73-
elif isinstance(node, Open):
74-
op_code += self._generate_open(node)
7573
elif isinstance(node, DistinctCount):
7674
op_code += self._generate_distinct_count(node)
7775
elif isinstance(node, Filter):
@@ -393,6 +391,12 @@ def _generate_aggregate(self, agg_op: Aggregate):
393391
elif agg_op.aggregator == "count":
394392
template = open(
395393
"{}/agg_count.tmpl".format(self.template_directory), 'r').read()
394+
elif agg_op.aggregator == 'mean':
395+
template = open(
396+
"{}/agg_mean_with_count_col.tmpl".format(self.template_directory), 'r').read()
397+
elif agg_op.aggregator == "std_dev":
398+
template = open(
399+
"{}/std_dev.tmpl".format(self.template_directory), 'r').read()
396400
else:
397401
raise Exception("Unknown aggregator encountered: {}".format(agg_op.aggregator))
398402

@@ -409,7 +413,9 @@ def _generate_aggregate(self, agg_op: Aggregate):
409413
"OUT_REL": agg_op.out_rel.name,
410414
"KEY_COL": agg_op.group_cols[0].idx,
411415
"AGG_COL": agg_op.agg_col.idx,
412-
"USE_LEAKY": leaky
416+
"USE_LEAKY": leaky,
417+
"COUNT_COL": 2,
418+
"LEAKY": "Leaky" if leaky else ""
413419
}
414420

415421
return pystache.render(template, data)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
var {{{OUTREL}}}RESULT = aggregateMean({{{INREL}}}, {{{INREL}}}KeepRows, {{{KEY_COL}}}, {{{AGG_COL}}}, 0);
3+
var {{{OUTREL}}} = {{{OUTREL}}}RESULT[0];
4+
var {{{OUTREL}}}KeepRows = {{{OUTREL}}}RESULT[1];
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
var {{{OUTREL}}}RESULT = await aggregateMeanWithCountCol({{{INREL}}}, {{{INREL}}}KeepRows, {{{KEY_COL}}}, {{{AGG_COL}}}, {{{COUNT_COL}}}, 0, jiff_instance);
3+
var {{{OUTREL}}} = {{{OUTREL}}}RESULT[0];
4+
var {{{OUTREL}}}KeepRows = {{{OUTREL}}}RESULT[1];
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
var {{{OUTREL}}}RESULT = stdDev({{{INREL}}}, {{{INREL}}}KeepRows, {{{KEY_COL}}}, {{{AGG_COL}}});
3+
var {{{OUTREL}}} = {{{OUTREL}}}RESULT[0];
4+
var {{{OUTREL}}}KeepRows = {{{OUTREL}}}RESULT[1];
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11

2-
var {{{OUTREL}}}RESULT = aggregate({{{INREL}}}, {{{INREL}}}KeepRows, {{{KEY_COL}}}, {{{AGG_COL}}});
2+
var {{{OUTREL}}}RESULT = await aggregate({{{INREL}}}, {{{INREL}}}KeepRows, {{{KEY_COL}}}, {{{AGG_COL}}}, jiff_instance);
33
var {{{OUTREL}}} = {{{OUTREL}}}RESULT[0];
44
var {{{OUTREL}}}KeepRows = {{{OUTREL}}}RESULT[1];

conclave/codegen/templates/jiff/bash.tmpl

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ export NODE_PATH="{{{JIFF_PATH}}}/node_modules"
44

55
cd {{{CODE_PATH}}}
66

7-
node party.js {{{INPUT_PATH}}} {{{PARTY_COUNT}}} default {{{PARTY_ID}}}
7+
node --max-old-space-size=8192 party.js {{{INPUT_PATH}}} {{{PARTY_COUNT}}} default {{{PARTY_ID}}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
2+
var {{{OUTREL}}}RESULT = multiplyMatrices({{{LEFT_REL}}}, {{{RIGHT_REL}}}, {{{LEFT_REL}}}KeepRows, {{{RIGHT_REL}}}KeepRows);
3+
var {{{OUTREL}}} = {{{OUTREL}}}RESULT[0];
4+
var {{{OUTREL}}}KeepRows = {{{OUTREL}}}RESULT[1];

0 commit comments

Comments
 (0)