Skip to content

Commit 7554e62

Browse files
committed
refactor set_best_split
1 parent c04c809 commit 7554e62

File tree

2 files changed

+42
-37
lines changed

2 files changed

+42
-37
lines changed

R/RcppExports.R

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
22
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393
33

4-
### Efficient log-linear cumulative median.
4+
#' Efficient log-linear cumulative median.
55
cum_median_interface <- function(data_vec, weight_vec) {
66
.Call(`_binsegRcpp_cum_median_interface`, data_vec, weight_vec)
77
}
88

9-
### Use depth first search to compute a data.frame
10-
### with one row for each segment, and columns
11-
### splits and depth, number/depth of candidate
12-
### splits that need to be
13-
### computed after splitting that segment.
9+
#' Use depth first search to compute a data.frame
10+
#' with one row for each segment, and columns
11+
#' splits and depth, number/depth of candidate
12+
#' splits that need to be
13+
#' computed after splitting that segment.
1414
depth_first_interface <- function(n_data, min_segment_length) {
1515
.Call(`_binsegRcpp_depth_first_interface`, n_data, min_segment_length)
1616
}
1717

18-
### Compute a data.frame with one row for each distribution
19-
### implemented in the C++ code, and columns distribution.str,
20-
### parameters, description.
18+
#' Compute a data.frame with one row for each distribution
19+
#' implemented in the C++ code, and columns distribution.str,
20+
#' parameters, description.
2121
get_distribution_info <- function() {
2222
.Call(`_binsegRcpp_get_distribution_info`)
2323
}
2424

25-
### Low-level interface to binary segmentation algorithm.
25+
#' Low-level interface to binary segmentation algorithm.
2626
binseg_interface <- function(data_vec, weight_vec, max_segments, min_segment_length, distribution_str, container_str, is_validation_vec, position_vec) {
2727
.Call(`_binsegRcpp_binseg_interface`, data_vec, weight_vec, max_segments, min_segment_length, distribution_str, container_str, is_validation_vec, position_vec)
2828
}

src/binseg.cpp

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class CumDistribution : public Distribution {
139139
var_param = VARIANCE; \
140140
description = DESC; \
141141
param_names_vec.push_back("mean"); \
142-
if(var_param)param_names_vec.push_back("var"); \
142+
if(var_param)param_names_vec.push_back("var"); \
143143
dist_umap.emplace( #NAME, this ); \
144144
} \
145145
}; \
@@ -216,22 +216,27 @@ class absDistribution : public Distribution {
216216
int after_i = n_candidates-1-before_i;
217217
candidate_split_ptr->set_end_dist
218218
(first_data, first_candidate+before_i, last_data);
219-
candidate_split_ptr->before.center =
220-
before_median_vec[before_i];
221-
candidate_split_ptr->after.center =
222-
after_median_vec[after_i];
223-
candidate_split_ptr->before.spread =
224-
before_loss_vec[before_i]/before_weight_vec[before_i];
225-
candidate_split_ptr->after.spread =
226-
after_loss_vec[after_i]/after_weight_vec[after_i];
227-
candidate_split_ptr->before.loss = adjust
228-
(before_loss_vec[before_i],
229-
before_weight_vec[before_i],
230-
candidate_split_ptr->before.spread);
231-
candidate_split_ptr->after.loss = adjust
232-
(after_loss_vec[after_i],
233-
after_weight_vec[after_i],
234-
candidate_split_ptr->after.spread);
219+
std::vector<double> *loss_ptr, *median_ptr, *weight_ptr;
220+
ParamsLoss *pl_ptr;
221+
int i;
222+
for(int direction=0; direction<2; direction++){
223+
if(direction==0){
224+
loss_ptr = &before_loss_vec;
225+
median_ptr = &before_median_vec;
226+
weight_ptr = &before_weight_vec;
227+
pl_ptr = &candidate_split_ptr->before;
228+
i = before_i;
229+
}else{
230+
loss_ptr = &after_loss_vec;
231+
median_ptr = &after_median_vec;
232+
weight_ptr = &after_weight_vec;
233+
pl_ptr = &candidate_split_ptr->after;
234+
i = after_i;
235+
}
236+
pl_ptr->center = (*median_ptr)[i];
237+
pl_ptr->spread = (*loss_ptr)[i]/(*weight_ptr)[i];
238+
pl_ptr->loss = adjust((*loss_ptr)[i], (*weight_ptr)[i], pl_ptr->spread);
239+
}
235240
best_split_ptr->maybe_update(candidate_split_ptr);
236241
}
237242
}
@@ -270,17 +275,17 @@ class absDistribution : public Distribution {
270275
}
271276
};
272277

273-
#define ABS_DIST(NAME, DESC, VARIANCE) \
274-
class CONCAT(NAME, Distribution) : public absDistribution { \
275-
public: \
276-
CONCAT(NAME, Distribution) (){ \
277-
var_param = VARIANCE; \
278-
description = DESC; \
279-
param_names_vec.push_back("median"); \
280-
if(var_param)param_names_vec.push_back("scale"); \
278+
#define ABS_DIST(NAME, DESC, VARIANCE) \
279+
class CONCAT(NAME, Distribution) : public absDistribution { \
280+
public: \
281+
CONCAT(NAME, Distribution) (){ \
282+
var_param = VARIANCE; \
283+
description = DESC; \
284+
param_names_vec.push_back("median"); \
285+
if(var_param)param_names_vec.push_back("scale"); \
281286
dist_umap.emplace( #NAME, this ); \
282-
} \
283-
}; \
287+
} \
288+
}; \
284289
static CONCAT(NAME, Distribution) NAME;
285290

286291
ABS_DIST(l1, "change in median (loss is total absolute deviation)", false)

0 commit comments

Comments
 (0)