-
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
Copy pathnegative_log_likelihood.rs
100 lines (91 loc) · 3.79 KB
/
negative_log_likelihood.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// Negative Log Likelihood Loss Function
//
// The `neg_log_likelihood` function calculates the Negative Log Likelyhood loss,
// which is a loss function used for classification problems in machine learning.
//
// ## Formula
//
// For a pair of actual and predicted values, represented as vectors `y_true` and
// `y_pred`, the Negative Log Likelihood loss is calculated as:
//
// - loss = `-y_true * log(y_pred) - (1 - y_true) * log(1 - y_pred)`.
//
// It returns the average loss by dividing the `total_loss` by total no. of
// elements.
//
// https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81
// http://neuralnetworksanddeeplearning.com/chap3.html
// Derivation of the formula:
// https://medium.com/@bhardwajprakarsh/negative-log-likelihood-loss-why-do-we-use-it-for-binary-classification-7625f9e3c944
pub fn neg_log_likelihood(
y_true: &[f64],
y_pred: &[f64],
) -> Result<f64, NegativeLogLikelihoodLossError> {
// Checks if the inputs are empty
if y_true.len() != y_pred.len() {
return Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength);
}
// Checks if the length of the actual and predicted values are equal
if y_pred.is_empty() {
return Err(NegativeLogLikelihoodLossError::EmptyInputs);
}
// Checks values are between 0 and 1
if !are_all_values_in_range(y_true) || !are_all_values_in_range(y_pred) {
return Err(NegativeLogLikelihoodLossError::InvalidValues);
}
let mut total_loss: f64 = 0.0;
for (p, a) in y_pred.iter().zip(y_true.iter()) {
let loss: f64 = -a * p.ln() - (1.0 - a) * (1.0 - p).ln();
total_loss += loss;
}
Ok(total_loss / (y_pred.len() as f64))
}
#[derive(Debug, PartialEq, Eq)]
pub enum NegativeLogLikelihoodLossError {
InputsHaveDifferentLength,
EmptyInputs,
InvalidValues,
}
fn are_all_values_in_range(values: &[f64]) -> bool {
values.iter().all(|&x| (0.0..=1.0).contains(&x))
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! test_with_wrong_inputs {
($($name:ident: $inputs:expr,)*) => {
$(
#[test]
fn $name() {
let (values_a, values_b, expected_error) = $inputs;
assert_eq!(neg_log_likelihood(&values_a, &values_b), expected_error);
assert_eq!(neg_log_likelihood(&values_b, &values_a), expected_error);
}
)*
}
}
test_with_wrong_inputs! {
different_length: (vec![0.9, 0.0, 0.8], vec![0.9, 0.1], Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength)),
different_length_one_empty: (vec![], vec![0.9, 0.1], Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength)),
value_greater_than_1: (vec![1.1, 0.0, 0.8], vec![0.1, 0.2, 0.3], Err(NegativeLogLikelihoodLossError::InvalidValues)),
value_greater_smaller_than_0: (vec![0.9, 0.0, -0.1], vec![0.1, 0.2, 0.3], Err(NegativeLogLikelihoodLossError::InvalidValues)),
empty_input: (vec![], vec![], Err(NegativeLogLikelihoodLossError::EmptyInputs)),
}
macro_rules! test_neg_log_likelihood {
($($name:ident: $inputs:expr,)*) => {
$(
#[test]
fn $name() {
let (actual_values, predicted_values, expected) = $inputs;
assert_eq!(neg_log_likelihood(&actual_values, &predicted_values).unwrap(), expected);
}
)*
}
}
test_neg_log_likelihood! {
set_0: (vec![1.0, 0.0, 1.0], vec![0.9, 0.1, 0.8], 0.14462152754328741),
set_1: (vec![1.0, 0.0, 1.0], vec![0.1, 0.2, 0.3], 1.2432338162113972),
set_2: (vec![0.0, 1.0, 0.0], vec![0.1, 0.2, 0.3], 0.6904911240102196),
set_3: (vec![1.0, 0.0, 1.0, 0.0], vec![0.9, 0.1, 0.8, 0.2], 0.164252033486018),
}
}