-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathtest_inspect.py
134 lines (115 loc) · 4.99 KB
/
test_inspect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import pytest
from datasets.exceptions import DatasetNotFoundError
from datasets.inspect import (
get_dataset_config_info,
get_dataset_config_names,
get_dataset_default_config_name,
get_dataset_infos,
get_dataset_split_names,
)
pytestmark = pytest.mark.integration
@pytest.mark.parametrize(
"path, config_name, expected_splits",
[
("rajpurkar/squad", "plain_text", ["train", "validation"]),
("dalle-mini/wit", "default", ["train"]),
("paws", "labeled_final", ["train", "test", "validation"]),
],
)
def test_get_dataset_config_info(path, config_name, expected_splits):
info = get_dataset_config_info(path, config_name=config_name)
assert info.config_name == config_name
assert list(info.splits.keys()) == expected_splits
def test_get_dataset_config_info_private(hf_token, hf_private_dataset_repo_txt_data):
info = get_dataset_config_info(hf_private_dataset_repo_txt_data, config_name="default", token=hf_token)
assert list(info.splits.keys()) == ["train"]
@pytest.mark.parametrize(
"path, config_name, expected_exception",
[
("paws", None, ValueError),
# non-existing, gated, private:
("hf-internal-testing/non-existing-dataset", "default", DatasetNotFoundError),
("hf-internal-testing/gated_dataset_with_data_files", "default", DatasetNotFoundError),
("hf-internal-testing/private_dataset_with_data_files", "default", DatasetNotFoundError),
("hf-internal-testing/gated_dataset_with_script", "default", DatasetNotFoundError),
("hf-internal-testing/private_dataset_with_script", "default", DatasetNotFoundError),
],
)
def test_get_dataset_config_info_raises(path, config_name, expected_exception):
kwargs = {"trust_remote_code": True} if path.endswith("_with_script") else {}
with pytest.raises(expected_exception):
get_dataset_config_info(path, config_name=config_name, **kwargs)
@pytest.mark.parametrize(
"path, expected",
[
("acronym_identification", ["default"]),
("rajpurkar/squad", ["plain_text"]),
("hf-internal-testing/dataset_with_script", ["default"]),
("dalle-mini/wit", ["default"]),
("hf-internal-testing/librispeech_asr_dummy", ["clean"]),
("hf-internal-testing/audiofolder_no_configs_in_metadata", ["default"]),
("hf-internal-testing/audiofolder_single_config_in_metadata", ["custom"]),
("hf-internal-testing/audiofolder_two_configs_in_metadata", ["v1", "v2"]),
],
)
def test_get_dataset_config_names(path, expected):
config_names = get_dataset_config_names(path, trust_remote_code=True)
assert config_names == expected
@pytest.mark.parametrize(
"path, expected",
[
("acronym_identification", "default"),
("rajpurkar/squad", "plain_text"),
("hf-internal-testing/dataset_with_script", "default"),
("dalle-mini/wit", "default"),
("hf-internal-testing/librispeech_asr_dummy", "clean"),
("hf-internal-testing/audiofolder_no_configs_in_metadata", "default"),
("hf-internal-testing/audiofolder_single_config_in_metadata", "custom"),
("hf-internal-testing/audiofolder_two_configs_in_metadata", None),
],
)
def test_get_dataset_default_config_name(path, expected):
default_config_name = get_dataset_default_config_name(path, trust_remote_code=True)
if expected:
assert default_config_name == expected
else:
assert default_config_name is None
@pytest.mark.parametrize(
"path, expected_configs, expected_splits_in_first_config",
[
("rajpurkar/squad", ["plain_text"], ["train", "validation"]),
("dalle-mini/wit", ["default"], ["train"]),
("paws", ["labeled_final", "labeled_swap", "unlabeled_final"], ["train", "test", "validation"]),
],
)
def test_get_dataset_info(path, expected_configs, expected_splits_in_first_config):
infos = get_dataset_infos(path)
assert list(infos.keys()) == expected_configs
expected_config = expected_configs[0]
assert expected_config in infos
info = infos[expected_config]
assert info.config_name == expected_config
assert list(info.splits.keys()) == expected_splits_in_first_config
@pytest.mark.parametrize(
"path, expected_config, expected_splits",
[
("rajpurkar/squad", "plain_text", ["train", "validation"]),
("dalle-mini/wit", "default", ["train"]),
("paws", "labeled_final", ["train", "test", "validation"]),
],
)
def test_get_dataset_split_names(path, expected_config, expected_splits):
infos = get_dataset_infos(path)
assert expected_config in infos
info = infos[expected_config]
assert info.config_name == expected_config
assert list(info.splits.keys()) == expected_splits
@pytest.mark.parametrize(
"path, config_name, expected_exception",
[
("paws", None, ValueError),
],
)
def test_get_dataset_split_names_error(path, config_name, expected_exception):
with pytest.raises(expected_exception):
get_dataset_split_names(path, config_name=config_name)