Skip to content

Commit 0fa9693

Browse files
authored
Support Tags and Global Attributes in Streaming context (#32)
* fix: stream to data init * feat: add support for tags and global attributes in iterator Co-authored-by: Pablo <[email protected]>
1 parent 1f6da42 commit 0fa9693

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

sdk/diffgram/core/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,21 @@ def get_label_schema_list(self):
132132
data = response.json()
133133
return data
134134

135+
def get_attributes(self, schema_id = None):
136+
if schema_id is None:
137+
schema = self.get_default_label_schema()
138+
if schema is not None:
139+
schema_id = schema.get('id')
140+
url = f'/api/v1/project/{self.project_string_id}/attribute/template/list'
141+
data = {
142+
'schema_id': schema_id,
143+
'mode': "from_project",
144+
}
145+
response = self.session.post(url = self.host + url, json=data)
146+
self.handle_errors(response)
147+
data = response.json()
148+
return data.get('attribute_group_list')
149+
135150
def get_http_auth(self):
136151
return HTTPBasicAuth(self.client_id, self.client_secret)
137152

sdk/diffgram/core/diffgram_dataset_iterator.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,30 @@ def get_image_data(self, diffgram_file):
136136
else:
137137
raise Exception('Pytorch datasets only support images. Please provide only file_ids from images')
138138

139+
def gen_global_attrs(self, instance_list):
140+
res = []
141+
for inst in instance_list:
142+
if inst['type'] != 'global':
143+
continue
144+
res.append(inst['attribute_groups'])
145+
return res
146+
147+
def gen_tag_instances(self, instance_list):
148+
result = []
149+
for inst in instance_list:
150+
if inst['type'] != 'tag':
151+
continue
152+
for k in list(inst.keys()):
153+
val = inst[k]
154+
if val is None:
155+
inst.pop(k)
156+
elm = {
157+
'label': inst['label_file']['label']['name'],
158+
'label_file_id': inst['label_file']['id'],
159+
}
160+
result.append(elm)
161+
return result
162+
139163
def get_file_instances(self, diffgram_file):
140164
if diffgram_file.type not in ['image', 'frame']:
141165
raise NotImplementedError('File type "{}" is not supported yet'.format(diffgram_file['type']))
@@ -147,6 +171,9 @@ def get_file_instances(self, diffgram_file):
147171
sample = {'image': image, 'diffgram_file': diffgram_file}
148172
has_boxes = False
149173
has_poly = False
174+
has_tags = False
175+
has_global = False
176+
sample['raw_instance_list'] = instance_list
150177
if 'box' in instance_types_in_file:
151178
has_boxes = True
152179
x_min_list, x_max_list, y_min_list, y_max_list = self.extract_bbox_values(instance_list, diffgram_file)
@@ -164,12 +191,19 @@ def get_file_instances(self, diffgram_file):
164191
has_poly = True
165192
mask_list = self.extract_masks_from_polygon(instance_list, diffgram_file)
166193
sample['polygon_mask_list'] = mask_list
194+
if 'tag' in instance_types_in_file:
195+
has_tags = True
196+
sample['tags'] = self.gen_tag_instances(instance_list)
197+
if 'global' in instance_types_in_file:
198+
has_global = True
199+
sample['global_attributes'] = self.gen_global_attrs(instance_list)
200+
167201
else:
168202
sample['polygon_mask_list'] = []
169203

170-
if len(instance_types_in_file) > 2 and has_boxes and has_boxes:
204+
if len(instance_types_in_file) > 4 and has_poly and has_boxes and has_tags and has_global:
171205
raise NotImplementedError(
172-
'SDK only supports boxes and polygon types currently. If you want a new instance type to be supported please contact us!'
206+
'SDK Streaming only supports boxes and polygon, tags and global attributes types currently. If you want a new instance type to be supported please contact us!'
173207
)
174208

175209
label_id_list, label_name_list = self.extract_labels(instance_list)
@@ -198,11 +232,13 @@ def extract_masks_from_polygon(self, instance_list, diffgram_file, empty_value =
198232
def extract_labels(self, instance_list, allowed_instance_types = None):
199233
label_file_id_list = []
200234
label_names_list = []
201-
202235
for inst in instance_list:
236+
if inst['type'] == 'global':
237+
continue
238+
if inst is None:
239+
continue
203240
if allowed_instance_types and inst['type'] in allowed_instance_types:
204241
continue
205-
206242
label_file_id_list.append(inst['label_file']['id'])
207243
label_names_list.append(inst['label_file']['label']['name'])
208244

0 commit comments

Comments
 (0)