Skip to content

Commit b2d1a91

Browse files
author
Yusuke Oda
authored
reload_stat -> use_cache (#585)
1 parent 858e0b0 commit b2d1a91

33 files changed

+28
-55
lines changed

data/reports/absa-confidence-report.json

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"task_name": "aspect-based-sentiment-classification",
33
"source_language": "en",
44
"target_language": "en",
5-
"reload_stat": true,
65
"source_tokenizer": {
76
"cls_name": "SingleSpaceTokenizer"
87
},
@@ -1666,4 +1665,4 @@
16661665
}
16671666
]
16681667
}
1669-
}
1668+
}

data/reports/report_kg.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"dataset_split": null,
77
"source_language": null,
88
"target_language": null,
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"confidence_alpha": 0.05,
1211
"system_details": null,

explainaboard/explainaboard_main.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,10 @@ def create_parser():
245245
)
246246

247247
parser.add_argument(
248-
"--reload-stat",
249-
type=str,
250-
required=False,
251-
default=None,
252-
help="reload precomputed statistics over training set (if exists)",
248+
"--no-use-cache",
249+
dest="use_cache",
250+
action="store_false",
251+
help="Disable cached statistics over training set.",
253252
)
254253

255254
parser.add_argument(
@@ -362,7 +361,7 @@ def main():
362361
"""The main function to be executed."""
363362
args = create_parser().parse_args()
364363

365-
reload_stat: bool = False if args.reload_stat == "0" else True
364+
use_cache: bool = args.use_cache
366365
system_outputs: list[str] = args.system_outputs
367366

368367
reports: list[str] | None = args.reports
@@ -479,7 +478,6 @@ def load_system_details_path():
479478
"split_name": split,
480479
"source_language": source_language,
481480
"target_language": target_language,
482-
"reload_stat": reload_stat,
483481
"confidence_alpha": args.confidence_alpha,
484482
"system_details": system_details,
485483
"custom_features": system_datasets[0].metadata.custom_features,
@@ -510,6 +508,7 @@ def load_system_details_path():
510508
metadata=metadata_copied,
511509
sys_output=system_dataset.samples,
512510
skip_failed_analyses=args.skip_failed_analyses,
511+
use_cache=use_cache,
513512
)
514513
reports.append(report)
515514

explainaboard/info.py

