Skip to content

Commit 1f6da42

Browse files
authored
fix: stream to data init (#30)
Co-authored-by: Pablo <[email protected]>
1 parent 0b988ea commit 1f6da42

File tree

3 files changed

+48
-26
lines changed

3 files changed

+48
-26
lines changed

sdk/diffgram/core/core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def __init__(
7070
self.directory_id = None
7171
self.name_to_file_id = None
7272

73+
self.client_id = client_id
74+
self.client_secret = client_secret
7375

7476
if init_default_directory is True:
7577
self.set_default_directory(directory = self.directory)
@@ -78,9 +80,6 @@ def __init__(
7880
if refresh_local_label_dict is True:
7981
self.get_label_file_dict()
8082

81-
self.client_id = client_id
82-
self.client_secret = client_secret
83-
8483
self.label_schema_list = self.get_label_schema_list()
8584

8685
self.directory_list = None

sdk/diffgram/core/diffgram_dataset_iterator.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,16 @@
88

99

1010
class DiffgramDatasetIterator:
11-
12-
def __init__(self, project,
11+
diffgram_file_id_list: list
12+
max_size_cache: int = 1073741824
13+
pool: ThreadPoolExecutor
14+
project: 'Project'
15+
file_cache: dict
16+
_internal_file_list: list
17+
current_file_index: int
18+
19+
def __init__(self,
20+
project,
1321
diffgram_file_id_list,
1422
validate_ids = True,
1523
max_size_cache = 1073741824,
@@ -19,6 +27,21 @@ def __init__(self, project,
1927
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
2028
:param diffgram_file_list (list): An arbitrary number of file ID's from Diffgram.
2129
"""
30+
self.diffgram_file_id_list = []
31+
self.max_size_cache = 1073741824
32+
self.pool = None
33+
self.file_cache = {}
34+
self._internal_file_list = []
35+
self.current_file_index = 0
36+
self.start_iterator(
37+
project = project,
38+
diffgram_file_id_list = diffgram_file_id_list,
39+
validate_ids = validate_ids,
40+
max_size_cache = max_size_cache,
41+
max_num_concurrent_fetches = max_num_concurrent_fetches)
42+
43+
def start_iterator(self, project, diffgram_file_id_list, validate_ids = True, max_size_cache = 1073741824,
44+
max_num_concurrent_fetches = 25):
2245
self.diffgram_file_id_list = diffgram_file_id_list
2346
self.max_size_cache = max_size_cache
2447
self.pool = ThreadPoolExecutor(max_num_concurrent_fetches)
@@ -62,7 +85,8 @@ def get_next_n_items(self, idx, num_items = 25):
6285
return True
6386

6487
def __get_file_data_for_index(self, idx):
65-
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True, use_session = False)
88+
diffgram_file = self.project.file.get_by_id(self.diffgram_file_id_list[idx], with_instances = True,
89+
use_session = False)
6690
instance_data = self.get_file_instances(diffgram_file)
6791
self.save_file_in_cache(idx, instance_data)
6892
return instance_data
@@ -88,7 +112,7 @@ def __validate_file_ids(self):
88112
if not self.diffgram_file_id_list:
89113
return
90114
result = self.project.file.file_list_exists(
91-
self.diffgram_file_id_list,
115+
self.diffgram_file_id_list,
92116
use_session = False)
93117
if not result:
94118
raise Exception(

sdk/diffgram/core/directory.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
class Directory(DiffgramDatasetIterator):
99

10-
def __init__(self,
11-
client,
12-
file_id_list_sliced = None,
13-
init_file_ids = True,
10+
def __init__(self,
11+
client,
12+
file_id_list_sliced = None,
13+
init_file_ids = True,
1414
validate_ids = True):
1515

1616
self.client = client
@@ -25,11 +25,8 @@ def __init__(self,
2525
self.file_id_list = file_id_list_sliced
2626
super(Directory, self).__init__(self.client, self.file_id_list, validate_ids)
2727

28-
29-
3028
def init_files(self):
3129
self.file_id_list = self.all_file_ids()
32-
3330
def get_directory_list(self):
3431
"""
3532
Get a list of available directories for a project
@@ -50,7 +47,7 @@ def get_directory_list(self):
5047
self.client.handle_errors(response)
5148

5249
data = response.json()
53-
50+
5451
directory_list_json = data.get('directory_list')
5552
default_directory_json = data.get('default_directory')
5653

@@ -60,7 +57,6 @@ def get_directory_list(self):
6057
directory_list = self.convert_json_to_sdk_object(directory_list_json)
6158

6259
return directory_list
63-
6460

6561
def convert_json_to_sdk_object(self, directory_list_json):
6662

@@ -71,18 +67,21 @@ def convert_json_to_sdk_object(self, directory_list_json):
7167
client = self.client,
7268
init_file_ids = False,
7369
validate_ids = False
74-
)
70+
)
7571
refresh_from_dict(new_directory, directory_json)
7672

7773
# note timing issue, this needs to happen after id is refreshed
78-
new_directory.init_files()
74+
new_directory.init_files()
75+
new_directory.start_iterator(
76+
project = new_directory.project,
77+
diffgram_file_id_list = new_directory.file_id_list,
78+
validate_ids = True
79+
)
7980

8081
directory_list.append(new_directory)
8182

8283
return directory_list
8384

84-
85-
8685
def all_files(self):
8786
"""
8887
Get all the files of the directoy.
@@ -93,8 +92,8 @@ def all_files(self):
9392
result = []
9493
while page_num is not None:
9594
diffgram_files = self.list_files(
96-
limit = 1000,
97-
page_num = page_num,
95+
limit = 1000,
96+
page_num = page_num,
9897
file_view_mode = 'base')
9998
page_num = self.file_list_metadata['next_page']
10099
result = result + diffgram_files
@@ -105,9 +104,9 @@ def all_file_ids(self, query = None):
105104
result = []
106105

107106
diffgram_ids = self.list_files(
108-
limit = 5000,
109-
page_num = page_num,
110-
file_view_mode = 'ids_only',
107+
limit = 5000,
108+
page_num = page_num,
109+
file_view_mode = 'ids_only',
111110
query = query)
112111

113112
if diffgram_ids is False:
@@ -299,7 +298,6 @@ def get(self,
299298
TODO refactor set_directory_by_name() to use this
300299
301300
"""
302-
303301
if name is None:
304302
raise Exception("No name provided.")
305303

@@ -312,6 +310,7 @@ def get(self,
312310
for directory in self.client.directory_list:
313311

314312
if directory.nickname == name:
313+
directory.init_files()
315314
return directory
316315

317316
else:

0 commit comments

Comments
 (0)