Skip to content

Commit 8f140ba

Browse files
committed
chore: copy paste code
Signed-off-by: ThibaultFy <[email protected]>
1 parent b5c419c commit 8f140ba

20 files changed

+2364
-0
lines changed

substra/tools/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from substratools.__version__ import __version__
2+
3+
from . import function
4+
from . import opener
5+
from .function import execute
6+
from .function import load_performance
7+
from .function import register
8+
from .function import save_performance
9+
from .opener import Opener
10+
11+
__all__ = [
12+
"__version__",
13+
function,
14+
opener,
15+
Opener,
16+
execute,
17+
load_performance,
18+
register,
19+
save_performance,
20+
]

substra/tools/__version__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.22.0a2"

substra/tools/exceptions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
class InvalidInterfaceError(Exception):
2+
pass
3+
4+
5+
class EmptyInterfaceError(InvalidInterfaceError):
6+
pass
7+
8+
9+
class NotAFileError(Exception):
10+
pass
11+
12+
13+
class MissingFileError(Exception):
14+
pass
15+
16+
17+
class InvalidInputOutputsError(Exception):
18+
pass
19+
20+
21+
class InvalidCLIError(Exception):
22+
pass
23+
24+
25+
class FunctionNotFoundError(Exception):
26+
pass
27+
28+
29+
class ExistingRegisteredFunctionError(Exception):
30+
pass

