Skip to content
Snippets Groups Projects
Commit 900d1869 authored by Pat Alt's avatar Pat Alt
Browse files

saving the best

parent 18a688bc
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -36,16 +36,26 @@ params = (
# Best grid search params:
append_best_params!(params, dataname)
if !GRID_SEARCH
run_experiment(
if GRID_SEARCH
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
params...
)
elseif FROM_GRID_SEARCH
outcomes_file_path = joinpath(
DEFAULT_OUTPUT_PATH,
"grid_search",
"$(replace(lowercase(dataname), " " => "_")).jls",
)
save_best(outcomes_file_path)
bmk2csv(dataname)
else
grid_search(
run_experiment(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params
model_tuning_params=model_tuning_params,
params...
)
end
\ No newline at end of file
......@@ -24,17 +24,26 @@ params = (
# Best grid search params:
append_best_params!(params, dataname)
if !GRID_SEARCH
run_experiment(
if GRID_SEARCH
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
params...
)
elseif FROM_GRID_SEARCH
outcomes_file_path = joinpath(
DEFAULT_OUTPUT_PATH,
"grid_search",
"$(replace(lowercase(dataname), " " => "_")).jls",
)
save_best(outcomes_file_path)
bmk2csv(dataname)
else
grid_search(
run_experiment(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
model_tuning_params=model_tuning_params,
params...
)
end
\ No newline at end of file
......@@ -56,17 +56,26 @@ params = (
dim_reduction=true,
)
if !GRID_SEARCH
run_experiment(
if GRID_SEARCH
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
params...
)
elseif FROM_GRID_SEARCH
outcomes_file_path = joinpath(
DEFAULT_OUTPUT_PATH,
"grid_search",
"$(replace(lowercase(dataname), " " => "_")).jls",
)
save_best(outcomes_file_path)
bmk2csv(dataname)
else
grid_search(
run_experiment(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
n_individuals=5
model_tuning_params=model_tuning_params,
params...
)
end
\ No newline at end of file
......@@ -36,16 +36,26 @@ params = (
# Best grid search params:
append_best_params!(params, dataname)
if !GRID_SEARCH
run_experiment(
if GRID_SEARCH
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
params...
)
elseif FROM_GRID_SEARCH
outcomes_file_path = joinpath(
DEFAULT_OUTPUT_PATH,
"grid_search",
"$(replace(lowercase(dataname), " " => "_")).jls",
)
save_best(outcomes_file_path)
bmk2csv(dataname)
else
grid_search(
run_experiment(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params
model_tuning_params=model_tuning_params,
params...
)
end
\ No newline at end of file
......@@ -37,16 +37,26 @@ params = (
# Best grid search params:
append_best_params!(params, dataname)
if !GRID_SEARCH
run_experiment(
if GRID_SEARCH
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
params...
)
elseif FROM_GRID_SEARCH
outcomes_file_path = joinpath(
DEFAULT_OUTPUT_PATH,
"grid_search",
"$(replace(lowercase(dataname), " " => "_")).jls",
)
save_best(outcomes_file_path)
bmk2csv(dataname)
else
grid_search(
run_experiment(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params
model_tuning_params=model_tuning_params,
params...
)
end
......@@ -39,6 +39,7 @@ elseif FROM_GRID_SEARCH
"$(replace(lowercase(dataname), " " => "_")).jls",
)
save_best(outcomes_file_path)
bmk2csv(dataname)
else
run_experiment(
counterfactual_data, test_data;
......
......@@ -56,17 +56,26 @@ params = (
dim_reduction=true,
)
if !GRID_SEARCH
run_experiment(
if GRID_SEARCH
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
params...
)
elseif FROM_GRID_SEARCH
outcomes_file_path = joinpath(
DEFAULT_OUTPUT_PATH,
"grid_search",
"$(replace(lowercase(dataname), " " => "_")).jls",
)
save_best(outcomes_file_path)
bmk2csv(dataname)
else
grid_search(
run_experiment(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
n_individuals=5
model_tuning_params=model_tuning_params,
params...
)
end
\ No newline at end of file
......@@ -24,17 +24,26 @@ params = (
# Best grid search params:
append_best_params!(params, dataname)
if !GRID_SEARCH
run_experiment(
if GRID_SEARCH
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
params...
)
elseif FROM_GRID_SEARCH
outcomes_file_path = joinpath(
DEFAULT_OUTPUT_PATH,
"grid_search",
"$(replace(lowercase(dataname), " " => "_")).jls",
)
save_best(outcomes_file_path)
bmk2csv(dataname)
else
grid_search(
run_experiment(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
model_tuning_params=model_tuning_params,
params...
)
end
\ No newline at end of file
include("setup_env.jl");
# User inputs:
all_data_sets = ["linearly_separable", "moons", "circles", "mnist", "fmnist", "gmsc", "german_credit", "california_housing"]
if "run-all" in ARGS
datanames = ["linearly_separable", "moons", "circles", "mnist", "fmnist", "gmsc"]
datanames = all_data_sets
elseif any(contains.(ARGS, "data="))
datanames = [ARGS[findall(contains.(ARGS, "data="))][1] |> x -> replace(x, "data=" => "")]
datanames = replace.(split(datanames[1], ","), " " => "")
else
@warn "No dataset specified, defaulting to all."
datanames = ["linearly_separable", "moons", "circles", "mnist", "fmnist", "gmsc"]
datanames = all_data_sets
end
# Linearly Separable
......
function save_best(outcomes_file_path::String)
# Just load the best model from the grid search:
@assert isfile(outcomes_file_path)
if !isfile(outcomes_file_path)
@info "No grid search file found at $(outcomes_file_path)."
return
end
outcomes = Serialization.deserialize(outcomes_file_path)
outcome = best_absolute_outcome_eccco_Δ(outcomes).outcome
exper = outcome.exper
......@@ -12,4 +15,17 @@ function save_best(outcomes_file_path::String)
Serialization.serialize(joinpath(output_path, "$(exper.save_name)_bmk.jls"), outcome.bmk)
Serialization.serialize(joinpath(output_path, "$(exper.save_name)_models.jls"), outcome.model_dict)
meta(outcome; save_output=true, params_path=params_path)
end
function bmk2csv(dataname::String)
bmk_path = joinpath(
DEFAULT_OUTPUT_PATH,
"$(replace(lowercase(dataname), " " => "_"))_bmk.jls",
)
bmk = Serialization.deserialize(bmk_path)
csv_path = joinpath(
DEFAULT_OUTPUT_PATH,
"$(replace(lowercase(dataname), " " => "_"))_bmk.csv",
)
CSV.write(csv_path, bmk()[:,Not(:ce)])
end
\ No newline at end of file
```{r}
library(data.table)
library(kableExtra)
```
```{r}
res_path <- "results/"
files <- list.files(res_path)
dt <- lapply(files[grepl("_benchmark.csv", files)], function(x) {
fread(file.path(res_path, x))
})
dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt)
dt[,ce:=NULL]
synth <- c("Moons", "Circles", "Linearly Separable")
dt[,source:=ifelse(dataname %in% synth, "synthetic", "real-world")]
reported_data <- c(
"MNIST",
"GMSC",
"Linearly Separable",
"Moons",
"Circles"
)
dt <- dt[dataname %in% reported_data]
dt[,non_valid:=variable=="validity" & value==0.0,.(sample,dataname,generator,model,target,factual,source)]
dt[,non_valid:=any(non_valid==TRUE),.(sample,dataname,generator,model,target,factual,source)]
dt_valid <- dt[non_valid==FALSE]
```
```{r}
tab <- dt[
,
.(
value=sprintf("%1.2f ± %1.2f", mean(value), sd(value)),
val = mean(value),
std = sd(value)
),
.(dataname, generator, model, variable, source)
]
tab$top_val = F
tab$one_std_wachter = F
tab$two_std_wachter = F
# Measures to be minimized:
min_measures <- c(
"distance",
"distance_from_energy",
"distance_from_targets",
"set_size_penalty"
)
tab[variable %in% min_measures,top_val:=val==min(val),.(model, dataname, variable)]
tab[variable %in% min_measures,top_val:=ifelse(rep(all(top_val),length(top_val)),F,top_val),.(model, dataname, variable)]
tab[variable %in% min_measures,two_std_wachter:=val+2*std<val[generator=="Wachter"],.(model, dataname, variable)]
tab[variable %in% min_measures,one_std_wachter:=val+1*std<val[generator=="Wachter"],.(model, dataname, variable)]
# Measures to be maximized:
max_measures <- c(
"validity",
"redundancy"
)
tab[variable %in% max_measures,top_val:=val==max(val),.(model, dataname, variable)]
tab[variable %in% max_measures,top_val:=ifelse(rep(all(top_val),length(top_val)),F,top_val),.(model, dataname, variable)]
tab[variable %in% max_measures,two_std_wachter:=val-2*std>val[generator=="Wachter"],.(model, dataname, variable)]
tab[variable %in% max_measures,one_std_wachter:=val-1*std>val[generator=="Wachter"],.(model, dataname, variable)]
# Add conditional formatting:
tab$value <- cell_spec(tab$value, "latex", bold=tab$top_val)
tab[one_std_wachter==T,value:=paste0(value,"*")]
tab[one_std_wachter==F,value:=paste0(value,"\\hphantom{*}")]
tab[two_std_wachter==T,value:=paste0(value,"*")]
tab[two_std_wachter==F,value:=paste0(value,"\\hphantom{*}")]
# Remove redundant columns:
tab[,val:=NULL]
tab[,std:=NULL]
tab[,top_val:=NULL]
tab[,two_std_wachter:=NULL]
tab[,one_std_wachter:=NULL]
```
```{r}
tab_valid <- dt_valid[
,
.(
value=sprintf("%1.2f ± %1.2f", mean(value), sd(value)),
val = mean(value),
std = sd(value)
),
.(dataname, generator, model, variable, source)
]
tab_valid$top_val = F
tab_valid$one_std_wachter = F
tab_valid$two_std_wachter = F
# Measures to be minimized:
min_measures <- c(
"distance",
"distance_from_energy",
"distance_from_targets",
"set_size_penalty"
)
tab_valid[variable %in% min_measures,top_val:=val==min(val),.(model, dataname, variable)]
tab_valid[variable %in% min_measures,top_val:=ifelse(rep(all(top_val),length(top_val)),F,top_val),.(model, dataname, variable)]
tab_valid[variable %in% min_measures,two_std_wachter:=val+2*std<val[generator=="Wachter"],.(model, dataname, variable)]
tab_valid[variable %in% min_measures,one_std_wachter:=val+1*std<val[generator=="Wachter"],.(model, dataname, variable)]
# Measures to be maximized:
max_measures <- c(
"validity",
"redundancy"
)
tab_valid[variable %in% max_measures,top_val:=val==max(val),.(model, dataname, variable)]
tab_valid[variable %in% max_measures,top_val:=ifelse(rep(all(top_val),length(top_val)),F,top_val),.(model, dataname, variable)]
tab_valid[variable %in% max_measures,two_std_wachter:=val-2*std>val[generator=="Wachter"],.(model, dataname, variable)]
tab_valid[variable %in% max_measures,one_std_wachter:=val-1*std>val[generator=="Wachter"],.(model, dataname, variable)]
# Add conditional formatting:
tab_valid$value <- cell_spec(tab_valid$value, "latex", bold=tab_valid$top_val)
tab_valid[one_std_wachter==T,value:=paste0(value,"*")]
tab_valid[one_std_wachter==F,value:=paste0(value,"\\hphantom{*}")]
tab_valid[two_std_wachter==T,value:=paste0(value,"*")]
tab_valid[two_std_wachter==F,value:=paste0(value,"\\hphantom{*}")]
# Remove redundant columns:
tab_valid[,val:=NULL]
tab_valid[,std:=NULL]
tab_valid[,top_val:=NULL]
tab_valid[,two_std_wachter:=NULL]
tab_valid[,one_std_wachter:=NULL]
```
## Main tables
```{r}
# Choices:
measures <- c("distance_from_energy", "distance_from_targets")
measure_names <- c("Unfaithfulness ↓","Implausibility ↓")
chosen_source <- "real-world"
chosen_data <- c(
"MNIST",
"GMSC"
)
tab_i <- tab
# Logic:
tab_i <- tab_i[variable %in% measures]
tab_i <- tab_i[source == chosen_source]
tab_i <- tab_i[dataname %in% chosen_data]
tab_i[,dataname:=factor(dataname,levels=chosen_data)]
tab_i <- dcast(tab_i, model + generator ~ dataname + variable)
col_names <- c(
"Model",
"Generator",
rep(measure_names,length(chosen_data))
)
caption <- sprintf(
"Results for %s datasets: sample averages +/- one standard deviation across counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \\label{tab:results-%s} \\newline",
chosen_source,
chosen_source
)
file_name <- sprintf(
"paper/contents/table-%s.tex",
chosen_source
)
sub_header <- rep(length(measures), length(chosen_data))
names(sub_header) <- chosen_data
header <- c(
" " = 2, sub_header
)
line_sep <- c(rep("",length(measures)-1),"\\addlinespace")
algin_cols <- c(rep('l',2),rep('c',ncol(tab_i)-2))
kbl(
tab_i, caption = caption,
align = algin_cols, col.names=col_names, booktabs = T, escape=F,
format="latex", linesep = line_sep
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
add_header_above(header) %>%
collapse_rows(columns = 1:2, latex_hline = "major", valign = "middle") %>%
save_kable(file_name)
```
```{r}
# Choices:
measures <- c("distance_from_energy", "distance_from_targets")
measure_names <- c("Unfaithfulness ↓","Implausibility ↓")
chosen_source <- "synthetic"
chosen_data <- c(
"Linearly Separable",
"Moons",
"Circles"
)
tab_i <- tab
# Logic:
tab_i <- tab_i[variable %in% measures]
tab_i <- tab_i[source == chosen_source]
tab_i <- tab_i[dataname %in% chosen_data]
tab_i[,dataname:=factor(dataname,levels=chosen_data)]
tab_i <- dcast(tab_i, model + generator ~ dataname + variable)
col_names <- c(
"Model",
"Generator",
rep(measure_names,length(chosen_data))
)
caption <- sprintf(
"Results for %s datasets: sample averages +/- one standard deviation across counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \\label{tab:results-%s} \\newline",
chosen_source,
chosen_source
)
file_name <- sprintf(
"paper/contents/table-%s.tex",
chosen_source
)
sub_header <- rep(length(measures), length(chosen_data))
names(sub_header) <- chosen_data
header <- c(
" " = 2, sub_header
)
line_sep <- c(rep("",length(measures)-1),"\\addlinespace")
algin_cols <- c(rep('l',2),rep('c',ncol(tab_i)-2))
kbl(
tab_i, caption = caption,
align = algin_cols, col.names=col_names, booktabs = T, escape=F,
format="latex", linesep = line_sep
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
add_header_above(header) %>%
collapse_rows(columns = 1:2, latex_hline = "major", valign = "middle") %>%
save_kable(file_name)
```
## Full table
```{r}
tab_full <- dcast(tab, dataname + model + generator ~ variable)
col_names <- c(
"Model",
"Data",
"Generator",
"Cost ↓",
"Unfaithfulness ↓",
"Implausibility ↓",
"Redundancy ↑",
"Uncertainty ↓",
"Validity ↑"
)
algin_cols <- c(rep('l',3),rep('c',ncol(tab_full)-3))
kbl(
tab_full, caption = "All results for all datasets: sample averages +/- one standard deviation over all counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \\label{tab:results-full} \\newline",
align = "c", col.names=col_names, booktabs = T, escape=F,
format="latex"
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
collapse_rows(columns = 1:3, latex_hline = "custom", valign = "top", custom_latex_hline = 1:2) %>%
save_kable("paper/contents/table_all.tex")
```
## Full table (valid only)
```{r}
tab_full <- dcast(tab_valid, dataname + model + generator ~ variable)
col_names <- c(
"Model",
"Data",
"Generator",
"Cost ↓",
"Unfaithfulness ↓",
"Implausibility ↓",
"Redundancy ↑",
"Uncertainty ↓",
"Validity ↑"
)
algin_cols <- c(rep('l',3),rep('c',ncol(tab_full)-3))
kbl(
tab_full, caption = "All results for all datasets: sample averages +/- one standard deviation over all valid counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \\label{tab:results-full-valid} \\newline",
align = "c", col.names=col_names, booktabs = T, escape=F,
format="latex"
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
collapse_rows(columns = 1:3, latex_hline = "custom", valign = "top", custom_latex_hline = 1:2) %>%
save_kable("paper/contents/table_all_valid.tex")
```
## EBM
```{r}
files <- list.files("artifacts/params/")
dt <- lapply(files[grepl(".csv", files)], function(x) {
fread(file.path("artifacts/params/", x))
})
dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt)
setcolorder(
dt,
c(
"dataname", "n_obs",
"n_hidden", "n_layers", "activation", "n_ens",
"epochs", "batch_size",
"jem_sampling_steps", "sgld_batch_size", "lambda"
)
)
dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))]
dt <- dt[order(dataname)]
dt_ebm <- dt[,.(dataname, jem_sampling_steps, sgld_batch_size, lambda)]
col_names <- c(
"Dataset",
"SGLD Steps", "Batch Size", "$\\lambda$"
)
kbl(
dt_ebm, caption = "EBM hyperparemeter choices for our experiments. \\label{tab:ebmparams} \\newline",
align = "r", col.names=col_names, booktabs = T, escape=F,
format="latex"
) %>%
kable_styling(font_size = 8) %>%
kable_paper(full_width = F) %>%
save_kable("paper/contents/table_ebm_params.tex")
```
## Experimental setup
```{r}
dt_exp <- dt[,.(dataname, n_obs, n_hidden, n_layers, activation, n_ens, epochs, batch_size)]
col_names <- c(
"Dataset", "Sample Size",
"Hidden Units", "Hidden Layers", "Activation", "Ensemble Size",
"Epochs", "Batch Size"
)
header <- c(" " = 2, "Network Architecture" = 4, "Training" = 2)
kbl(
dt_exp, caption = "Paremeter choices for our experiments. \\label{tab:params} \\newline",
align = "r", col.names=col_names, booktabs = T, escape=F,
format="latex"
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
add_header_above(header) %>%
save_kable("paper/contents/table_params.tex")
```
```{r}
files <- list.files("artifacts/params/generator")
dt <- lapply(files, function(x) {
fread(file.path("artifacts/params/generator", x))
})
dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt)
dt <- dt[,.(dataname,eta,λ1,λ3,λ3)]
dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))]
dt <- dt[order(dataname)]
col_names <- c(
"Dataset",
"$\\eta$", "$\\lambda_1$", "$\\lambda_2$", "$\\lambda_3$"
)
kbl(
dt, caption = "Generator hyperparameters. \\label{tab:genparams} \\newline",
align = "r", col.names=col_names, booktabs = T, escape=F,
format="latex"
) %>%
kable_styling(font_size = 8) %>%
kable_paper(full_width = F) %>%
save_kable("paper/contents/table_gen_params.tex")
```
```{r}
files <- list.files("artifacts/results/")
dt <- lapply(files[grepl("_model_performance.csv", files)], function(x) {
fread(file.path("artifacts/results/", x))
})
dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt)
dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))]
dt <- dt[order(dataname,mod_name)]
setcolorder(
dt,
c(
"dataname", "mod_name",
"acc", "precision", "f1score"
)
)
col_names <- c("Dataset", "Model", "Accuracy", "Precision", "F1-Score")
kbl(
dt, caption = "Various standard performance metrics for our different models grouped by dataset. \\label{tab:perf} \\newline",
align = "r", col.names=col_names, booktabs = T, escape=F,
format="latex", digits=2
) %>%
kable_styling(font_size = 8) %>%
kable_paper(full_width = F) %>%
add_header_above(c(" "=2, "Performance Metrics" = 3)) %>%
collapse_rows(columns = 1, latex_hline = "custom", valign = "top", custom_latex_hline = 1) %>%
save_kable("paper/contents/table_perf.tex")
```
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment