@@ -2,7 +2,7 @@ use std::collections::VecDeque;
2
2
use std:: fs:: OpenOptions ;
3
3
use std:: io;
4
4
use std:: io:: { BufRead , BufReader , BufWriter , Write } ;
5
- use std:: path:: { Path , PathBuf } ;
5
+ use std:: path:: PathBuf ;
6
6
use std:: process;
7
7
use std:: sync:: atomic:: { AtomicU32 , Ordering } ;
8
8
use std:: sync:: Arc ;
@@ -15,8 +15,8 @@ use threadpool::ThreadPool;
15
15
16
16
use crate :: bloom_filter:: BloomFilter ;
17
17
use crate :: s3_util;
18
- use crate :: s3_util:: { download_to_file, upload_file} ;
19
18
use crate :: shard:: shard_config:: WorkDirConfig ;
19
+ use crate :: shard:: FileCache ;
20
20
21
21
use deduper_config:: * ;
22
22
@@ -74,40 +74,44 @@ pub fn run(config: DeduperConfig) {
74
74
// For doc-level deduping, check the Bloom filter for existence of the configured key and set the configured attribute to true.
75
75
// For paragraph-level deduping, check the Bloom filter for existence of a paragraph in the text and add a span to the configured attribute.
76
76
fn write_attributes (
77
- doc_path : String ,
77
+ docs_location : String ,
78
78
work_dirs : WorkDirConfig ,
79
79
dedupe_config : DedupeConfig ,
80
80
bloom_filter : Arc < BloomFilter > ,
81
81
) -> Result < ( ) , io:: Error > {
82
- let rt = tokio:: runtime:: Builder :: new_current_thread ( )
83
- . enable_all ( )
84
- . build ( )
85
- . unwrap ( ) ;
86
-
87
- let s3_client = s3_util:: new_client ( None ) ?;
88
-
89
- let input_work_dir = Path :: new ( & work_dirs. input ) ;
90
- let output_work_dir = Path :: new ( & work_dirs. output ) ;
82
+ let cache = FileCache {
83
+ s3_client : Box :: new ( s3_util:: new_client ( None ) ?) ,
84
+ work : work_dirs. clone ( ) ,
85
+ } ;
91
86
92
- let output_path = {
87
+ let attrs_location = {
93
88
let mut attr_prefix = "/attributes/" . to_owned ( ) ;
94
89
attr_prefix. push_str ( & dedupe_config. name ) ;
95
90
attr_prefix. push_str ( "/" ) ;
96
- doc_path. to_owned ( ) . replace ( "/documents/" , & attr_prefix)
91
+ docs_location
92
+ . to_owned ( )
93
+ . replace ( "/documents/" , & attr_prefix)
97
94
} ;
98
- let local_output = output_work_dir . join ( & output_path ) ;
95
+ let local_output = cache . prepare_output ( & attrs_location ) ? ;
99
96
if local_output. exists ( ) {
100
- log:: info!( "Skipping {:?} because it already exists" , output_path ) ;
97
+ log:: info!( "Skipping {:?} because it already exists" , attrs_location ) ;
101
98
return Ok ( ( ) ) ;
102
99
}
100
+ log:: info!(
101
+ "Writing attributes for {} to {}" ,
102
+ docs_location,
103
+ local_output. display( )
104
+ ) ;
103
105
104
106
std:: fs:: create_dir_all ( local_output. parent ( ) . unwrap ( ) ) ?;
105
107
106
- let tmp_output_path = output_work_dir. join ( output_path. clone ( ) + ".tmp" ) ;
108
+ log:: info!(
109
+ "Writing attributes for {} to {}" ,
110
+ docs_location,
111
+ local_output. display( )
112
+ ) ;
107
113
{
108
- let local_input = input_work_dir. join ( Path :: new ( & doc_path) ) ;
109
- log:: info!( "Downloading {} to {}" , doc_path, local_input. display( ) ) ;
110
- rt. block_on ( download_to_file ( & s3_client, & doc_path, & local_input) ) ?;
114
+ let local_input = cache. prepare_input ( & docs_location) ?;
111
115
let input_file = OpenOptions :: new ( )
112
116
. read ( true )
113
117
. write ( false )
@@ -120,7 +124,7 @@ fn write_attributes(
120
124
. write ( true )
121
125
. create ( true )
122
126
. truncate ( true )
123
- . open ( & tmp_output_path ) ?;
127
+ . open ( & local_output ) ?;
124
128
125
129
let mut writer = BufWriter :: with_capacity (
126
130
1024 * 1024 ,
@@ -132,7 +136,12 @@ fn write_attributes(
132
136
match line {
133
137
Ok ( _) => { }
134
138
Err ( e) => {
135
- log:: error!( "Error reading line {} of {}: {}" , line_number, & doc_path, e) ;
139
+ log:: error!(
140
+ "Error reading line {} of {}: {}" ,
141
+ line_number,
142
+ & docs_location,
143
+ e
144
+ ) ;
136
145
break ;
137
146
}
138
147
}
@@ -223,23 +232,7 @@ fn write_attributes(
223
232
}
224
233
std:: fs:: remove_file ( local_input) ?;
225
234
}
226
-
227
- log:: info!(
228
- "Uploading {} to {}" ,
229
- & tmp_output_path. display( ) ,
230
- & output_path
231
- ) ;
232
- rt. block_on ( upload_file ( & s3_client, & output_path, & tmp_output_path) ) ?;
233
-
234
- {
235
- // Create empty file to indicate that the shard is done.
236
- OpenOptions :: new ( )
237
- . create ( true )
238
- . write ( true )
239
- . open ( & local_output) ?;
240
- std:: fs:: remove_file ( & tmp_output_path) ?;
241
- }
242
-
235
+ cache. finalize_output ( & attrs_location) ?;
243
236
Ok ( ( ) )
244
237
}
245
238
@@ -303,16 +296,14 @@ pub mod deduper_config {
303
296
}
304
297
305
298
#[ cfg( test) ]
306
- pub mod test {
299
+ mod test {
307
300
use std:: fs:: OpenOptions ;
308
301
use std:: io;
309
302
use std:: io:: { BufRead , BufReader } ;
310
- use std:: path:: Path ;
311
303
312
304
use flate2:: read:: MultiGzDecoder ;
313
305
314
306
use crate :: s3_util;
315
- use crate :: s3_util:: download_to_file;
316
307
317
308
use super :: * ;
318
309
@@ -352,53 +343,39 @@ pub mod test {
352
343
}
353
344
354
345
#[ test]
355
- pub fn test_dedupe_by_url ( ) -> Result < ( ) , io:: Error > {
346
+ fn test_dedupe_by_url ( ) -> Result < ( ) , io:: Error > {
356
347
let config = DeduperConfig :: read_from_file ( "tests/config/dedupe-by-url.json" ) . unwrap ( ) ;
357
- run ( config) ;
358
-
359
- let rt = tokio:: runtime:: Builder :: new_current_thread ( )
360
- . enable_all ( )
361
- . build ( )
362
- . unwrap ( ) ;
363
- let s3_client = s3_util:: new_client ( None ) ?;
364
-
365
- let local_output_file = "tests/work/output/dedupe-by-url.json.gz" ;
366
- let remote_output_file = "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_by_url/head/0000.json.gz" ;
367
- rt. block_on ( download_to_file (
368
- & s3_client,
369
- remote_output_file,
370
- Path :: new ( local_output_file) ,
371
- ) ) ?;
348
+ run ( config. clone ( ) ) ;
349
+
350
+ let cache = FileCache {
351
+ s3_client : Box :: new ( s3_util:: new_client ( None ) ?) ,
352
+ work : config. work_dir . clone ( ) ,
353
+ } ;
354
+
355
+ let local_output_file = cache. prepare_input ( "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_by_url/head/0000.json.gz" ) ?;
372
356
373
357
compare_contents (
374
358
"tests/data/expected/dedupe-by-url.json.gz" ,
375
- local_output_file,
359
+ & local_output_file. display ( ) . to_string ( ) ,
376
360
) ;
377
361
Ok ( ( ) )
378
362
}
379
363
380
364
#[ test]
381
- pub fn test_dedupe_paragraphs ( ) -> Result < ( ) , io:: Error > {
365
+ fn test_dedupe_paragraphs ( ) -> Result < ( ) , io:: Error > {
382
366
let config = DeduperConfig :: read_from_file ( "tests/config/dedupe-paragraphs.json" ) . unwrap ( ) ;
383
- run ( config) ;
384
-
385
- let rt = tokio:: runtime:: Builder :: new_current_thread ( )
386
- . enable_all ( )
387
- . build ( )
388
- . unwrap ( ) ;
389
- let s3_client = s3_util:: new_client ( None ) ?;
390
-
391
- let local_output_file = "tests/work/output/dedupe-paragraphs.json.gz" ;
392
- let remote_output_file = "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_paragraphs/head/0000.json.gz" ;
393
- rt. block_on ( download_to_file (
394
- & s3_client,
395
- remote_output_file,
396
- Path :: new ( local_output_file) ,
397
- ) ) ?;
367
+ run ( config. clone ( ) ) ;
368
+
369
+ let cache = FileCache {
370
+ s3_client : Box :: new ( s3_util:: new_client ( None ) ?) ,
371
+ work : config. work_dir . clone ( ) ,
372
+ } ;
373
+
374
+ let local_output_file = cache. prepare_input ( "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_paragraphs/head/0000.json.gz" ) ?;
398
375
399
376
compare_contents (
400
377
"tests/data/expected/dedupe-paragraphs.json.gz" ,
401
- local_output_file,
378
+ & local_output_file. display ( ) . to_string ( ) ,
402
379
) ;
403
380
Ok ( ( ) )
404
381
}
0 commit comments