Skip to content

Commit 31b6750

Browse files
committed
add system_details to SysOutput class
Former-commit-id: 217d13b
1 parent a5b4109 commit 31b6750

File tree

6 files changed

+77
-1
lines changed

6 files changed

+77
-1
lines changed

data/system_details/test.json

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"learning_rate": 0.0001,
3+
"number_of_layers": 10
4+
}

explainaboard/analyzers/bar_chart.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
mlogger.setLevel(logging.WARNING)
1010

1111

12-
def mp_format(data: list[dict]) -> dict:
12+
def mp_format(data) -> dict:
1313
"""
1414
Adapt the format of data
1515
:param data:

explainaboard/explainaboard_main.py

+20
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,14 @@ def main():
161161
help="the directory of output files",
162162
)
163163

164+
parser.add_argument(
165+
'--system_details',
166+
type=str,
167+
required=False,
168+
default="system_details.json",
169+
help="a json file to store detailed information for a system",
170+
)
171+
164172
args = parser.parse_args()
165173

166174
dataset = args.dataset
@@ -181,6 +189,17 @@ def main():
181189
datasets_aggregation = args.datasets_aggregation
182190
languages_aggregation = args.languages_aggregation
183191

192+
system_details_path = args.system_details
193+
194+
# get system_details from input json file
195+
system_details = None
196+
if system_details_path is not None:
197+
try:
198+
with open(system_details_path) as fin:
199+
system_details = json.load(fin)
200+
except ValueError as e:
201+
print('invalid json: %s' % e)
202+
184203
# If reports have been specified, ExplainaBoard cli will perform analysis
185204
# over report files.
186205
if reports is not None:
@@ -329,6 +348,7 @@ def main():
329348
"task_name": task,
330349
"reload_stat": reload_stat,
331350
"conf_value": args.conf_value,
351+
"system_details": system_details,
332352
}
333353

334354
if metric_names is not None:

explainaboard/info.py

+3
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class SysOutputInfo:
9090
is_print_case: bool = True
9191
language: str = "en"
9292
conf_value: float = 0.05
93+
system_details: dict = (
94+
None # TODO(Pengfei): we can define a schema using `dataclass` in the future
95+
)
9396

9497
# set later
9598
# code: str = None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"learning_rate": 0.0001,
3+
"number_of_layers": 10
4+
}
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import json
2+
import os
3+
import pathlib
4+
import unittest
5+
6+
from explainaboard import FileType, get_loader, get_processor, Source, TaskType
7+
8+
artifacts_path = os.path.dirname(pathlib.Path(__file__)) + "/artifacts/"
9+
10+
11+
class TestSysDetails(unittest.TestCase):
12+
def test_generate_system_analysis(self):
13+
"""TODO: should add harder tests"""
14+
15+
path_system_details = artifacts_path + "test_system_details.json"
16+
path_data = artifacts_path + "sys_out1.tsv"
17+
18+
with open(path_system_details) as fin:
19+
system_details = json.load(fin)
20+
21+
metadata = {
22+
"task_name": TaskType.text_classification.value,
23+
"metric_names": ["Accuracy"],
24+
"system_details": system_details,
25+
}
26+
27+
loader = get_loader(
28+
TaskType.text_classification,
29+
path_data,
30+
Source.local_filesystem,
31+
FileType.tsv,
32+
)
33+
data = list(loader.load())
34+
processor = get_processor(TaskType.text_classification)
35+
36+
sys_info = processor.process(metadata, data)
37+
38+
# analysis.write_to_directory("./")
39+
self.assertIsNotNone(
40+
sys_info.system_details, {"learning_rate": 0.0001, "number_of_layers": 10}
41+
)
42+
43+
44+
if __name__ == '__main__':
45+
unittest.main()

0 commit comments

Comments
 (0)