-7
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,12 @@ class SysOutputInfo(Serializable):
8181
dataset_split (str): the name of the split.
8282
source_language (str): the language of the input
8383
target_language (str): the language of the output
84-
reload_stat (bool): whether to reload the statistics or not
8584
system_details (dict): a dictionary of system details
8685
source_tokenizer (Tokenizer): the tokenizer for source sentences
8786
target_tokenizer (Tokenizer): the tokenizer for target sentences
8887
analysis_levels: the levels of analysis to perform
8988
"""
9089

91-
DEFAULT_RELOAD_STAT: ClassVar[bool] = True
9290
DEFAULT_CONFIDENCE_ALPHA: ClassVar[float] = 0.05
9391

9492
task_name: str | None = None
@@ -98,7 +96,6 @@ class SysOutputInfo(Serializable):
9896
dataset_split: str | None = None
9997
source_language: str | None = None
10098
target_language: str | None = None
101-
reload_stat: bool = DEFAULT_RELOAD_STAT
10299
# NOTE(odashi): confidence_alpha == None has a meaning beyond "unset": it prevents
103100
# calculating confidence intervals.
104101
confidence_alpha: float | None = DEFAULT_CONFIDENCE_ALPHA
@@ -182,7 +179,6 @@ def serialize(self) -> dict[str, SerializableData]:
182179
"dataset_split": self.dataset_split,
183180
"source_language": self.source_language,
184181
"target_language": self.target_language,
185-
"reload_stat": self.reload_stat,
186182
"confidence_alpha": self.confidence_alpha,
187183
"system_details": self.system_details,
188184
"source_tokenizer": self.source_tokenizer,
@@ -223,9 +219,6 @@ def deserialize(cls, data: dict[str, SerializableData]) -> Serializable:
223219
dataset_split=_get_value(data, str, "dataset_split"),
224220
source_language=_get_value(data, str, "source_language"),
225221
target_language=_get_value(data, str, "target_language"),
226-
reload_stat=unwrap_or(
227-
_get_value(data, bool, "reload_stat"), cls.DEFAULT_RELOAD_STAT
228-
),
229222
confidence_alpha=confidence_alpha,
230223
system_details=system_details,
231224
source_tokenizer=_get_value(

explainaboard/info_test.py

-4
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ def test_serialization(self) -> None:
100100
dataset_split="quux",
101101
source_language="en",
102102
target_language="zh",
103-
reload_stat=True,
104103
confidence_alpha=None,
105104
system_details={"detail": 123},
106105
source_tokenizer=tokenizer1,
@@ -126,7 +125,6 @@ def test_serialization(self) -> None:
126125
"dataset_split": "quux",
127126
"source_language": "en",
128127
"target_language": "zh",
129-
"reload_stat": True,
130128
"system_details": {"detail": 123},
131129
"source_tokenizer": tokenizer1_serialized,
132130
"target_tokenizer": tokenizer2_serialized,
@@ -168,7 +166,6 @@ def test_serialization(self) -> None:
168166
self.assertEqual(deserialized.dataset_split, sysout.dataset_split)
169167
self.assertEqual(deserialized.source_language, sysout.source_language)
170168
self.assertEqual(deserialized.target_language, sysout.target_language)
171-
self.assertEqual(deserialized.reload_stat, sysout.reload_stat)
172169
self.assertEqual(deserialized.confidence_alpha, sysout.confidence_alpha)
173170
self.assertEqual(deserialized.system_details, sysout.system_details)
174171
self.assertIsInstance(deserialized.source_tokenizer, SingleSpaceTokenizer)
@@ -191,7 +188,6 @@ def test_from_any_dict(self) -> None:
191188
self.assertIsNone(deserialized.dataset_split)
192189
self.assertIsNone(deserialized.source_language)
193190
self.assertIsNone(deserialized.target_language)
194-
self.assertEqual(deserialized.reload_stat, SysOutputInfo.DEFAULT_RELOAD_STAT)
195191
self.assertEqual(
196192
deserialized.confidence_alpha, SysOutputInfo.DEFAULT_CONFIDENCE_ALPHA
197193
)

explainaboard/processors/processor.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,15 @@ def _get_statistics_resources(self, sys_info: SysOutputInfo) -> dict[str, Any]:
157157
def _statistics_func(self, samples: Iterable[Any], sys_info: SysOutputInfo) -> Any:
158158
...
159159

160-
def _gen_external_stats(self, sys_info: SysOutputInfo) -> Any:
160+
def _gen_external_stats(self, sys_info: SysOutputInfo, use_cache: bool) -> Any:
161161
"""Generate external statistics.
162162
163163
These are gathered from a relatively costly source, such as the training set,
164164
then cached for future use.
165165
166166
Args:
167167
sys_info: Information about the system outputs
168+
use_cache: whether to reload the statistics from cache or not.
168169
169170
Returns:
170171
Statistics from, usually, the training set that are used to calculate
@@ -179,7 +180,7 @@ def _gen_external_stats(self, sys_info: SysOutputInfo) -> Any:
179180
else sys_info.sub_dataset_name
180181
)
181182
# read statistics from cache
182-
if sys_info.reload_stat:
183+
if use_cache:
183184
statistics = read_statistics_from_cache(
184185
sys_info.dataset_name, sub_dataset
185186
)
@@ -497,13 +498,17 @@ def sort_bucket_info(
497498
raise ValueError(f"Invalid sort_by: {sort_by}")
498499

499500
def get_overall_statistics(
500-
self, metadata: dict, sys_output: list[dict]
501+
self,
502+
metadata: dict,
503+
sys_output: list[dict],
504+
use_cache: bool = True,
501505
) -> OverallStatistics:
502506
"""Get the overall statistics information of the system output.
503507
504508
Args:
505509
metadata: The metadata of the system
506510
sys_output: The system output itself
511+
use_cache: whether to reload the statistics from cache or not.
507512
"""
508513
if metadata is None:
509514
metadata = {}
@@ -542,7 +547,7 @@ def get_overall_statistics(
542547
)
543548

544549
# get scoring statistics
545-
external_stats = self._gen_external_stats(sys_info)
550+
external_stats = self._gen_external_stats(sys_info, use_cache)
546551

547552
# generate cases for each level
548553
analysis_cases: list[list[AnalysisCase]] = []
@@ -561,19 +566,28 @@ def get_overall_statistics(
561566

562567
@final
563568
def process(
564-
self, metadata: dict, sys_output: list[dict], skip_failed_analyses: bool = False
569+
self,
570+
metadata: dict,
571+
sys_output: list[dict],
572+
skip_failed_analyses: bool = False,
573+
use_cache: bool = True,
565574
) -> SysOutputInfo:
566575
"""Run the whole process of processing the output.
567576
568577
Args:
569578
metadata: The metadata used to specify information about processing.
570579
sys_output: They list of system outputs.
571580
skip_failed_analyses: Whether to skip failed analyses.
581+
use_cache: whether to reload the statistics or not.
572582
573583
Returns:
574584
Information about the processed system output.
575585
"""
576-
overall_statistics = self.get_overall_statistics(metadata, sys_output)
586+
overall_statistics = self.get_overall_statistics(
587+
metadata,
588+
sys_output,
589+
use_cache,
590+
)
577591
sys_info = unwrap(overall_statistics.sys_info)
578592
analyses = self.perform_analyses(
579593
sys_info,

integration_tests/artifacts/reports/test-ar_6960.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "ar",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-de_7213.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "de",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-de_9330.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "de",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-de_9335.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "de",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-en_7676.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "en",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-en_7872.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "en",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-en_8113.json

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"F1ScoreQA",
88
"ExactMatchQA"
99
],
10-
"reload_stat": true,
1110
"is_print_case": true,
1211
"language": "en",
1312
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-en_8235.json

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"F1ScoreQA",
88
"ExactMatchQA"
99
],
10-
"reload_stat": true,
1110
"is_print_case": true,
1211
"language": "en",
1312
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-en_9152.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "en",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-en_9200.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "en",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-es_7377.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "es",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-es_7678.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "es",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-es_7687.json

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"F1ScoreQA",
88
"ExactMatchQA"
99
],
10-
"reload_stat": true,
1110
"is_print_case": true,
1211
"language": "es",
1312
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-es_7698.json

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"F1ScoreQA",
88
"ExactMatchQA"
99
],
10-
"reload_stat": true,
1110
"is_print_case": true,
1211
"language": "es",
1312
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-es_9340.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "es",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-es_9342.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "es",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-fr_9262.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "fr",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-fr_9332.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "fr",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-ja_9137.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "ja",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-ja_9150.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "ja",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-zh_7117.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "zh",
1211
"confidence_alpha": 0.05,

integration_tests/artifacts/reports/test-zh_7311.json

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"metric_names": [
77
"Accuracy"
88
],
9-
"reload_stat": true,
109
"is_print_case": true,
1110
"language": "zh",
1211
"confidence_alpha": 0.05,

0 commit comments

Comments
 (0)