Skip to content

Commit ecff7f6

Browse files
IgnacioHerediaalvarolopez
authored andcommitted
feat: allow multiple input files
1 parent f4e9ec4 commit ecff7f6

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

deepaas/cmd/cli.py

+27-8
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
fields.URL: str,
6161
fields.Url: str,
6262
fields.UUID: str,
63+
fields.Field: str,
6364
}
6465

6566

@@ -159,6 +160,18 @@ def _get_model_name(model_name=None):
159160
sys.exit(1)
160161

161162

163+
def _get_file_args(fields_in):
164+
"""Function to retrieve a list of file-type fields
165+
:param fields_in: mashmallow fields
166+
:return: list
167+
"""
168+
file_fields = []
169+
for k, v in fields_in.items():
170+
if type(v) is fields.Field:
171+
file_fields.append(k)
172+
return file_fields
173+
174+
162175
# Get the model name
163176
model_name = CONF.model_name
164177

@@ -174,6 +187,10 @@ def _get_model_name(model_name=None):
174187
predict_args = _fields_to_dict(model_obj.get_predict_args())
175188
train_args = _fields_to_dict(model_obj.get_train_args())
176189

190+
# Find which of the arguments are going to be files
191+
file_args = {}
192+
file_args['predict'] = _get_file_args(model_obj.get_predict_args())
193+
file_args['train'] = _get_file_args(model_obj.get_train_args())
177194

178195
# Function to add later these arguments to CONF via SubCommandOpt
179196
def _add_methods(subparsers):
@@ -285,29 +302,31 @@ def main():
285302
if CONF.deepaas_with_multiprocessing:
286303
mp.set_start_method("spawn", force=True)
287304

288-
# TODO(multi-file): change to many files ('for' itteration)
289-
if CONF.methods.__contains__("files"):
290-
if CONF.methods.files:
305+
# Create file wrapper for file args (if provided)
306+
for farg in file_args[CONF.methods.name]:
307+
if getattr(CONF.methods, farg, None):
308+
fpath = conf_vars[farg]
309+
291310
# create tmp file as later it supposed
292311
# to be deleted by the application
293312
temp = tempfile.NamedTemporaryFile()
294313
temp.close()
295314
# copy original file into tmp file
296-
with open(conf_vars["files"], "rb") as f:
315+
with open(fpath, "rb") as f:
297316
with open(temp.name, "wb") as f_tmp:
298317
for line in f:
299318
f_tmp.write(line)
300319

301320
# create file object
302-
file_type = mimetypes.MimeTypes().guess_type(conf_vars["files"])[0]
321+
file_type = mimetypes.MimeTypes().guess_type(fpath)[0]
303322
file_obj = v2_wrapper.UploadedFile(
304323
name="data",
305324
filename=temp.name,
306325
content_type=file_type,
307-
original_filename=conf_vars["files"],
326+
original_filename=fpath,
308327
)
309-
# re-write 'files' parameter in conf_vars
310-
conf_vars["files"] = file_obj
328+
# re-write parameter in conf_vars
329+
conf_vars[farg] = file_obj
311330

312331
# debug of input parameters
313332
LOG.debug("[DEBUG provided options, conf_vars]: {}".format(conf_vars))

0 commit comments

Comments
 (0)