-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathNetworkBuilder.swift
138 lines (109 loc) · 4.53 KB
/
NetworkBuilder.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Copyright © 2016 Alejandro Isaza. All rights reserved.
import BrainCore
import HDF5Kit
import Upsurge
class Source: DataLayer {
let id = NSUUID()
let name: String? = nil
var size: Int
var data: Blob
var outputSize: Int {
return size
}
init(size: Int) {
data = Blob(count: size)
self.size = size
}
func nextBatch(batchSize: Int) -> Blob {
precondition(batchSize == 1)
return data
}
}
class Sink: SinkLayer {
let id = NSUUID()
let name: String? = nil
var size: Int
var data: Blob = []
init(size: Int) {
self.size = size
}
var inputSize: Int {
return size
}
func consume(input: Blob) {
self.data = input
}
}
class NetworkBuilder {
var dataLayer: Source
var sinkLayer: Sink
init(inputSize: Int, outputSize: Int) {
dataLayer = Source(size: inputSize)
sinkLayer = Sink(size: outputSize)
}
func loadNetFromFile(path: String) -> Net {
guard let file = File.open(path, mode: .ReadOnly) else {
fatalError("File not found '\(path)'")
}
let lstm_1 = try! loadLSTMLayerFromFile(file, name: "lstm_1")
let lstm_2 = try! loadLSTMLayerFromFile(file, name: "lstm_2")
let denseLayer = loadDenseLayerFromFile(file)
return Net.build {
self.dataLayer => lstm_1 => lstm_2 => denseLayer => self.sinkLayer
}
}
private func loadLSTMLayerFromFile(file: File, name: String) throws -> LSTMLayer {
guard let group = file.openGroup(name) else {
fatalError("LSTM \(name) group not found in file")
}
guard let
ucDataset = group.openFloatDataset("\(name)_U_c"),
ufDataset = group.openFloatDataset("\(name)_U_f"),
uiDataset = group.openFloatDataset("\(name)_U_i"),
uoDataset = group.openFloatDataset("\(name)_U_o"),
wcDataset = group.openFloatDataset("\(name)_W_c"),
wfDataset = group.openFloatDataset("\(name)_W_f"),
wiDataset = group.openFloatDataset("\(name)_W_i"),
woDataset = group.openFloatDataset("\(name)_W_o"),
bcDataset = group.openFloatDataset("\(name)_b_c"),
bfDataset = group.openFloatDataset("\(name)_b_f"),
biDataset = group.openFloatDataset("\(name)_b_i"),
boDataset = group.openFloatDataset("\(name)_b_o")
else {
fatalError("LSTM weights for \(name) not found in file")
}
let inputSize = wcDataset.space.dims[0]
let unitCount = wcDataset.space.dims[1]
let weights = LSTMLayer.makeWeightsFromComponents(
Wc: Matrix(rows: inputSize, columns: unitCount, elements: try wcDataset.read()),
Wf: Matrix(rows: inputSize, columns: unitCount, elements: try wfDataset.read()),
Wi: Matrix(rows: inputSize, columns: unitCount, elements: try wiDataset.read()),
Wo: Matrix(rows: inputSize, columns: unitCount, elements: try woDataset.read()),
Uc: Matrix(rows: unitCount, columns: unitCount, elements: try ucDataset.read()),
Uf: Matrix(rows: unitCount, columns: unitCount, elements: try ufDataset.read()),
Ui: Matrix(rows: unitCount, columns: unitCount, elements: try uiDataset.read()),
Uo: Matrix(rows: unitCount, columns: unitCount, elements: try uoDataset.read()))
let biases = ValueArray([
try biDataset.read(),
try bcDataset.read(),
try bfDataset.read(),
try boDataset.read()
].flatMap({ $0 }))
return LSTMLayer(weights: weights, biases: biases, batchSize: 1, name: name)
}
private func loadDenseLayerFromFile(file: File) -> InnerProductLayer {
guard let group = file.openGroup("dense_1") else {
fatalError("Dense group not found in file")
}
guard let weightsDataset = group.openFloatDataset("dense_1_W"), weights = try? weightsDataset.read() else {
fatalError("Dense weights not found in file")
}
guard let biasesDataset = group.openFloatDataset("dense_1_b"), biases = try? biasesDataset.read() else {
fatalError("Dense biases not found in file")
}
let inputSize = weightsDataset.space.dims[0]
let outputSize = weightsDataset.space.dims[1]
let weightsMatrix = Matrix(rows: inputSize, columns: outputSize, elements: weights)
return InnerProductLayer(weights: weightsMatrix, biases: ValueArray(biases), name: "dense_1")
}
}