Skip to content

Commit 23aa578

Browse files
authored
refactor: Simplify model export (#73)
Replaces OutputMapperLayer with a custom serving function. This means we don't need to save two models (training/serving) anymore, as the (ex training) model will do the "output mapping" through the custom serving function. With this approach, we could even have multiple serving functions with different output mapping strategies.
1 parent a684eb9 commit 23aa578

File tree

6 files changed

+338
-435
lines changed

6 files changed

+338
-435
lines changed
File renamed without changes.

experiments/Train.ipynb

Lines changed: 207 additions & 306 deletions
Large diffs are not rendered by default.

lib/dataset.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List
55

66
import numpy as np
7+
import pandas as pd
78
import tensorflow as tf
89
import tensorflow_datasets as tfds
910
from tensorflow.data.experimental import dense_to_ragged_batch
@@ -197,3 +198,18 @@ def _has_labels(x, y):
197198
return tf.math.reduce_max(y, 0) > 0
198199

199200
return ds.filter(_has_labels)
201+
202+
203+
def as_dataframe(ds: tf.data.Dataset) -> pd.DataFrame:
204+
"""
205+
Return the dataset as a pandas dataframe.
206+
207+
Same as `tfds.as_dataframe`, but with properly decoded string tensors.
208+
"""
209+
def _maybe_decode(x):
210+
try:
211+
return x.decode()
212+
except (UnicodeDecodeError, AttributeError):
213+
return x
214+
215+
return tfds.as_dataframe(ds).applymap(_maybe_decode)

lib/eval.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

lib/io.py

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,73 @@
1-
import json
21
import pathlib
3-
import shutil
4-
from typing import Dict
2+
import tempfile
3+
from typing import List
54

65
import tensorflow as tf
76

8-
from lib import settings
9-
from lib.model import to_serving_model
107

11-
12-
TRAINING_MODEL_SUBDIR = 'training_model'
13-
SERVING_MODEL_SUBDIR = 'serving_model'
14-
15-
16-
def save_model_bundle(
17-
model_dir: pathlib.Path,
8+
def save_model(
9+
path: pathlib.Path,
1810
model: tf.keras.Model,
19-
categories_vocab: Dict[str, int]):
20-
save_category_vocabulary(categories_vocab, model_dir)
21-
model.save(model_dir/TRAINING_MODEL_SUBDIR)
22-
to_serving_model(model, categories_vocab).save(model_dir/SERVING_MODEL_SUBDIR)
23-
24-
25-
def load_training_model(model_dir: pathlib.Path) -> tf.keras.Model:
26-
return tf.keras.models.load_model(model_dir/TRAINING_MODEL_SUBDIR)
27-
28-
29-
def load_serving_model(model_dir: pathlib.Path) -> tf.keras.Model:
30-
return tf.keras.models.load_model(model_dir/SERVING_MODEL_SUBDIR)
31-
32-
33-
def save_category_vocabulary(category_vocab: Dict[str, int], model_dir: pathlib.Path):
34-
category_to_ind = {name: idx for idx, name in enumerate(category_vocab)}
35-
return save_json(category_to_ind, model_dir / settings.CATEGORY_VOC_NAME)
36-
37-
38-
def load_category_vocabulary(model_dir: pathlib.Path):
39-
return load_json(model_dir / settings.CATEGORY_VOC_NAME)
40-
41-
42-
def copy_category_taxonomy(taxonomy_path: pathlib.Path, model_dir: pathlib.Path):
43-
shutil.copy(str(taxonomy_path), str(model_dir / settings.CATEGORY_TAXONOMY_NAME))
44-
45-
46-
def save_json(obj: object, path: pathlib.Path):
47-
with path.open("w") as f:
48-
return json.dump(obj, f)
49-
50-
51-
def load_json(path: pathlib.Path):
52-
with path.open("r") as f:
53-
return json.load(f)
11+
labels_vocab: List[str],
12+
serving_func: tf.function = None,
13+
**kwargs):
14+
"""
15+
Save the model and labels, with an optional custom serving function.
16+
17+
Parameters
18+
----------
19+
path: pathlib.Path
20+
Path where the model will be saved.
21+
22+
model: tf.keras.Model
23+
Keras model instance to be saved.
24+
25+
labels_vocab: List[str]
26+
Label vocabulary.
27+
28+
serving_func: tf.function, optional
29+
Custom serving function.
30+
If passed, `serving_func` will be the default endpoint in tensorflow serving.
31+
32+
**kwargs: dict, optional
33+
Additional keyword arguments passed to `tf.keras.Model.save`.
34+
"""
35+
tmp_dir = tempfile.TemporaryDirectory()
36+
labels_path = pathlib.Path(tmp_dir.name).joinpath('labels_vocab.txt')
37+
with labels_path.open('w') as w:
38+
w.writelines([f"{label}\n" for label in labels_vocab])
39+
model.labels_file = tf.saved_model.Asset(str(labels_path))
40+
41+
signatures = None
42+
if serving_func:
43+
arg_specs, kwarg_specs = model.save_spec()
44+
concrete_func = serving_func.get_concrete_function(*arg_specs, **kwarg_specs)
45+
signatures = {'serving_default': concrete_func}
46+
47+
model.save(str(path), signatures=signatures, **kwargs)
48+
49+
# must occur after model.save, so Asset source is still around for save
50+
tmp_dir.cleanup()
51+
52+
53+
def load_model(path: pathlib.Path, **kwargs):
54+
"""
55+
Load the model and labels.
56+
57+
Parameters
58+
----------
59+
path: pathlib.Path
60+
Path to the saved model.
61+
62+
**kwargs: dict, optional
63+
Additional keyword arguments passed to `tf.keras.models.load_model`.
64+
65+
Returns
66+
-------
67+
(tf.keras.Model, List[str])
68+
Model and labels.
69+
"""
70+
model = tf.keras.models.load_model(str(path))
71+
labels_file = model.labels_file.asset_path.numpy()
72+
labels = open(labels_file).read().splitlines()
73+
return model, labels

lib/model.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,64 @@
1-
from typing import List
1+
from typing import List, Tuple, Union
22

3+
import numpy as np
4+
import pandas as pd
35
import tensorflow as tf
46

57

6-
@tf.keras.utils.register_keras_serializable()
7-
class OutputMapperLayer(tf.keras.layers.Layer):
8-
"""
9-
The OutputMapperLayer converts the label indices produced by the model to
10-
the taxonomy category ids and limits them to top N labels.
8+
@tf.function
9+
def top_labeled_predictions(
10+
predictions: Union[tf.Tensor, np.array],
11+
labels: List[str],
12+
k: int = 10):
1113
"""
14+
Top labeled predictions.
15+
16+
This `@tf.function` can be used as a custom serving function.
1217
13-
def __init__(self, labels: List[str], top_n: int, **kwargs):
14-
self.labels = labels
15-
self.top_n = top_n
18+
Parameters
19+
----------
20+
predictions: tf.Tensor or np.array
21+
Predictions, as returned by `model.predict` or equivalent.
1622
17-
super(OutputMapperLayer, self).__init__(**kwargs)
23+
labels: List[str]
24+
Label vocabulary.
1825
19-
def call(self, x):
20-
batch_size = tf.shape(x)[0]
26+
k: int, optional
27+
Number of top predictions to return.
2128
22-
tf_labels = tf.constant([self.labels], dtype="string")
23-
tf_labels = tf.tile(tf_labels, [batch_size, 1])
29+
Returns
30+
-------
31+
(tf.Tensor, tf.Tensor)
32+
Top predicted labels with their scores, as (scores, labels).
33+
Returned tensors will have shape `(predictions.shape[0], k)`.
34+
"""
35+
tf_labels = tf.constant([labels], dtype='string')
2436

25-
top_n = tf.nn.top_k(x, k=self.top_n, sorted=True, name="top_k").indices
37+
top_indices = tf.nn.top_k(predictions, k=k, sorted=True, name='top_k').indices
2638

27-
top_conf = tf.gather(x, top_n, batch_dims=1)
28-
top_labels = tf.gather(tf_labels, top_n, batch_dims=1)
39+
top_labels = tf.experimental.numpy.take(tf_labels, top_indices)
40+
top_scores = tf.gather(predictions, top_indices, batch_dims=1)
2941

30-
return (top_conf, top_labels)
42+
return top_scores, top_labels
3143

32-
def compute_output_shape(self, input_shape):
33-
batch_size = input_shape[0]
34-
top_shape = (batch_size, self.top_n)
35-
return [top_shape, top_shape]
3644

37-
def get_config(self):
38-
config = {"labels": self.labels, "top_n": self.top_n}
39-
base_config = super(OutputMapperLayer, self).get_config()
40-
return dict(list(base_config.items()) + list(config.items()))
45+
def top_predictions_table(labeled_predictions) -> pd.DataFrame:
46+
"""
47+
Format the top labeled predictions into a pretty table.
48+
49+
Parameters
50+
----------
51+
labeled_predictions: (tf.Tensor, tf.Tensor)
52+
Labeled predictions, as returned by `top_labeled_predictions`.
53+
54+
Returns
55+
-------
56+
pd.DataFrame
57+
"""
58+
labels = labeled_predictions[1].numpy()
59+
scores = labeled_predictions[0].numpy()
4160

61+
cells = np.vectorize(lambda l, s: f"{l.decode()}: {s:.2%}")(labels, scores)
62+
columns = [f"top prediction {i+1}" for i in range(labels.shape[1])]
4263

43-
def to_serving_model(base_model: tf.keras.Model, categories: List[str]) -> tf.keras.Model:
44-
mapper_layer = OutputMapperLayer(categories, 50)(base_model.output)
45-
return tf.keras.Model(base_model.input, mapper_layer)
64+
return pd.DataFrame(cells, columns=columns)

0 commit comments

Comments
 (0)