7
7
import argparse
8
8
import json
9
9
import logging
10
+ from json import JSONDecodeError
10
11
from pathlib import Path
11
12
from os import PathLike
12
13
from typing import Union , Dict , Any , Optional
@@ -35,14 +36,14 @@ def add_subparser(self, parser: argparse._SubParsersAction) -> argparse.Argument
35
36
subparser .add_argument (
36
37
"input_file" ,
37
38
type = str ,
38
- help = "path to the file containing the evaluation data (for mutiple "
39
+ help = "path to the file containing the evaluation data (for multiple "
39
40
"files, put between filenames e.g., input1.txt,input2.txt)" ,
40
41
)
41
42
42
43
subparser .add_argument (
43
44
"--output-file" ,
44
45
type = str ,
45
- help = "optional path to write the metrics to as JSON (for mutiple "
46
+ help = "optional path to write the metrics to as JSON (for multiple "
46
47
"files, put between filenames e.g., output1.txt,output2.txt)" ,
47
48
)
48
49
@@ -258,17 +259,26 @@ def evaluate_from_archive(
258
259
dataset_reader = archive .validation_dataset_reader
259
260
260
261
# split files
261
- evaluation_data_path_list = input_file .split ("," )
262
+ try :
263
+ # Try reading it as a list of JSON objects first. Some readers require
264
+ # that kind of input.
265
+ evaluation_data_path_list = json .loads (f"[{ input_file } ]" )
266
+ except JSONDecodeError :
267
+ evaluation_data_path_list = input_file .split ("," )
262
268
263
269
# TODO(gabeorlanski): Is it safe to always default to .outputs and .preds?
264
270
# TODO(gabeorlanski): Add in way to save to specific output directory
265
271
if metrics_output_file is not None :
266
272
if auto_names == "METRICS" or auto_names == "ALL" :
267
273
logger .warning (
268
- f"Passed output_files will be ignored, auto_names is" f" set to { auto_names } "
274
+ f"Passed output_files will be ignored, auto_names is set to { auto_names } "
269
275
)
270
276
271
277
# Keep the path of the parent otherwise it will write to the CWD
278
+ assert all (isinstance (p , str ) for p in evaluation_data_path_list ), (
279
+ "When specifying JSON blobs as input, the output files must be explicitly named with "
280
+ "--output-file."
281
+ )
272
282
output_file_list = [
273
283
p .parent .joinpath (f"{ p .stem } .outputs" ) for p in map (Path , evaluation_data_path_list )
274
284
]
@@ -285,6 +295,10 @@ def evaluate_from_archive(
285
295
)
286
296
287
297
# Keep the path of the parent otherwise it will write to the CWD
298
+ assert all (isinstance (p , str ) for p in evaluation_data_path_list ), (
299
+ "When specifying JSON blobs as input, the predictions output files must be explicitly named with "
300
+ "--predictions-output-file."
301
+ )
288
302
predictions_output_file_list = [
289
303
p .parent .joinpath (f"{ p .stem } .preds" ) for p in map (Path , evaluation_data_path_list )
290
304
]
@@ -307,13 +321,15 @@ def evaluate_from_archive(
307
321
)
308
322
309
323
all_metrics = {}
310
- for index in range ( len ( evaluation_data_path_list ) ):
324
+ for index , evaluation_data_path in enumerate ( evaluation_data_path_list ):
311
325
config = deepcopy (archive .config )
312
- evaluation_data_path = evaluation_data_path_list [index ]
313
326
314
327
# Get the eval file name so we can save each metric by file name in the
315
328
# output dictionary.
316
- eval_file_name = Path (evaluation_data_path ).stem
329
+ if isinstance (evaluation_data_path , str ):
330
+ eval_file_name = Path (evaluation_data_path ).stem
331
+ else :
332
+ eval_file_name = str (index )
317
333
318
334
if metrics_output_file is not None :
319
335
# noinspection PyUnboundLocalVariable
@@ -323,7 +339,7 @@ def evaluate_from_archive(
323
339
# noinspection PyUnboundLocalVariable
324
340
predictions_output_file_path = predictions_output_file_list [index ]
325
341
326
- logger .info ("Reading evaluation data from %s" , evaluation_data_path )
342
+ logger .info ("Reading evaluation data from %s" , eval_file_name )
327
343
data_loader_params = config .get ("validation_data_loader" , None )
328
344
if data_loader_params is None :
329
345
data_loader_params = config .get ("data_loader" )
0 commit comments