Skip to content

Commit db91623

Browse files
committed
Merge branch 'master' of github.com:bootphon/abnet3
2 parents 92a72fd + 5d75e5b commit db91623

19 files changed

+2183
-54
lines changed

abnet3/dataloader.py

+397-7
Large diffs are not rendered by default.

abnet3/embedder.py

+80-1
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
import h5features
1212
import argparse
1313

14-
from abnet3.utils import read_feats
14+
from abnet3.utils import read_feats, EmbeddingObserver
1515
from abnet3.model import *
16+
from abnet3.integration import BiWeightedDeepLearnt
1617

1718

1819
class EmbedderBuilder:
@@ -64,6 +65,9 @@ def embed(self):
6465
if self.network_path is not None:
6566
self.network.load_network(self.network_path)
6667
self.network.eval()
68+
69+
if self.cuda:
70+
self.network.cuda()
6771
print("Done loading network weights")
6872

6973
with h5features.Reader(self.feature_path, 'features') as fh:
@@ -112,6 +116,9 @@ def embed(self):
112116
self.network.load_network(self.network_path)
113117
self.network.eval()
114118

119+
if self.cuda:
120+
self.network.cuda()
121+
115122
with h5features.Reader(self.feature_path, 'features') as fh:
116123
features = fh.read()
117124

@@ -140,3 +147,75 @@ def embed(self):
140147

141148
with h5features.Writer(self.output_path+'.phn') as fh:
142149
fh.write(data_phn, 'features')
150+
151+
class MultimodalEmbedder(EmbedderBuilder):
152+
"""
153+
Embedder class for multimodal siamese network
154+
"""
155+
156+
def __init__(self, *args, **kwargs):
157+
super(MultimodalEmbedder, self).__init__(*args, **kwargs)
158+
self.observers = [] #tuples list, of the form (EmbedderObserver,
159+
#function to get the data,
160+
#path to be saved)
161+
162+
if isinstance(self.network.integration_unit, BiWeightedDeepLearnt):
163+
print("Placing observer to save learnt attention weights")
164+
self.observers.append(EmbeddingObserver(
165+
self.network.integration_unit.get_weights,
166+
self.output_path+"attention_weights.features"))
167+
168+
def embed(self):
169+
"""
170+
Embed method to embed features based on a saved network
171+
"""
172+
173+
if self.network_path is not None:
174+
self.network.load_network(self.network_path)
175+
self.network.eval()
176+
177+
if self.cuda:
178+
self.network.cuda()
179+
180+
items = None
181+
times = None
182+
features_list = []
183+
for path in self.feature_path:
184+
with h5features.Reader(path, 'features') as fh:
185+
features = fh.read()
186+
features_list.append(features.features())
187+
check_items = features.items()
188+
check_times = features.labels()
189+
if not items:
190+
items = check_items
191+
if not times:
192+
times = check_times
193+
194+
print("Done loading input feature file")
195+
196+
zipped_feats = zip(*features_list)
197+
embeddings = []
198+
for feats in zipped_feats:
199+
modes_list = []
200+
for feat in feats:
201+
if feat.dtype != np.float32:
202+
feat = feat.astype(np.float32)
203+
feat_torch = Variable(torch.from_numpy(feat), volatile=True)
204+
if self.cuda:
205+
feat_torch = feat_torch.cuda()
206+
modes_list.append(feat_torch)
207+
emb, _ = self.network(modes_list, modes_list)
208+
emb = emb.cpu()
209+
embeddings.append(emb.data.numpy())
210+
211+
#Register activity on observer
212+
for observer in self.observers:
213+
observer.register_status()
214+
215+
data = h5features.Data(items, times, embeddings, check=True)
216+
with h5features.Writer(self.output_path + "embedded.features") as fh:
217+
fh.write(data, 'features')
218+
219+
#Save observer registers
220+
for observer in self.observers:
221+
observer.save(items, times)

abnet3/features.py

+113-18
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import os
1313
import h5py
1414
import shutil
15-
import tempfile
15+
import argparse
1616

1717
from abnet3.utils import read_vad_file, read_feats, Features_Accessor
1818

@@ -341,6 +341,27 @@ def load_mean_variance(self, file_path):
341341

342342
return {'mean': mean_var[0], 'variance': mean_var[1]}
343343

344+
345+
def normalize(self, input_features, output_features):
346+
print("Normalizing")
347+
if self.norm_per_file:
348+
self.mean_var_norm_per_file(input_features, output_features,
349+
vad_file=self.vad_file)
350+
else:
351+
if self.load_mean_variance_path is not None:
352+
params = self.load_mean_variance(
353+
file_path=self.load_mean_variance_path)
354+
else:
355+
params = None
356+
mean, variance = self.mean_variance_normalisation(
357+
input_features, output_features, params=params,
358+
vad_file=self.vad_file
359+
)
360+
if self.save_mean_variance_path is not None:
361+
self.save_mean_variance(
362+
mean, variance,
363+
output_file=self.save_mean_variance_path)
364+
344365
def generate(self):
345366

