Skip to content

Commit 58f8699

Browse files
mash: optimize mash similarity function, add benchmark (#395)
* mash: optimize mash similarity function, add benchmark Signed-off-by: Matias Insaurralde <[email protected]> * got test coverage to 100% and change variable names. * fixed TODO. --------- Signed-off-by: Matias Insaurralde <[email protected]> Co-authored-by: Timothy Stiles <[email protected]>
1 parent db703c3 commit 58f8699

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

mash/mash.go

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -106,35 +106,28 @@ func (mash *Mash) Sketch(sequence string) {
106106
// Similarity returns the Jaccard similarity between two sketches (number of matching hashes / sketch size)
107107
func (mash *Mash) Similarity(other *Mash) float64 {
108108
var sameHashes int
109+
largerSketch := mash
110+
smallerSketch := other
109111

110-
var largerSketch *Mash
111-
var smallerSketch *Mash
112-
113-
if mash.SketchSize > other.SketchSize {
114-
largerSketch = mash
115-
smallerSketch = other
116-
} else {
112+
if mash.SketchSize < other.SketchSize {
117113
largerSketch = other
118114
smallerSketch = mash
119115
}
120116

121-
largerSketchSizeShifted := largerSketch.SketchSize - 1
122-
smallerSketchSizeShifted := smallerSketch.SketchSize - 1
123-
124-
// if the largest hash in the larger sketch is smaller than the smallest hash in the smaller sketch, the distance is 1
125-
if largerSketch.Sketches[largerSketchSizeShifted] < smallerSketch.Sketches[0] {
126-
return 0
127-
}
128-
129-
// if the largest hash in the smaller sketch is smaller than the smallest hash in the larger sketch, the distance is 1
130-
if smallerSketch.Sketches[smallerSketchSizeShifted] < largerSketch.Sketches[0] {
117+
if largerSketch.Sketches[largerSketch.SketchSize-1] < smallerSketch.Sketches[0] || smallerSketch.Sketches[smallerSketch.SketchSize-1] < largerSketch.Sketches[0] {
131118
return 0
132119
}
133120

134-
for _, hash := range smallerSketch.Sketches {
135-
ind := sort.Search(largerSketchSizeShifted, func(ind int) bool { return largerSketch.Sketches[ind] <= hash })
136-
if largerSketch.Sketches[ind] == hash {
121+
smallSketchIndex, largeSketchIndex := 0, 0
122+
for smallSketchIndex < smallerSketch.SketchSize && largeSketchIndex < largerSketch.SketchSize {
123+
if smallerSketch.Sketches[smallSketchIndex] == largerSketch.Sketches[largeSketchIndex] {
137124
sameHashes++
125+
smallSketchIndex++
126+
largeSketchIndex++
127+
} else if smallerSketch.Sketches[smallSketchIndex] < largerSketch.Sketches[largeSketchIndex] {
128+
smallSketchIndex++
129+
} else {
130+
largeSketchIndex++
138131
}
139132
}
140133

mash/mash_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,38 @@ func TestMash(t *testing.T) {
3737
if distance != 1 {
3838
t.Errorf("Expected distance to be 1, got %f", distance)
3939
}
40+
41+
fingerprint1 = mash.New(17, 10)
42+
fingerprint1.Sketch("ATGCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA")
43+
44+
fingerprint2 = mash.New(17, 5)
45+
fingerprint2.Sketch("ATCGATCGATCGATCGATCGATCGATCGATCGATCGAATGCGATCGATCGATCGATCGATCG")
46+
47+
distance = fingerprint1.Distance(fingerprint2)
48+
if !(distance > 0.19 && distance < 0.21) {
49+
t.Errorf("Expected distance to be 0.19999999999999996, got %f", distance)
50+
}
51+
52+
fingerprint1 = mash.New(17, 10)
53+
fingerprint1.Sketch("ATCGATCGATCGATCGATCGATCGATCGATCGATCGAATGCGATCGATCGATCGATCGATCG")
54+
55+
fingerprint2 = mash.New(17, 5)
56+
fingerprint2.Sketch("ATGCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA")
57+
58+
distance = fingerprint1.Distance(fingerprint2)
59+
if distance != 0 {
60+
t.Errorf("Expected distance to be 0, got %f", distance)
61+
}
62+
}
63+
64+
func BenchmarkMashDistancee(b *testing.B) {
65+
fingerprint1 := mash.New(17, 10)
66+
fingerprint1.Sketch("ATGCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA")
67+
68+
fingerprint2 := mash.New(17, 9)
69+
fingerprint2.Sketch("ATGCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGA")
70+
71+
for i := 0; i < b.N; i++ {
72+
fingerprint1.Distance(fingerprint2)
73+
}
4074
}

0 commit comments

Comments
 (0)