7
7
from conclave .job import JiffJob
8
8
9
9
10
- '''
11
- TODO:
12
-
13
- - append output column row to output csv
14
- - agg
15
-
16
- '''
17
-
18
-
19
10
class JiffCodeGen (CodeGen ):
20
11
21
12
def __init__ (self , config , dag : Dag , pid : int ,
@@ -120,9 +111,9 @@ def _generate(self, job_name: [str, None], output_directory: str):
120
111
elif isinstance (node , Concat ):
121
112
op_code += self ._generate_concat (node )
122
113
elif isinstance (node , Close ):
123
- op_code += ''
114
+ op_code += self . _generate_close ( node )
124
115
elif isinstance (node , Create ):
125
- op_code += self . _generate_create ( node )
116
+ pass
126
117
elif isinstance (node , Join ):
127
118
op_code += self ._generate_join (node )
128
119
elif isinstance (node , Open ):
@@ -135,13 +126,53 @@ def _generate(self, job_name: [str, None], output_directory: str):
135
126
op_code += self ._generate_divide (node )
136
127
elif isinstance (node , SortBy ):
137
128
op_code += self ._generate_sort_by (node )
129
+ elif isinstance (node , ConcatCols ):
130
+ op_code += self ._generate_concat_cols (node )
138
131
elif isinstance (node , Open ):
139
132
op_code += self ._generate_open (node )
140
133
else :
141
134
print ("encountered unknown operator type" , repr (node ))
142
135
143
136
return self ._generate_job (job_name , op_code )
144
137
138
+ def _generate_close (self , close_op : Close ):
139
+
140
+ # node.parent.out_rel.stored_with
141
+ copied_set = copy .deepcopy (close_op .parent .out_rel .stored_with )
142
+ data_holder = copied_set .pop ()
143
+
144
+ template = open (
145
+ "{0}/create.tmpl" .format (self .template_directory ), 'r' ).read ()
146
+
147
+ data = {
148
+ "OUTREL" : close_op .out_rel .name ,
149
+ "ID" : data_holder
150
+ }
151
+
152
+ return pystache .render (template , data )
153
+
154
+ def _generate_concat_cols (self , concat_cols_op : ConcatCols ):
155
+
156
+ if len (concat_cols_op .get_in_rels ()) != 2 :
157
+ raise NotImplementedError ("Only support concat cols of two relations" )
158
+
159
+ if concat_cols_op .use_mult :
160
+
161
+ template = open (
162
+ "{0}/matrix_mult.tmpl" .format (self .template_directory ), 'r' ).read ()
163
+
164
+ data = {
165
+ "LEFT_REL" : concat_cols_op .get_in_rels ()[0 ].name ,
166
+ 'RIGHT_REL' : concat_cols_op .get_in_rels ()[1 ].name ,
167
+ "OUTREL" : concat_cols_op .out_rel .name
168
+ }
169
+
170
+ return pystache .render (template , data )
171
+
172
+ else :
173
+ # TODO: implement this
174
+ return ""
175
+
145
176
def _generate_create (self , create_op : Create ):
146
177
147
178
# check that the input data belongs to exactly one party
@@ -165,6 +196,23 @@ def _generate_aggregate(self, agg_op: Aggregate):
165
196
if agg_op .aggregator == 'sum' :
166
197
template = open (
167
198
"{}/agg_sum.tmpl" .format (self .template_directory ), 'r' ).read ()
199
+ elif agg_op .aggregator == 'mean' :
200
+ template = open (
201
+ "{}/agg_mean_with_count_col.tmpl" .format (self .template_directory ), 'r' ).read ()
202
+
203
+ data = {
204
+ "INREL" : agg_op .get_in_rel ().name ,
205
+ "OUTREL" : agg_op .out_rel .name ,
206
+ "KEY_COL" : agg_op .group_cols [0 ].idx ,
207
+ "AGG_COL" : agg_op .agg_col .idx ,
208
+ "COUNT_COL" : 2
209
+ }
210
+
211
+ return pystache .render (template , data )
212
+
213
+ elif agg_op .aggregator == 'std_dev' :
214
+ template = open (
215
+ "{}/agg_std_dev.tmpl" .format (self .template_directory ), 'r' ).read ()
168
216
else :
169
217
raise Exception ("Unknown aggregator encountered: {}\n " .format (agg_op .aggregator ))
170
218
0 commit comments