@@ -157,14 +157,15 @@ def _get_statistics_resources(self, sys_info: SysOutputInfo) -> dict[str, Any]:
157
157
def _statistics_func (self , samples : Iterable [Any ], sys_info : SysOutputInfo ) -> Any :
158
158
...
159
159
160
- def _gen_external_stats (self , sys_info : SysOutputInfo ) -> Any :
160
+ def _gen_external_stats (self , sys_info : SysOutputInfo , use_cache : bool ) -> Any :
161
161
"""Generate external statistics.
162
162
163
163
These are gathered from a relatively costly source, such as the training set,
164
164
then cached for future use.
165
165
166
166
Args:
167
167
sys_info: Information about the system outputs
168
+ use_cache: whether to reload the statistics from cache or not.
168
169
169
170
Returns:
170
171
Statistics from, usually, the training set that are used to calculate
@@ -179,7 +180,7 @@ def _gen_external_stats(self, sys_info: SysOutputInfo) -> Any:
179
180
else sys_info .sub_dataset_name
180
181
)
181
182
# read statistics from cache
182
- if sys_info . reload_stat :
183
+ if use_cache :
183
184
statistics = read_statistics_from_cache (
184
185
sys_info .dataset_name , sub_dataset
185
186
)
@@ -497,13 +498,17 @@ def sort_bucket_info(
497
498
raise ValueError (f"Invalid sort_by: { sort_by } " )
498
499
499
500
def get_overall_statistics (
500
- self , metadata : dict , sys_output : list [dict ]
501
+ self ,
502
+ metadata : dict ,
503
+ sys_output : list [dict ],
504
+ use_cache : bool = True ,
501
505
) -> OverallStatistics :
502
506
"""Get the overall statistics information of the system output.
503
507
504
508
Args:
505
509
metadata: The metadata of the system
506
510
sys_output: The system output itself
511
+ use_cache: whether to reload the statistics from cache or not.
507
512
"""
508
513
if metadata is None :
509
514
metadata = {}
@@ -542,7 +547,7 @@ def get_overall_statistics(
542
547
)
543
548
544
549
# get scoring statistics
545
- external_stats = self ._gen_external_stats (sys_info )
550
+ external_stats = self ._gen_external_stats (sys_info , use_cache )
546
551
547
552
# generate cases for each level
548
553
analysis_cases : list [list [AnalysisCase ]] = []
@@ -561,19 +566,28 @@ def get_overall_statistics(
561
566
562
567
@final
563
568
def process (
564
- self , metadata : dict , sys_output : list [dict ], skip_failed_analyses : bool = False
569
+ self ,
570
+ metadata : dict ,
571
+ sys_output : list [dict ],
572
+ skip_failed_analyses : bool = False ,
573
+ use_cache : bool = True ,
565
574
) -> SysOutputInfo :
566
575
"""Run the whole process of processing the output.
567
576
568
577
Args:
569
578
metadata: The metadata used to specify information about processing.
570
579
sys_output: They list of system outputs.
571
580
skip_failed_analyses: Whether to skip failed analyses.
581
+ use_cache: whether to reload the statistics or not.
572
582
573
583
Returns:
574
584
Information about the processed system output.
575
585
"""
576
- overall_statistics = self .get_overall_statistics (metadata , sys_output )
586
+ overall_statistics = self .get_overall_statistics (
587
+ metadata ,
588
+ sys_output ,
589
+ use_cache ,
590
+ )
577
591
sys_info = unwrap (overall_statistics .sys_info )
578
592
analyses = self .perform_analyses (
579
593
sys_info ,
0 commit comments