-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfit_models.R
160 lines (126 loc) · 5.19 KB
/
fit_models.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
################################################################################
# Fit Detection Models
################################################################################
library(Metrics)
library(ranger)
library(cli)
library(cowplot)
library(ggplot2)
library(ParBayesianOptimization) # only for xgboost
library(doParallel)
library(parallelly)
library(xgboost)
library(parallel)
library(data.table)
cli_div(theme = list(span.emph = list(color = "#3c77b9")))
# Set seed
set.seed(2024)
################################################################################
# Global Settings
################################################################################
# Define the settings for fitting the detection models -------------------------
# Note: Can be very time consuming (especially for XGBoost)
filter_df <- data.table(expand.grid(
dataset_name = c(
'adult_complete', 'car_evaluation', 'chess_king_rook_vs_king', 'connect_4',
'diabetes', 'diabetes_HI', 'diamonds', 'letter_recognition',
'magic_gamma_telescope', 'nursery', 'statlog_landsat_satellite'
),
model_name = c("ranger", "logReg", "xgboost"),
syn_name = c(
"ARF", "CTAB-GAN+", "CTGAN", "synthpop", "TabSyn", "TVAE"
)
))
max_runs <- 10
# The setting used in the main text of the manuscript
#filter_df <- data.table(
# dataset_name = rev(c("adult_complete", "nursery")),
# model_name = "xgboost",
# syn_name = rev(c("TabSyn", "CTGAN"))
#)
# Other global settings --------------------------------------------------------
# Threading/Number of CPUS
# Note: The total number of cores used will be 'mc.cores * n_threads'
n_threads <- 15 # number of threads for each mc parallel run
mc.cores <- 16L # number of cores for parallel processing
options(mc.cores = mc.cores)
options(ranger.num.threads = n_threads)
Sys.setenv("OMP_NUM_THREADS" = n_threads)
# XGBoost tuning parameters
n_parallel <- 15 #18
init_points <- n_parallel
time_limit <- 60 *30 # 45 minutes
# Load utility methods and create dirs -----------------------------------------
# Load global utility functions
source("utils.R")
# Create folder for tuning logs
if (!dir.exists("./tmp/tuning_logs")) dir.create("./tmp/tuning_logs", recursive = TRUE)
################################################################################
# Main script to fit detection models
################################################################################
# Create data.frame for all settings -------------------------------------------
cli_progress_step("Creating settings data.frame for fitting models")
# Find all available datasets
args <- lapply(list.files("./data"), function(dat_name) {
expand.grid(
dataset_name = dat_name,
model_name = c("ranger", "logReg", "xgboost"),
syn_name = list.files(paste0("./data/", dat_name, "/syn/")))
})
args <- do.call(rbind, args)
args <- data.table(args[order(args$dataset_name, args$model_name), ])
# Filter settings (as defined in the global settings)
args <- args[filter_df, on = c("dataset_name", "model_name", "syn_name")]
# Fitting detection models -----------------------------------------------------
cli_h1("Fitting detection models")
# Vector to store indices of rows with errors
error_idx <- c()
# Fit models
result <- mclapply(seq_len(nrow(args)), function(i) {
# Set seed
set.seed(42)
cli_progress_step(paste0(
"[{i}/{nrow(args)}] ",
"Dataset: {.emph {args$dataset_name[i]}} --- ",
"Model: {.emph {args$model_name[i]}} --- ",
"Synthesizer: {.emph {args$syn_name[i]}}"))
# Load data
data <- load_data(args$dataset_name[i], args$syn_name[i], test_split = 0.3)
# Fit model for the first 'max_runs' runs
tryCatch({
log <- fit_model(data, args$model_name[i], max_runs = max_runs,
n_threads = n_threads, time_limit = time_limit,
n_parallel = n_parallel, init_points = init_points)
return(log)
}, error = function(e) {
error_idx <<- c(error_idx, i)
cli_alert(col_red(paste0("[RUN {i}] Error in fitting model. Skipping.")))
print(e)
return(NULL)
})
})
result <- data.table(do.call(rbind, result))
# Show errors if any
if (length(error_idx) > 0) {
cli_alert(col_red(paste0("[ERROR SUMMARY] Error in fitting models for runs: ",
paste(error_idx, collapse = ", "))))
}
# Update and save results ------------------------------------------------------
if (!dir.exists("./results/model_performance")) {
dir.create("./results/model_performance", recursive = TRUE)
}
if (file.exists("./results/model_performance/model_performance.rds")) {
result_old <- data.table(readRDS("./results/model_performance/model_performance.rds"))
# Update result
result_old <- rbind(
result_old,
result[!result_old, on = c("dataset", "syn_name", "run", "model_name",
"train", "metric")])
result_old[result, on = c("dataset", "syn_name", "run", "model_name",
"train", "metric"), value := i.value]
result <- result_old
cli_progress_step("Updating results")
}
cli_progress_step("Saving results")
saveRDS(result, file = "./results/model_performance/model_performance.rds")
cli_progress_done()