Skip to content

Commit 3092ad7

Browse files
committed
add models
1 parent 7907d5e commit 3092ad7

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

models/vgg16.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
2+
from src.modules.sequential import Sequential
3+
from src.modules.linear import Linear
4+
from src.modules.convolution import Convolution
5+
from src.modules.convolution_ab import Convolution_ab
6+
from src.modules.reshape import Reshape
7+
from src.modules.proppool import PropPool
8+
from src.utils import load_json
9+
import numpy as np
10+
import cv2
11+
12+
13+
def vgg_addmean(images):
14+
mean = [103.939, 116.779, 123.68]
15+
img_trans = []
16+
for img in images:
17+
blue, green, red = np.split(img, 3, 2)
18+
img_trans.append(np.concatenate(
19+
[blue + mean[0], green + mean[1], red + mean[2]], 2),
20+
)
21+
return img_trans
22+
23+
24+
def bgr_to_rgb(images):
25+
img_trans = []
26+
for img in images:
27+
blue, green, red = np.split(img, 3, 2)
28+
img_trans.append(np.concatenate([red, green, blue], 2))
29+
return img_trans
30+
31+
32+
class VGG16(object):
33+
def __init__(self, weights_file, classes_file):
34+
self.classes = load_json(classes_file)
35+
self.weights = np.load(weights_file, encoding='latin1').item()
36+
37+
def vgg_addmean(self, images):
38+
mean = [103.939, 116.779, 123.68]
39+
img_trans = []
40+
for img in images:
41+
blue, green, red = np.split(img, 3, 2)
42+
img_trans.append(np.concatenate(
43+
[blue + mean[0], green + mean[1], red + mean[2]], 2),
44+
)
45+
return img_trans
46+
47+
def load_image(self, image_file):
48+
image = cv2.imread(image_file)
49+
image = cv2.resize(image, (224, 224))
50+
image = np.expand_dims(image, axis=0)
51+
return self.vgg_addmean(image)
52+
53+
def build_model(self, batch_size, alpha):
54+
return Sequential([
55+
Convolution(batch_size=batch_size,
56+
initializer=self.weights['conv1_1'],
57+
first=True,
58+
name='conv1_1_'),
59+
Convolution_ab(
60+
batch_size=batch_size,
61+
initializer=self.weights['conv1_2'],
62+
alpha=alpha,
63+
name='conv1_2_'),
64+
PropPool(name='PropPool1'),
65+
Convolution_ab(
66+
batch_size=batch_size,
67+
initializer=self.weights['conv2_1'],
68+
alpha=alpha,
69+
name='conv2_1_'),
70+
Convolution_ab(
71+
batch_size=batch_size,
72+
initializer=self.weights['conv2_2'],
73+
alpha=alpha,
74+
name='conv2_2_'),
75+
PropPool(name='PropPool2'),
76+
Convolution_ab(
77+
batch_size=batch_size,
78+
initializer=self.weights['conv3_1'],
79+
alpha=alpha,
80+
name='conv3_1_'),
81+
Convolution_ab(
82+
batch_size=batch_size,
83+
initializer=self.weights['conv3_2'],
84+
alpha=alpha,
85+
name='conv3_2_'),
86+
Convolution_ab(
87+
batch_size=batch_size,
88+
initializer=self.weights['conv3_3'],
89+
alpha=alpha,
90+
name='conv3_3_'),
91+
PropPool(name='PropPool3'),
92+
Convolution_ab(
93+
batch_size=batch_size,
94+
initializer=self.weights['conv4_1'],
95+
alpha=alpha,
96+
name='conv4_1_'),
97+
Convolution_ab(
98+
batch_size=batch_size,
99+
initializer=self.weights['conv4_2'],
100+
alpha=alpha,
101+
name='conv4_2_'),
102+
Convolution_ab(
103+
batch_size=batch_size,
104+
initializer=self.weights['conv4_3'],
105+
alpha=alpha,
106+
name='conv4_3_'),
107+
PropPool(name='PropPool4'),
108+
Convolution_ab(
109+
batch_size=batch_size,
110+
initializer=self.weights['conv5_1'],
111+
alpha=alpha,
112+
name='conv5_1_'),
113+
Convolution_ab(
114+
batch_size=batch_size,
115+
initializer=self.weights['conv5_2'],
116+
alpha=alpha,
117+
name='conv5_2_'),
118+
Convolution_ab(
119+
batch_size=batch_size,
120+
initializer=self.weights['conv5_3'],
121+
alpha=alpha,
122+
name='conv5_3_'),
123+
PropPool(name='PropPool5'),
124+
Reshape(name='flat1'),
125+
Linear(batch_size=batch_size,
126+
initializer=self.weights['fc6'],
127+
alpha=alpha,
128+
name='fc6_'),
129+
Linear(batch_size=batch_size,
130+
initializer=self.weights['fc7'],
131+
alpha=alpha,
132+
name='fc7_'),
133+
Linear(batch_size=batch_size,
134+
initializer=self.weights['fc8'],
135+
alpha=alpha,
136+
name='fc8_'),
137+
])

0 commit comments

Comments
 (0)