substra/tools/function.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# coding: utf8
2+
import argparse
3+
import json
4+
import logging
5+
import os
6+
import sys
7+
from copy import deepcopy
8+
from typing import Any
9+
from typing import Callable
10+
from typing import Dict
11+
from typing import Optional
12+
13+
from substratools import exceptions
14+
from substratools import opener
15+
from substratools import utils
16+
from substratools.exceptions import ExistingRegisteredFunctionError
17+
from substratools.exceptions import FunctionNotFoundError
18+
from substratools.task_resources import StaticInputIdentifiers
19+
from substratools.task_resources import TaskResources
20+
from substratools.workspace import FunctionWorkspace
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
def _parser_add_default_arguments(parser):
26+
parser.add_argument(
27+
"--function-name",
28+
type=str,
29+
help="The name of the function to execute from the given file",
30+
)
31+
parser.add_argument(
32+
"-r",
33+
"--task-properties",
34+
type=str,
35+
default="{}",
36+
help="Define the task properties",
37+
),
38+
parser.add_argument(
39+
"-d",
40+
"--fake-data",
41+
action="store_true",
42+
default=False,
43+
help="Enable fake data mode",
44+
)
45+
parser.add_argument(
46+
"--n-fake-samples",
47+
default=None,
48+
type=int,
49+
help="Number of fake samples if fake data is used.",
50+
)
51+
parser.add_argument(
52+
"--log-path",
53+
default=None,
54+
help="Define log filename path",
55+
)
56+
parser.add_argument(
57+
"--log-level",
58+
default="info",
59+
choices=utils.MAPPING_LOG_LEVEL.keys(),
60+
help="Choose log level",
61+
)
62+
parser.add_argument(
63+
"--inputs",
64+
type=str,
65+
default="[]",
66+
help="Inputs of the compute task",
67+
)
68+
parser.add_argument(
69+
"--outputs",
70+
type=str,
71+
default="[]",
72+
help="Outputs of the compute task",
73+
)
74+
75+
76+
class FunctionRegister:
77+
"""Class to create a decorator to register function in substratools. The functions are registered in the _functions
78+
dictionary, with the function.__name__ as key.
79+
Register a function in substratools means that this function can be access by the function.execute functions through
80+
the --function-name CLI argument."""
81+
82+
def __init__(self):
83+
self._functions = {}
84+
85+
def __call__(self, function: Callable, function_name: Optional[str] = None):
86+
"""Function called when using an instance of the class as a decorator.
87+
88+
Args:
89+
function (Callable): function to register in substratools.
90+
function_name (str, optional): function name to register the given function.
91+
If None, function.__name__ is used for registration.
92+
Raises:
93+
ExistingRegisteredFunctionError: Raise if a function with the same function.__name__
94+
has already been registered in substratools.
95+
96+
Returns:
97+
Callable: returns the function without decorator
98+
"""
99+
100+
function_name = function_name or function.__name__
101+
if function_name not in self._functions:
102+
self._functions[function_name] = function
103+
else:
104+
raise ExistingRegisteredFunctionError("A function with the same name is already registered.")
105+
106+
return function
107+
108+
def get_registered_functions(self):
109+
return self._functions
110+
111+
112+
# Instance of the decorator to store the function to register in memory.
113+
# Can be imported directly from substratools.
114+
register = FunctionRegister()
115+
116+
117+
class FunctionWrapper(object):
118+
"""Wrapper to execute a function on the platform."""
119+
120+
def __init__(self, workspace: FunctionWorkspace, opener_wrapper: Optional[opener.OpenerWrapper]):
121+
self._workspace = workspace
122+
self._opener_wrapper = opener_wrapper
123+
124+
def _assert_outputs_exists(self, outputs: Dict[str, str]):
125+
for key, path in outputs.items():
126+
if os.path.isdir(path):
127+
raise exceptions.NotAFileError(f"Expected output file at {path}, found dir for output `{key}`")
128+
if not os.path.isfile(path):
129+
raise exceptions.MissingFileError(f"Output file {path} used to save argument `{key}` does not exists.")
130+
131+
@utils.Timer(logger)
132+
def execute(
133+
self, function: Callable, task_properties: dict = {}, fake_data: bool = False, n_fake_samples: int = None
134+
):
135+
"""Execute a compute task"""
136+
137+
# load inputs
138+
inputs = deepcopy(self._workspace.task_inputs)
139+
140+
# load data from opener
141+
if self._opener_wrapper:
142+
loaded_datasamples = self._opener_wrapper.get_data(fake_data, n_fake_samples)
143+
144+
if fake_data:
145+
logger.info("Using fake data with %i fake samples." % int(n_fake_samples))
146+
147+
assert (
148+
StaticInputIdentifiers.datasamples.value not in inputs.keys()
149+
), f"{StaticInputIdentifiers.datasamples.value} must be an input of kind `datasamples`"
150+
inputs.update({StaticInputIdentifiers.datasamples.value: loaded_datasamples})
151+
152+
# load outputs
153+
outputs = deepcopy(self._workspace.task_outputs)
154+
155+
logger.info("Launching task: executing `%s` function." % function.__name__)
156+
function(
157+
inputs=inputs,
158+
outputs=outputs,
159+
task_properties=task_properties,
160+
)
161+
162+
self._assert_outputs_exists(
163+
self._workspace.task_outputs,
164+
)
165+
166+
167+
def _generate_function_cli():
168+
"""Helper to generate a command line interface client."""
169+
170+
def _function_from_args(args):
171+
inputs = TaskResources(args.inputs)
172+
outputs = TaskResources(args.outputs)
173+
log_path = args.log_path
174+
chainkeys_path = inputs.chainkeys_path
175+
176+
workspace = FunctionWorkspace(
177+
log_path=log_path,
178+
chainkeys_path=chainkeys_path,
179+
inputs=inputs,
180+
outputs=outputs,
181+
)
182+
183+
utils.configure_logging(workspace.log_path, log_level=args.log_level)
184+
185+
opener_wrapper = opener.load_from_module(
186+
workspace=workspace,
187+
)
188+
189+
return FunctionWrapper(workspace, opener_wrapper)
190+
191+
def _user_func(args, function):
192+
function_wrapper = _function_from_args(args)
193+
function_wrapper.execute(
194+
function=function,
195+
task_properties=json.loads(args.task_properties),
196+
fake_data=args.fake_data,
197+
n_fake_samples=args.n_fake_samples,
198+
)
199+
200+
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
201+
_parser_add_default_arguments(parser)
202+
parser.set_defaults(func=_user_func)
203+
204+
return parser
205+
206+
207+
def _get_function_from_name(functions: dict, function_name: str):
208+
209+
if function_name not in functions:
210+
raise FunctionNotFoundError(
211+
f"The function {function_name} given as --function-name argument as not been found."
212+
)
213+
214+
return functions[function_name]
215+
216+
217+
def save_performance(performance: Any, path: os.PathLike):
218+
with open(path, "w") as f:
219+
json.dump({"all": performance}, f)
220+
221+
222+
def load_performance(path: os.PathLike) -> Any:
223+
with open(path, "r") as f:
224+
performance = json.load(f)["all"]
225+
return performance
226+
227+
228+
def execute(sysargs=None):
229+
"""Launch function command line interface."""
230+
231+
cli = _generate_function_cli()
232+
233+
sysargs = sysargs if sysargs is not None else sys.argv[1:]
234+
args = cli.parse_args(sysargs)
235+
function = _get_function_from_name(register.get_registered_functions(), args.function_name)
236+
args.func(args, function)
237+
238+
return args

0 commit comments

Comments
 (0)