-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbuild_db_mp.py
85 lines (56 loc) · 1.92 KB
/
build_db_mp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sound
import hashlib
from db import Database
from mp import MatchingPursuit
from mdct import Mdct
import numpy as np
from sqlite3 import Binary
current_dir = os.path.dirname(os.path.abspath(__file__))
music_dir = os.path.join(current_dir, 'music')
music_list = []
# Selecting all mp3 in music directory
for dirpath,dirnames,filenames in os.walk(music_dir, followlinks=True):
for filename in filenames:
music_list.append(os.path.join(dirpath,filename))
print("%d tracks in database" % len(music_list))
# Creating an empty database
if os.path.isfile(current_dir + "/database.sqlite"):
os.remove(current_dir + "/database.sqlite")
database = Database()
database.create()
del database
# Processing each track
atoms_per_frame = 20
frame_duration = 5
mdct = Mdct()
mp = MatchingPursuit(mdct, atoms_per_frame)
biggest_atom_size = mdct.sizes[-1]
Fe = 44100
frame_size = int(frame_duration*Fe/biggest_atom_size)*biggest_atom_size
mp.buildMask(frame_size)
for track in music_list:
query = []
track_title = track[:-4]
database = Database()
track_id = database.addTrack(track_title)
del database
print("=> Processing %s, id: %d" % (track_title, track_id))
wavdata = sound.read(track)
Fe = wavdata.getframerate()
frame_number = int(wavdata.getnframes()/frame_size)
print("%d frames to process for this track" % frame_number)
progress = 0
for i in range(frame_number):
s = wavdata.readframes(frame_size)
s = np.frombuffer(s, dtype='<i2')
y = mp.sparse(s)
keys = mp.extractKeys(y)
query.extend([(Binary(hash_key)[0:5],track_id, int(i*frame_size+offset)) for (hash_key,offset) in keys])
if i*100/frame_number >= progress + 10:
progress = i * 100/frame_number
#print("%d%%," % (progress,)),
database = Database()
database.addFingerprint(query)
del database
print("100%")