Skip to content

Commit 6e19cd8

Browse files
authored
Merge pull request #234 from ashvardanian/main-dev
Fix: Jensen Shannon square roots (#233)
2 parents c124410 + d14b654 commit 6e19cd8

File tree

8 files changed

+264
-198
lines changed

8 files changed

+264
-198
lines changed

CONTRIBUTING.md

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,35 +168,65 @@ open target/criterion/report/index.html
168168
169169
## JavaScript
170170
171-
If you don't have NPM installed:
171+
### NodeJS
172+
173+
If you don't have the environment configured, here are the [installation options](https://github.com/nvm-sh/nvm?tab=readme-ov-file#install--update-script) with different tools:
174+
175+
```sh
176+
wget -qO- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash # Linux
177+
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.1/install.sh | bash # MacOS
178+
```
179+
180+
Install dependencies:
172181
173182
```sh
174-
wget -qO- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.7/install.sh | bash
175183
nvm install 20
184+
npm install -g typescript # Install the TypeScript compiler globally
185+
npm install --save-dev @types/node # Install the Node.js type definitions as a dev dependency
176186
```
177187
178188
Testing and benchmarking:
179189
180190
```sh
181-
npm install -g typescript
182-
npm run build-js
183-
npm test
184-
npm run bench
191+
npm run build-js # Build the JavaScript code using TypeScript configurations
192+
npm test # Run the test suite
193+
npm run bench # Run the benchmark script
194+
```
195+
196+
### Deno
197+
198+
If you don't have the environment configured, here are [installation options](https://docs.deno.com/runtime/getting_started/installation/) with different tools:
199+
200+
```sh
201+
wget -qO- https://deno.land/x/install/install.sh | sh # Linux
202+
curl -fsSL https://deno.land/install.sh | sh # MacOS
203+
irm https://deno.land/install.ps1 | iex # Windows
185204
```
186205
187-
Running with Deno:
206+
Testing:
188207
189208
```sh
190209
deno test --allow-read
191210
```
192211
193-
Running with Bun:
212+
### Bun
213+
214+
If you don't have the environment configured, here are the [installation options](https://bun.sh/docs/installation) with different tools:
194215
195216
```sh
196-
npm install -g bun
197-
bun test
217+
wget -qO- https://bun.sh/install | bash # for Linux
218+
curl -fsSL https://bun.sh/install | bash # for macOS and WSL
198219
```
199220
221+
Testing:
222+
223+
```sh
224+
bun install
225+
bun test ./scripts/test.mjs
226+
```
227+
228+
... wouldn't work for now.
229+
200230
## Swift
201231
202232
```sh

include/simsimd/probability.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_
108108
d += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); \
109109
d += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); \
110110
} \
111-
*result = (simsimd_distance_t)d / 2; \
111+
*result = SIMSIMD_SQRT(((simsimd_distance_t)d / 2)); \
112112
}
113113

114114
SIMSIMD_MAKE_KL(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f64_serial
@@ -219,12 +219,13 @@ SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const *a, simsimd_f32_t co
219219
float32x4_t log_ratio_b_vec = _simsimd_log2_f32_neon(ratio_b_vec);
220220
float32x4_t prod_a_vec = vmulq_f32(a_vec, log_ratio_a_vec);
221221
float32x4_t prod_b_vec = vmulq_f32(b_vec, log_ratio_b_vec);
222+
222223
sum_vec = vaddq_f32(sum_vec, vaddq_f32(prod_a_vec, prod_b_vec));
223224
if (n != 0) goto simsimd_js_f32_neon_cycle;
224225

225226
simsimd_f32_t log2_normalizer = 0.693147181f;
226-
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer;
227-
*result = sum / 2;
227+
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer / 2;
228+
*result = _simsimd_sqrt_f32_neon(sum);
228229
}
229230

230231
#pragma clang attribute pop
@@ -296,8 +297,8 @@ SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const *a, simsimd_f16_t co
296297
if (n) goto simsimd_js_f16_neon_cycle;
297298

298299
simsimd_f32_t log2_normalizer = 0.693147181f;
299-
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer;
300-
*result = sum / 2;
300+
simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer / 2;
301+
*result = _simsimd_sqrt_f32_neon(sum);
301302
}
302303

303304
#pragma clang attribute pop
@@ -403,8 +404,8 @@ SIMSIMD_PUBLIC void simsimd_js_f16_haswell(simsimd_f16_t const *a, simsimd_f16_t
403404

404405
simsimd_f32_t log2_normalizer = 0.693147181f;
405406
simsimd_f32_t sum = _simsimd_reduce_f32x8_haswell(sum_vec);
406-
sum *= log2_normalizer;
407-
*result = sum / 2;
407+
sum *= log2_normalizer / 2;
408+
*result = _simsimd_sqrt_f32_haswell(sum);
408409
}
409410

410411
#pragma clang attribute pop
@@ -496,7 +497,9 @@ SIMSIMD_PUBLIC void simsimd_js_f32_skylake(simsimd_f32_t const *a, simsimd_f32_t
496497
if (n) goto simsimd_js_f32_skylake_cycle;
497498

498499
simsimd_f32_t log2_normalizer = 0.693147181f;
499-
*result = _mm512_reduce_add_ps(_mm512_add_ps(sum_a_vec, sum_b_vec)) * log2_normalizer / 2;
500+
simsimd_f32_t sum = _mm512_reduce_add_ps(_mm512_add_ps(sum_a_vec, sum_b_vec));
501+
sum *= log2_normalizer / 2;
502+
*result = _simsimd_sqrt_f32_haswell(sum);
500503
}
501504

502505
#pragma clang attribute pop
@@ -586,7 +589,9 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const *a, simsimd_f16_
586589
if (n) goto simsimd_js_f16_sapphire_cycle;
587590

588591
simsimd_f32_t log2_normalizer = 0.693147181f;
589-
*result = _mm512_reduce_add_ph(_mm512_add_ph(sum_a_vec, sum_b_vec)) * log2_normalizer / 2;
592+
simsimd_f32_t sum = _mm512_reduce_add_ph(_mm512_add_ph(sum_a_vec, sum_b_vec));
593+
sum *= log2_normalizer / 2;
594+
*result = _simsimd_sqrt_f32_haswell(sum);
590595
}
591596

592597
#pragma clang attribute pop

javascript/fallback.ts

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ export const jaccard = (a: Uint8Array, b: Uint8Array): number => {
171171
};
172172

173173
/**
174-
* @brief Computes the kullbackleibler similarity coefficient between two vectors.
174+
* @brief Computes the Kullback-Leibler divergence between two probability distributions.
175175
* @param {Float64Array|Float32Array} a - The first vector.
176176
* @param {Float64Array|Float32Array} b - The second vector.
177177
* @returns {number} The Jaccard similarity coefficient between vectors a and b.
@@ -182,38 +182,49 @@ export const kullbackleibler = (a: Float64Array | Float32Array, b: Float64Array
182182
}
183183

184184
let divergence = 0.0;
185-
186185
for (let i = 0; i < a.length; i++) {
187-
if (a[i] > 0) {
188-
if (b[i] === 0) {
189-
throw new Error(
190-
"Division by zero encountered in KL divergence calculation"
191-
);
192-
}
193-
divergence += a[i] * Math.log(a[i] / b[i]);
186+
if (a[i] < 0 || b[i] < 0) {
187+
throw new Error("Negative values are not allowed in probability distributions");
188+
}
189+
if (b[i] === 0) {
190+
throw new Error(
191+
"Division by zero encountered in KL divergence calculation"
192+
);
194193
}
194+
divergence += a[i] * Math.log(a[i] / b[i]);
195195
}
196196

197197
return divergence;
198198
};
199199

200200
/**
201-
* @brief Computes the jensenshannon similarity coefficient between two vectors.
202-
* @param {Float64Array|Float32Array} a - The first vector.
203-
* @param {Float64Array|Float32Array} b - The second vector.
204-
* @returns {number} The Jaccard similarity coefficient between vectors a and b.
201+
* @brief Computes the Jensen-Shannon distance between two probability distributions.
202+
* @param {Float64Array|Float32Array} a - The first probability distribution.
203+
* @param {Float64Array|Float32Array} b - The second probability distribution.
204+
* @returns {number} The Jensen-Shannon distance between distributions a and b.
205205
*/
206206
export const jensenshannon = (a: Float64Array | Float32Array, b: Float64Array | Float32Array): number => {
207207
if (a.length !== b.length) {
208208
throw new Error("Arrays must be of the same length");
209209
}
210210

211-
const m = a.map((value, index) => (value + b[index]) / 2);
211+
let divergence = 0;
212+
for (let i = 0; i < a.length; i++) {
213+
if (a[i] < 0 || b[i] < 0) {
214+
throw new Error("Negative values are not allowed in probability distributions");
215+
}
216+
const m = (a[i] + b[i]) / 2;
217+
if (m > 0) {
218+
if (a[i] > 0) divergence += a[i] * Math.log(a[i] / m);
219+
if (b[i] > 0) divergence += b[i] * Math.log(b[i] / m);
220+
}
221+
}
212222

213-
const divergence = 0.5 * kullbackleibler(a, m) + 0.5 * kullbackleibler(b, m);
223+
divergence /= 2;
214224
return Math.sqrt(divergence);
215225
};
216226

227+
217228
export default {
218229
sqeuclidean,
219230
euclidean,

package-lock.json

Lines changed: 10 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
},
4949
"devDependencies": {
5050
"@types/bindings": "^1.5.5",
51-
"@types/node": "^20.17.1",
51+
"@types/node": "^20.17.6",
5252
"node-gyp": "^10.0.1",
5353
"prebuildify": "^6.0.0",
5454
"typescript": "^5.3.3"

rust/lib.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -877,19 +877,19 @@ mod tests {
877877
// Adding new tests for probability similarities
878878
#[test]
879879
fn test_js_f32() {
880-
let a = &[0.1, 0.9, 0.0];
881-
let b = &[0.2, 0.8, 0.0];
880+
let a: &[f32; 3] = &[0.1, 0.9, 0.0];
881+
let b: &[f32; 3] = &[0.2, 0.8, 0.0];
882882

883883
if let Some(result) = ProbabilitySimilarity::jensenshannon(a, b) {
884884
println!("The result of js_f32 is {:.8}", result);
885-
assert_almost_equal(0.01, result, 0.01); // Example value
885+
assert_almost_equal(0.099, result, 0.01); // Example value
886886
}
887887
}
888888

889889
#[test]
890890
fn test_kl_f32() {
891-
let a = &[0.1, 0.9, 0.0];
892-
let b = &[0.2, 0.8, 0.0];
891+
let a: &[f32; 3] = &[0.1, 0.9, 0.0];
892+
let b: &[f32; 3] = &[0.2, 0.8, 0.0];
893893

894894
if let Some(result) = ProbabilitySimilarity::kullbackleibler(a, b) {
895895
println!("The result of kl_f32 is {:.8}", result);

0 commit comments

Comments
 (0)