Skip to content

Commit 5fa5345

Browse files
authored
Merge pull request #48 from philipperemy/sequential
add sequential examples + keras layer
2 parents 0309dbf + 75aac86 commit 5fa5345

File tree

8 files changed

+165
-70
lines changed

8 files changed

+165
-70
lines changed

README.md

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,37 @@
22
[![license](https://img.shields.io/badge/License-Apache_2.0-brightgreen.svg)](https://github.com/philipperemy/keras-attention-mechanism/blob/master/LICENSE) [![dep1](https://img.shields.io/badge/Tensorflow-2.0+-brightgreen.svg)](https://www.tensorflow.org/) [![dep2](https://img.shields.io/badge/Keras-2.0+-brightgreen.svg)](https://keras.io/)
33
![Simple Keras Attention CI](https://github.com/philipperemy/keras-attention-mechanism/workflows/Simple%20Keras%20Attention%20CI/badge.svg)
44

5-
```
6-
pip install attention
7-
```
8-
95
Many-to-one attention mechanism for Keras.
106

117
<p align="center">
12-
<img src="examples/equations.png">
8+
<img src="examples/equations.png" width="600">
139
</p>
1410

11+
12+
Installation via pip
13+
14+
```bash
15+
pip install attention
16+
```
17+
18+
Import in the source code
19+
20+
```python
21+
from attention import Attention
22+
23+
# [...]
24+
25+
m = Sequential([
26+
LSTM(128, input_shape=(seq_length, 1), return_sequences=True),
27+
Attention(name='attention_weight'), # <--------- here.
28+
Dense(1, activation='linear')
29+
])
30+
```
31+
1532
## Examples
1633

34+
Install the requirements before running the examples: `pip install -r requirements.txt`.
35+
1736
### IMDB Dataset
1837

1938
In this experiment, we demonstrate that using attention yields a higher accuracy on the IMDB dataset. We consider two
@@ -46,6 +65,18 @@ task and the attention map converges to the ground truth.
4665
<img src="examples/attention.gif" width="320">
4766
</p>
4867

68+
### Finding max of a sequence
69+
70+
We consider many 1D sequences of the same length. The task is to find the maximum of each sequence.
71+
72+
We give the full sequence processed by the RNN layer to the attention layer. We expect the attention layer to focus on the maximum of each sequence.
73+
74+
After a few epochs, the attention layer converges perfectly to what we expected.
75+
76+
<p align="center">
77+
<img src="examples/readme/example.png" width="320">
78+
</p>
79+
4980
## References
5081

5182
- https://www.cs.cmu.edu/~./hovy/papers/16HLT-hierarchical-attention-networks.pdf

attention/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
from attention.attention import attention_3d_block # noqa
1+
from attention.attention import Attention # noqa
2+
3+
VERSION = '3.0'

attention/attention.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
11
from tensorflow.keras.layers import Dense, Lambda, dot, Activation, concatenate
2+
from tensorflow.keras.layers import Layer
23

34

4-
def attention_3d_block(hidden_states):
5-
"""
6-
Many-to-one attention mechanism for Keras.
7-
@param hidden_states: 3D tensor with shape (batch_size, time_steps, input_dim).
8-
@return: 2D tensor with shape (batch_size, 128)
9-
@author: felixhao28.
10-
"""
11-
hidden_size = int(hidden_states.shape[2])
12-
# Inside dense layer
13-
# hidden_states dot W => score_first_part
14-
# (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size)
15-
# W is the trainable weight matrix of attention Luong's multiplicative style score
16-
score_first_part = Dense(hidden_size, use_bias=False, name='attention_score_vec')(hidden_states)
17-
# score_first_part dot last_hidden_state => attention_weights
18-
# (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps)
19-
h_t = Lambda(lambda x: x[:, -1, :], output_shape=(hidden_size,), name='last_hidden_state')(hidden_states)
20-
score = dot([score_first_part, h_t], [2, 1], name='attention_score')
21-
attention_weights = Activation('softmax', name='attention_weight')(score)
22-
# (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size)
23-
context_vector = dot([hidden_states, attention_weights], [1, 1], name='context_vector')
24-
pre_activation = concatenate([context_vector, h_t], name='attention_output')
25-
attention_vector = Dense(128, use_bias=False, activation='tanh', name='attention_vector')(pre_activation)
26-
return attention_vector
5+
class Attention(Layer):
6+
7+
def __init__(self, **kwargs):
8+
super().__init__(**kwargs)
9+
10+
def __call__(self, hidden_states):
11+
"""
12+
Many-to-one attention mechanism for Keras.
13+
@param hidden_states: 3D tensor with shape (batch_size, time_steps, input_dim).
14+
@return: 2D tensor with shape (batch_size, 128)
15+
@author: felixhao28.
16+
"""
17+
hidden_size = int(hidden_states.shape[2])
18+
# Inside dense layer
19+
# hidden_states dot W => score_first_part
20+
# (batch_size, time_steps, hidden_size) dot (hidden_size, hidden_size) => (batch_size, time_steps, hidden_size)
21+
# W is the trainable weight matrix of attention Luong's multiplicative style score
22+
score_first_part = Dense(hidden_size, use_bias=False, name='attention_score_vec')(hidden_states)
23+
# score_first_part dot last_hidden_state => attention_weights
24+
# (batch_size, time_steps, hidden_size) dot (batch_size, hidden_size) => (batch_size, time_steps)
25+
h_t = Lambda(lambda x: x[:, -1, :], output_shape=(hidden_size,), name='last_hidden_state')(hidden_states)
26+
score = dot([score_first_part, h_t], [2, 1], name='attention_score')
27+
attention_weights = Activation('softmax', name='attention_weight')(score)
28+
# (batch_size, time_steps, hidden_size) dot (batch_size, time_steps) => (batch_size, hidden_size)
29+
context_vector = dot([hidden_states, attention_weights], [1, 1], name='context_vector')
30+
pre_activation = concatenate([context_vector, h_t], name='attention_output')
31+
attention_vector = Dense(128, use_bias=False, activation='tanh', name='attention_vector')(pre_activation)
32+
return attention_vector

examples/example-attention.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55
import numpy
66
import numpy as np
77
from keract import get_activations
8-
from tensorflow.keras import Input
9-
from tensorflow.keras import Model
8+
from tensorflow.keras import Sequential
109
from tensorflow.keras.callbacks import Callback
11-
from tensorflow.keras.layers import Dense
12-
from tensorflow.keras.layers import Dropout
13-
from tensorflow.keras.layers import LSTM
10+
from tensorflow.keras.layers import Dense, Dropout, LSTM
1411

15-
from attention import attention_3d_block
12+
from attention import Attention
1613

1714

1815
def task_add_two_numbers_after_delimiter(n: int, seq_length: int, delimiter: float = 0.0,
@@ -59,14 +56,13 @@ def main():
5956
x_test_mask[:, test_index_1:test_index_1 + 1] = 1
6057
x_test_mask[:, test_index_2:test_index_2 + 1] = 1
6158

62-
# model
63-
i = Input(shape=(seq_length, 1))
64-
x = LSTM(100, return_sequences=True)(i)
65-
x = attention_3d_block(x)
66-
x = Dropout(0.2)(x)
67-
x = Dense(1, activation='linear')(x)
59+
model = Sequential([
60+
LSTM(100, input_shape=(seq_length, 1), return_sequences=True),
61+
Attention(name='attention_weight'),
62+
Dropout(0.2),
63+
Dense(1, activation='linear')
64+
])
6865

69-
model = Model(inputs=[i], outputs=[x])
7066
model.compile(loss='mse', optimizer='adam')
7167
print(model.summary())
7268

@@ -79,7 +75,7 @@ def main():
7975
class VisualiseAttentionMap(Callback):
8076

8177
def on_epoch_end(self, epoch, logs=None):
82-
attention_map = get_activations(model, x_test, layer_name='attention_weight')['attention_weight']
78+
attention_map = get_activations(model, x_test, layer_names='attention_weight')['attention_weight']
8379

8480
# top is attention map.
8581
# bottom is ground truth.

examples/find_max.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
from keract import get_activations
4+
from tensorflow.keras import Sequential
5+
from tensorflow.keras.callbacks import Callback
6+
from tensorflow.keras.layers import Dense, LSTM
7+
8+
from attention import Attention
9+
10+
11+
class VisualizeAttentionMap(Callback):
12+
13+
def __init__(self, model, x):
14+
super().__init__()
15+
self.model = model
16+
self.x = x
17+
18+
def on_epoch_begin(self, epoch, logs=None):
19+
attention_map = get_activations(self.model, self.x, layer_names='attention_weight')['attention_weight']
20+
x = self.x[..., 0]
21+
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(5, 6))
22+
maps = [attention_map, create_argmax_mask(attention_map), create_argmax_mask(x)]
23+
maps_names = ['attention layer', 'attention layer - argmax()', 'ground truth - argmax()']
24+
for i, ax in enumerate(axes.flat):
25+
im = ax.imshow(maps[i], interpolation='none', cmap='jet')
26+
ax.set_ylabel(maps_names[i] + '\n#sample axis')
27+
ax.set_xlabel('sequence axis')
28+
ax.xaxis.set_ticks([])
29+
ax.yaxis.set_ticks([])
30+
cbar_ax = fig.add_axes([0.75, 0.15, 0.05, 0.7])
31+
fig.colorbar(im, cax=cbar_ax)
32+
fig.suptitle(f'Epoch {epoch} - training')
33+
plt.show()
34+
35+
36+
def create_argmax_mask(x):
37+
mask = np.zeros_like(x)
38+
for i, m in enumerate(x.argmax(axis=1)):
39+
mask[i, m] = 1
40+
return mask
41+
42+
43+
def main():
44+
seq_length = 10
45+
num_samples = 100000
46+
# https://stats.stackexchange.com/questions/485784/which-distribution-has-its-maximum-uniformly-distributed
47+
# Choose beta(1/N,1) to have max(X_1,...,X_n) ~ U(0, 1) => minimizes amount of knowledge.
48+
# If all the max(s) are concentrated around 1, then it makes the task easy for the model.
49+
x_data = np.random.beta(a=1 / seq_length, b=1, size=(num_samples, seq_length, 1))
50+
y_data = np.max(x_data, axis=1)
51+
model = Sequential([
52+
LSTM(128, input_shape=(seq_length, 1), return_sequences=True),
53+
Attention(name='attention_weight'),
54+
Dense(1, activation='linear')
55+
])
56+
model.compile(loss='mae')
57+
max_epoch = 100
58+
# visualize the attention on the first samples.
59+
visualize = VisualizeAttentionMap(model, x_data[0:12])
60+
model.fit(x_data, y_data, epochs=max_epoch, validation_split=0.2, callbacks=[visualize])
61+
62+
63+
if __name__ == '__main__':
64+
main()

examples/imdb.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,37 @@
11
import numpy
22
import numpy as np
3-
from tensorflow.keras import Input
4-
from tensorflow.keras import Model
3+
from tensorflow.keras import Sequential
54
from tensorflow.keras.callbacks import Callback
65
from tensorflow.keras.datasets import imdb
7-
from tensorflow.keras.layers import Dense
8-
from tensorflow.keras.layers import Dropout
9-
from tensorflow.keras.layers import Embedding
10-
from tensorflow.keras.layers import LSTM
6+
from tensorflow.keras.layers import Dense, Dropout, Embedding, LSTM
117
from tensorflow.keras.preprocessing import sequence
128

13-
from attention import attention_3d_block
9+
from attention import Attention
1410

1511

1612
def train_and_evaluate_model_on_imdb(add_attention=True):
1713
numpy.random.seed(7)
1814
# load the dataset but only keep the top n words, zero the rest
1915
top_words = 5000
20-
(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=top_words)
16+
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=top_words)
2117
# truncate and pad input sequences
2218
max_review_length = 500
23-
X_train = sequence.pad_sequences(X_train, maxlen=max_review_length)
24-
X_test = sequence.pad_sequences(X_test, maxlen=max_review_length)
19+
x_train = sequence.pad_sequences(x_train, maxlen=max_review_length)
20+
x_test = sequence.pad_sequences(x_test, maxlen=max_review_length)
2521
# create the model
2622
embedding_vector_length = 32
27-
i = Input(shape=(max_review_length,))
28-
x = Embedding(top_words, embedding_vector_length, input_length=max_review_length)(i)
29-
x = Dropout(0.5)(x)
30-
if add_attention:
31-
x = LSTM(100, return_sequences=True)(x)
32-
x = attention_3d_block(x)
33-
else:
34-
x = LSTM(100, return_sequences=False)(x)
35-
x = Dense(350, activation='relu')(x) # same number of parameters so fair comparison.
36-
x = Dropout(0.5)(x)
37-
x = Dense(1, activation='sigmoid')(x)
3823

39-
model = Model(inputs=[i], outputs=[x])
24+
model = Sequential([
25+
Embedding(top_words, embedding_vector_length, input_length=max_review_length),
26+
Dropout(0.5),
27+
# attention vs no attention. same number of parameters so fair comparison.
28+
*([LSTM(100, return_sequences=True), Attention()] if add_attention
29+
else [LSTM(100), Dense(350, activation='relu')]),
30+
Dropout(0.5),
31+
Dense(1, activation='sigmoid')
32+
]
33+
)
34+
4035
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
4136
print(model.summary())
4237

@@ -52,7 +47,7 @@ def on_epoch_end(self, epoch, logs=None):
5247
self.val_losses.append(logs['val_loss'])
5348

5449
rbta = RecordBestTestAccuracy()
55-
model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=10, batch_size=64, callbacks=[rbta])
50+
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=10, batch_size=64, callbacks=[rbta])
5651

5752
print(f"Max Test Accuracy: {100 * np.max(rbta.val_accuracies):.2f} %")
5853
print(f"Mean Test Accuracy: {100 * np.mean(rbta.val_accuracies):.2f} %")

examples/readme/example.png

97.9 KB
Loading

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from setuptools import setup
22

3+
from attention import VERSION
4+
35
setup(
46
name='attention',
5-
version='2.2',
6-
description='Keras Attention Many to One',
7+
version=VERSION,
8+
description='Keras Simple Attention',
79
author='Philippe Remy',
810
license='Apache 2.0',
911
long_description_content_type='text/markdown',
1012
long_description=open('README.md').read(),
1113
packages=['attention'],
12-
# manually install tensorflow or tensorflow-gpu
1314
install_requires=[
1415
'numpy>=1.18.1',
1516
'keras>=2.3.1',
16-
'gast>=0.2.2'
17+
'tensorflow>=2.1'
1718
]
1819
)

0 commit comments

Comments
 (0)