Skip to content

Commit d458c90

Browse files
committed
Use s3 or local paths
1 parent eb57219 commit d458c90

File tree

10 files changed

+422
-292
lines changed

10 files changed

+422
-292
lines changed

.cargo/config.toml

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# See https://pyo3.rs/v0.14.2/building_and_distribution.html#macos
2+
[target.x86_64-apple-darwin]
3+
rustflags = [
4+
"-C", "link-arg=-undefined",
5+
"-C", "link-arg=dynamic_lookup",
6+
]
7+
8+
[target.aarch64-apple-darwin]
9+
rustflags = [
10+
"-C", "link-arg=-undefined",
11+
"-C", "link-arg=dynamic_lookup",
12+
]

Cargo.lock

+8-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
[package]
22
name = "dolma"
3-
version = "0.6.2"
3+
version = "0.6.3"
44
edition = "2021"
55
license = "Apache-2.0"
66

7-
87
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
98
[lib]
109
name = "dolma"
@@ -30,3 +29,4 @@ threadpool = "1.8.1"
3029
tokio = {version = "1.27.0", features = ["full"]}
3130
tokio-util = "0.7.7"
3231
unicode-segmentation = "1.7"
32+
glob = "0.3.1"

Makefile

+6-7
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,24 @@ setup:
2828
publish:
2929
maturin publish
3030

31-
test: setup develop setup-test test-python test-rust clean-test
31+
test: setup develop setup-test test-python test-rust
3232

3333
test-python:
3434
pytest -vs tests/python
3535

36-
test-rust:
37-
cargo test -- --nocapture
38-
39-
clean-test:
36+
test-rust-clean:
4037
rm -rf tests/work/*
4138
aws s3 rm --recursive s3://ai2-llm/pretraining-data/tests/mixer/
4239

43-
setup-test:
40+
test-rust-setup:
4441
aws s3 cp tests/data/documents.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/documents/head/0000.json.gz
4542
aws s3 cp tests/data/pii-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/pii/head/0000.json.gz
4643
aws s3 cp tests/data/toxicity-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/toxicity/head/0000.json.gz
4744
aws s3 cp tests/data/sample-attributes.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/sample/head/0000.json.gz
4845
aws s3 cp tests/data/duplicate-paragraphs.json.gz s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/duplicate_paragraphs/head/0000.json.gz
49-
aws s3 sync tests/data/expected s3://ai2-llm/pretraining-data/tests/mixer/expected --exclude ".*" --exclude "*/.*"
46+
47+
test-rust: test-rust-clean test-rust-setup
48+
cargo test -- --nocapture
5049

5150
develop:
5251
maturin develop --extras=dev

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "dolma"
3-
version = "0.6.2"
3+
version = "0.6.3"
44
description = "Data filters"
55
license = {text = "Apache-2.0"}
66
readme = "README.md"

src/deduper.rs

+53-76
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::collections::VecDeque;
22
use std::fs::OpenOptions;
33
use std::io;
44
use std::io::{BufRead, BufReader, BufWriter, Write};
5-
use std::path::{Path, PathBuf};
5+
use std::path::PathBuf;
66
use std::process;
77
use std::sync::atomic::{AtomicU32, Ordering};
88
use std::sync::Arc;
@@ -15,8 +15,8 @@ use threadpool::ThreadPool;
1515

1616
use crate::bloom_filter::BloomFilter;
1717
use crate::s3_util;
18-
use crate::s3_util::{download_to_file, upload_file};
1918
use crate::shard::shard_config::WorkDirConfig;
19+
use crate::shard::FileCache;
2020

2121
use deduper_config::*;
2222