346367
functions = {
@@ -370,23 +391,7 @@ def generate(self):
370391
if self.normalization:
371392
print("Normalizing")
372393
h5_temp2 = tempdir + '/temp2'
373-
if self.norm_per_file:
374-
self.mean_var_norm_per_file(h5_temp1, h5_temp2,
375-
vad_file=self.vad_file)
376-
else:
377-
if self.load_mean_variance_path is not None:
378-
params = self.load_mean_variance(
379-
file_path=self.load_mean_variance_path)
380-
else:
381-
params = None
382-
mean, variance = self.mean_variance_normalisation(
383-
h5_temp1, h5_temp2, params=params,
384-
vad_file=self.vad_file
385-
)
386-
if self.save_mean_variance_path is not None:
387-
self.save_mean_variance(
388-
mean, variance,
389-
output_file=self.save_mean_variance_path)
394+
self.normalize(h5_temp1, h5_temp2)
390395
else:
391396
h5_temp2 = h5_temp1
392397
if self.stack:
@@ -397,3 +402,93 @@ def generate(self):
397402
shutil.copy(h5_temp2, self.output_path)
398403
finally:
399404
shutil.rmtree(tempdir)
405+
406+
407+
408+
def main_wav(args):
409+
410+
features_generator = FeaturesGenerator(
411+
files=args.wav_dir,
412+
output_path=args.output_path,
413+
method=args.method,
414+
n_filters=args.n_filters,
415+
save_mean_variance_path=args.save_mean_var,
416+
load_mean_variance_path=args.load_mean_var,
417+
vad_file=args.vad,
418+
normalization=args.normalization,
419+
stack=True,
420+
norm_per_file=args.norm_per_file,
421+
norm_per_channel=args.norm_per_channel,
422+
)
423+
424+
features_generator.generate()
425+
426+
def main_normalize(args):
427+
features_generator = FeaturesGenerator(
428+
save_mean_variance_path=args.save_mean_var,
429+
load_mean_variance_path=args.load_mean_var,
430+
vad_file=args.vad,
431+
normalization=True,
432+
norm_per_file=args.norm_per_file,
433+
norm_per_channel=args.norm_per_channel,
434+
)
435+
436+
features_generator.normalize(
437+
args.input_features,
438+
args.output_features
439+
)
440+
441+
def main():
442+
parser = argparse.ArgumentParser()
443+
444+
subparsers = parser.add_subparsers(help='sub-command help')
445+
446+
parser_wav = subparsers.add_parser("wav")
447+
parser_normalize = subparsers.add_parser("norm")
448+
449+
450+
parser_wav.add_argument("wav_dir", help="Path to wav directory")
451+
parser_wav.add_argument("output_path", help="Path to output h5f file")
452+
parser_wav.add_argument("method", choices=["mfcc", "fbanks"],
453+
help="which features to generate")
454+
parser_wav.add_argument("--vad", help="Path to vad file "
455+
"(CSV, seconds with header)")
456+
parser_wav.add_argument("--normalization", "-n", action="store_true")
457+
parser_wav.add_argument("--norm-per-file", action="store_true",
458+
help="Independent normalization for each file")
459+
parser_wav.add_argument("--norm-per-channel", action="store_true",
460+
help="Normalize each channel independently")
461+
parser_wav.add_argument("--n-filters", type=int, default=40)
462+
parser_wav.add_argument("--save-mean-var", type=str,
463+
help="Path to emplacement where mean / var"
464+
"will be saved")
465+
parser_wav.add_argument("--load-mean-var", type=str,
466+
help="Path to emplacement where mean / var"
467+
"are saved. Will be used to compute test features")
468+
469+
parser_wav.set_defaults(func=main_wav)
470+
471+
parser_normalize.add_argument("input_features", help="Path to input h5f file")
472+
parser_normalize.add_argument("output_features", help="Path to output h5f file")
473+
parser_normalize.add_argument("--vad", help="Path to vad file "
474+
"(CSV, seconds with header)")
475+
parser_normalize.add_argument("--norm-per-file", action="store_true",
476+
help="Independent normalization for each file")
477+
parser_normalize.add_argument("--norm-per-channel", action="store_true",
478+
help="Normalize each channel independently")
479+
parser_normalize.add_argument("--save-mean-var", type=str,
480+
help="Path to emplacement where mean / var"
481+
"will be saved")
482+
parser_normalize.add_argument("--load-mean-var", type=str,
483+
help="Path to emplacement where mean / var"
484+
"are saved. Will be used to compute test features")
485+
486+
parser_normalize.set_defaults(func=main_normalize)
487+
488+
args = parser.parse_args()
489+
if args.func:
490+
args.func(args)
491+
492+
493+
if __name__ == '__main__':
494+
main()

abnet3/gridsearch.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ def run_single_experiment(self, single_experiment=None, gpu_id=0):
139139

140140
os.makedirs(single_experiment['pathname_experience'], exist_ok=True)
141141

142+
with open(os.path.join(single_experiment['pathname_experience'], 'exp.yml'), 'w') as f:
143+
yaml.dump(single_experiment, f, default_flow_style=False)
144+
142145
features_prop = single_experiment['features']
143146
features_class = getattr(abnet3.features, features_prop['class'])
144147
arguments = features_prop['arguments']
@@ -170,7 +173,8 @@ def run_single_experiment(self, single_experiment=None, gpu_id=0):
170173
dataloader_prop = single_experiment['dataloader']
171174
dataloader_class = getattr(abnet3.dataloader, dataloader_prop['class'])
172175
arguments = dataloader_prop['arguments']
173-
arguments['pairs_path'] = sampler.directory_output
176+
if not 'pairs_path' in arguments:
177+
arguments['pairs_path'] = sampler.directory_output
174178
arguments['features_path'] = features.output_path
175179
dataloader = dataloader_class(**arguments)
176180

0 commit comments

Comments
 (0)