@@ -74,40 +74,44 @@ pub fn run(config: DeduperConfig) {
7474
// For doc-level deduping, check the Bloom filter for existence of the configured key and set the configured attribute to true.
7575
// For paragraph-level deduping, check the Bloom filter for existence of a paragraph in the text and add a span to the configured attribute.
7676
fn write_attributes(
77-
doc_path: String,
77+
docs_location: String,
7878
work_dirs: WorkDirConfig,
7979
dedupe_config: DedupeConfig,
8080
bloom_filter: Arc<BloomFilter>,
8181
) -> Result<(), io::Error> {
82-
let rt = tokio::runtime::Builder::new_current_thread()
83-
.enable_all()
84-
.build()
85-
.unwrap();
86-
87-
let s3_client = s3_util::new_client(None)?;
88-
89-
let input_work_dir = Path::new(&work_dirs.input);
90-
let output_work_dir = Path::new(&work_dirs.output);
82+
let cache = FileCache {
83+
s3_client: Box::new(s3_util::new_client(None)?),
84+
work: work_dirs.clone(),
85+
};
9186

92-
let output_path = {
87+
let attrs_location = {
9388
let mut attr_prefix = "/attributes/".to_owned();
9489
attr_prefix.push_str(&dedupe_config.name);
9590
attr_prefix.push_str("/");
96-
doc_path.to_owned().replace("/documents/", &attr_prefix)
91+
docs_location
92+
.to_owned()
93+
.replace("/documents/", &attr_prefix)
9794
};
98-
let local_output = output_work_dir.join(&output_path);
95+
let local_output = cache.prepare_output(&attrs_location)?;
9996
if local_output.exists() {
100-
log::info!("Skipping {:?} because it already exists", output_path);
97+
log::info!("Skipping {:?} because it already exists", attrs_location);
10198
return Ok(());
10299
}
100+
log::info!(
101+
"Writing attributes for {} to {}",
102+
docs_location,
103+
local_output.display()
104+
);
103105

104106
std::fs::create_dir_all(local_output.parent().unwrap())?;
105107

106-
let tmp_output_path = output_work_dir.join(output_path.clone() + ".tmp");
108+
log::info!(
109+
"Writing attributes for {} to {}",
110+
docs_location,
111+
local_output.display()
112+
);
107113
{
108-
let local_input = input_work_dir.join(Path::new(&doc_path));
109-
log::info!("Downloading {} to {}", doc_path, local_input.display());
110-
rt.block_on(download_to_file(&s3_client, &doc_path, &local_input))?;
114+
let local_input = cache.prepare_input(&docs_location)?;
111115
let input_file = OpenOptions::new()
112116
.read(true)
113117
.write(false)
@@ -120,7 +124,7 @@ fn write_attributes(
120124
.write(true)
121125
.create(true)
122126
.truncate(true)
123-
.open(&tmp_output_path)?;
127+
.open(&local_output)?;
124128

125129
let mut writer = BufWriter::with_capacity(
126130
1024 * 1024,
@@ -132,7 +136,12 @@ fn write_attributes(
132136
match line {
133137
Ok(_) => {}
134138
Err(e) => {
135-
log::error!("Error reading line {} of {}: {}", line_number, &doc_path, e);
139+
log::error!(
140+
"Error reading line {} of {}: {}",
141+
line_number,
142+
&docs_location,
143+
e
144+
);
136145
break;
137146
}
138147
}
@@ -223,23 +232,7 @@ fn write_attributes(
223232
}
224233
std::fs::remove_file(local_input)?;
225234
}
226-
227-
log::info!(
228-
"Uploading {} to {}",
229-
&tmp_output_path.display(),
230-
&output_path
231-
);
232-
rt.block_on(upload_file(&s3_client, &output_path, &tmp_output_path))?;
233-
234-
{
235-
// Create empty file to indicate that the shard is done.
236-
OpenOptions::new()
237-
.create(true)
238-
.write(true)
239-
.open(&local_output)?;
240-
std::fs::remove_file(&tmp_output_path)?;
241-
}
242-
235+
cache.finalize_output(&attrs_location)?;
243236
Ok(())
244237
}
245238

@@ -303,16 +296,14 @@ pub mod deduper_config {
303296
}
304297

305298
#[cfg(test)]
306-
pub mod test {
299+
mod test {
307300
use std::fs::OpenOptions;
308301
use std::io;
309302
use std::io::{BufRead, BufReader};
310-
use std::path::Path;
311303

312304
use flate2::read::MultiGzDecoder;
313305

314306
use crate::s3_util;
315-
use crate::s3_util::download_to_file;
316307

317308
use super::*;
318309

@@ -352,53 +343,39 @@ pub mod test {
352343
}
353344

354345
#[test]
355-
pub fn test_dedupe_by_url() -> Result<(), io::Error> {
346+
fn test_dedupe_by_url() -> Result<(), io::Error> {
356347
let config = DeduperConfig::read_from_file("tests/config/dedupe-by-url.json").unwrap();
357-
run(config);
358-
359-
let rt = tokio::runtime::Builder::new_current_thread()
360-
.enable_all()
361-
.build()
362-
.unwrap();
363-
let s3_client = s3_util::new_client(None)?;
364-
365-
let local_output_file = "tests/work/output/dedupe-by-url.json.gz";
366-
let remote_output_file = "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_by_url/head/0000.json.gz";
367-
rt.block_on(download_to_file(
368-
&s3_client,
369-
remote_output_file,
370-
Path::new(local_output_file),
371-
))?;
348+
run(config.clone());
349+
350+
let cache = FileCache {
351+
s3_client: Box::new(s3_util::new_client(None)?),
352+
work: config.work_dir.clone(),
353+
};
354+
355+
let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_by_url/head/0000.json.gz")?;
372356

373357
compare_contents(
374358
"tests/data/expected/dedupe-by-url.json.gz",
375-
local_output_file,
359+
&local_output_file.display().to_string(),
376360
);
377361
Ok(())
378362
}
379363

380364
#[test]
381-
pub fn test_dedupe_paragraphs() -> Result<(), io::Error> {
365+
fn test_dedupe_paragraphs() -> Result<(), io::Error> {
382366
let config = DeduperConfig::read_from_file("tests/config/dedupe-paragraphs.json").unwrap();
383-
run(config);
384-
385-
let rt = tokio::runtime::Builder::new_current_thread()
386-
.enable_all()
387-
.build()
388-
.unwrap();
389-
let s3_client = s3_util::new_client(None)?;
390-
391-
let local_output_file = "tests/work/output/dedupe-paragraphs.json.gz";
392-
let remote_output_file = "s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_paragraphs/head/0000.json.gz";
393-
rt.block_on(download_to_file(
394-
&s3_client,
395-
remote_output_file,
396-
Path::new(local_output_file),
397-
))?;
367+
run(config.clone());
368+
369+
let cache = FileCache {
370+
s3_client: Box::new(s3_util::new_client(None)?),
371+
work: config.work_dir.clone(),
372+
};
373+
374+
let local_output_file = cache.prepare_input("s3://ai2-llm/pretraining-data/tests/mixer/inputs/v0/attributes/dedupe_paragraphs/head/0000.json.gz")?;
398375

399376
compare_contents(
400377
"tests/data/expected/dedupe-paragraphs.json.gz",
401-
local_output_file,
378+
&local_output_file.display().to_string(),
402379
);
403380
Ok(())
404381
}

0 commit comments

Comments
 (0)