Skip to content
Snippets Groups Projects
Commit 6bb094cf authored by Pat Alt's avatar Pat Alt
Browse files

uh

parent 0a068020
No related branches found
No related tags found
No related merge requests found
......@@ -27,5 +27,10 @@ function bmk2csv(dataname::String)
DEFAULT_OUTPUT_PATH,
"$(replace(lowercase(dataname), " " => "_"))_bmk.csv",
)
CSV.write(csv_path, bmk()[:,Not(:ce)])
bmk = bmk()
if "ce" names(bmk)
CSV.write(csv_path, bmk[:,Not(:ce)])
else
CSV.write(csv_path, bmk)
end
end
\ No newline at end of file
\begin{table}
\caption{Results for datasets from different domains: sample averages +/- one standard deviation across counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \label{tab:results-main} \newline}
\centering
\resizebox{\linewidth}{!}{
\begin{tabu} to \linewidth {>{\raggedright}X>{\raggedright}X>{\centering}X>{\centering}X>{\centering}X>{\centering}X>{\centering}X>{\centering}X}
\toprule
\multicolumn{2}{c}{ } & \multicolumn{2}{c}{GMSC} & \multicolumn{2}{c}{California Housing} & \multicolumn{2}{c}{German Credit} \\
\cmidrule(l{3pt}r{3pt}){3-4} \cmidrule(l{3pt}r{3pt}){5-6} \cmidrule(l{3pt}r{3pt}){7-8}
Model & Generator & Unfaithfulness ↓ & Implausibility ↓ & Unfaithfulness ↓ & Implausibility ↓ & Unfaithfulness ↓ & Implausibility ↓\\
\midrule
& ECCCo-L1 & 0.90 ± 0.26\hphantom{*}\hphantom{*} & 43.87 ± 0.12** & 1.02 ± 0.31\hphantom{*}\hphantom{*} & 35.00 ± 0.07** & 4.98 ± 0.76\hphantom{*}\hphantom{*} & 2.26 ± 0.13**\\
& ECCCo-L1 (no CP) & 0.89 ± 0.26\hphantom{*}\hphantom{*} & 39.64 ± 0.15** & 1.02 ± 0.31\hphantom{*}\hphantom{*} & 41.86 ± 0.12** & 5.00 ± 0.76\hphantom{*}\hphantom{*} & 2.42 ± 0.37**\\
& ECCCo-L1 (no EBM) & 0.89 ± 0.26\hphantom{*}\hphantom{*} & \textbf{34.01 ± 0.15}** & 1.02 ± 0.31\hphantom{*}\hphantom{*} & \textbf{30.29 ± 0.13}** & 5.00 ± 0.77\hphantom{*}\hphantom{*} & \textbf{2.22 ± 0.42}**\\
& ECCCo & 0.82 ± 0.27\hphantom{*}\hphantom{*} & 74.71 ± 3.49** & 0.74 ± 0.18*\hphantom{*} & 65.50 ± 1.12** & 4.53 ± 0.49\hphantom{*}\hphantom{*} & 5.64 ± 2.62\hphantom{*}\hphantom{*}\\
& ECCCo+ & 0.72 ± 0.15*\hphantom{*} & 97.53 ± 0.20\hphantom{*}\hphantom{*} & 0.74 ± 0.30\hphantom{*}\hphantom{*} & 84.16 ± 0.29** & \textbf{3.70 ± 0.38}** & 5.30 ± 2.12\hphantom{*}\hphantom{*}\\
& ECCCo (no CP) & 0.81 ± 0.26\hphantom{*}\hphantom{*} & 79.70 ± 0.57** & 0.74 ± 0.17*\hphantom{*} & 83.73 ± 0.29** & 4.53 ± 0.49\hphantom{*}\hphantom{*} & 5.05 ± 2.96\hphantom{*}\hphantom{*}\\
& ECCCo (no EBM) & 0.89 ± 0.26\hphantom{*}\hphantom{*} & 90.34 ± 0.35\hphantom{*}\hphantom{*} & 1.02 ± 0.31\hphantom{*}\hphantom{*} & 71.82 ± 0.18** & 5.00 ± 0.77\hphantom{*}\hphantom{*} & 5.20 ± 2.19\hphantom{*}\hphantom{*}\\
& REVISE & \textbf{0.61 ± 0.09}** & 117.40 ± 1.20\hphantom{*}\hphantom{*} & \textbf{0.68 ± 0.31}*\hphantom{*} & 94.71 ± 1.61** & 3.83 ± 0.60*\hphantom{*} & 5.56 ± 2.24\hphantom{*}\hphantom{*}\\
& Schut & 1.09 ± 0.21\hphantom{*}\hphantom{*} & 113.59 ± 0.49\hphantom{*}\hphantom{*} & 1.03 ± 0.27\hphantom{*}\hphantom{*} & 84.61 ± 0.44** & 4.95 ± 0.71\hphantom{*}\hphantom{*} & 6.43 ± 2.81\hphantom{*}\hphantom{*}\\
\multirow{-10}{*}{\raggedright\arraybackslash JEM} & Wachter & 0.88 ± 0.26\hphantom{*}\hphantom{*} & 83.19 ± 0.37\hphantom{*}\hphantom{*} & 1.02 ± 0.31\hphantom{*}\hphantom{*} & 110.44 ± 0.42\hphantom{*}\hphantom{*} & 5.00 ± 0.76\hphantom{*}\hphantom{*} & 6.23 ± 1.45\hphantom{*}\hphantom{*}\\
\cmidrule{1-8}
& ECCCo-L1 & 1.01 ± 0.77\hphantom{*}\hphantom{*} & 43.76 ± 0.16** & 0.99 ± 0.33\hphantom{*}\hphantom{*} & 35.19 ± 0.09** & 4.81 ± 0.64\hphantom{*}\hphantom{*} & 2.38 ± 0.34**\\
& ECCCo-L1 (no CP) & 0.99 ± 0.77\hphantom{*}\hphantom{*} & 39.79 ± 0.19** & 0.99 ± 0.33\hphantom{*}\hphantom{*} & 42.06 ± 0.12** & 4.83 ± 0.62\hphantom{*}\hphantom{*} & 2.57 ± 0.49**\\
& ECCCo-L1 (no EBM) & 1.00 ± 0.77\hphantom{*}\hphantom{*} & \textbf{33.79 ± 0.23}** & 0.98 ± 0.32\hphantom{*}\hphantom{*} & \textbf{30.52 ± 0.16}** & 4.84 ± 0.66\hphantom{*}\hphantom{*} & \textbf{2.38 ± 0.35}**\\
& ECCCo & 2.03 ± 1.30\hphantom{*}\hphantom{*} & 80.15 ± 1.86** & 3.41 ± 2.28\hphantom{*}\hphantom{*} & 67.91 ± 1.63** & 4.84 ± 0.66\hphantom{*}\hphantom{*} & 6.58 ± 1.69\hphantom{*}\hphantom{*}\\
& ECCCo+ & 1.64 ± 1.01\hphantom{*}\hphantom{*} & 98.25 ± 0.57\hphantom{*}\hphantom{*} & 2.71 ± 2.32\hphantom{*}\hphantom{*} & 82.72 ± 1.12** & 3.79 ± 0.39** & 6.80 ± 1.65\hphantom{*}\hphantom{*}\\
& ECCCo (no CP) & 2.02 ± 1.30\hphantom{*}\hphantom{*} & 82.52 ± 1.18*\hphantom{*} & 3.40 ± 2.28\hphantom{*}\hphantom{*} & 88.72 ± 2.28** & 4.84 ± 0.66\hphantom{*}\hphantom{*} & 6.82 ± 1.56\hphantom{*}\hphantom{*}\\
& ECCCo (no EBM) & 1.00 ± 0.77\hphantom{*}\hphantom{*} & 92.86 ± 1.05\hphantom{*}\hphantom{*} & 0.98 ± 0.32\hphantom{*}\hphantom{*} & 75.47 ± 1.60** & 4.84 ± 0.66\hphantom{*}\hphantom{*} & 6.65 ± 1.87\hphantom{*}\hphantom{*}\\
& REVISE & \textbf{0.71 ± 0.37}\hphantom{*}\hphantom{*} & 118.36 ± 1.68\hphantom{*}\hphantom{*} & \textbf{0.64 ± 0.19}*\hphantom{*} & 98.98 ± 0.23** & \textbf{3.70 ± 0.23}** & 6.78 ± 0.40\hphantom{*}\hphantom{*}\\
& Schut & 1.32 ± 0.72\hphantom{*}\hphantom{*} & 114.37 ± 1.21\hphantom{*}\hphantom{*} & 1.02 ± 0.31\hphantom{*}\hphantom{*} & 87.66 ± 2.05** & 4.92 ± 0.71\hphantom{*}\hphantom{*} & 7.86 ± 1.41\hphantom{*}\hphantom{*}\\
\multirow{-10}{*}{\raggedright\arraybackslash MLP} & Wachter & 0.99 ± 0.77\hphantom{*}\hphantom{*} & 84.37 ± 0.99\hphantom{*}\hphantom{*} & 0.98 ± 0.32\hphantom{*}\hphantom{*} & 114.38 ± 2.14\hphantom{*}\hphantom{*} & 4.84 ± 0.66\hphantom{*}\hphantom{*} & 6.58 ± 2.00\hphantom{*}\hphantom{*}\\
\bottomrule
\end{tabu}}
\end{table}
\begin{table}
\caption{Results for tabular datasets: sample averages +/- one standard deviation across counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \label{tab:results-tabular} \newline}
\centering
\resizebox{\linewidth}{!}{
\begin{tabu} to \linewidth {>{\raggedright}X>{\raggedright}X>{\centering}X>{\centering}X>{\centering}X>{\centering}X>{\centering}X>{\centering}X}
\toprule
\multicolumn{2}{c}{ } & \multicolumn{2}{c}{GMSC} & \multicolumn{2}{c}{California Housing} & \multicolumn{2}{c}{German Credit} \\
\cmidrule(l{3pt}r{3pt}){3-4} \cmidrule(l{3pt}r{3pt}){5-6} \cmidrule(l{3pt}r{3pt}){7-8}
Model & Generator & Unfaithfulness ↓ & Implausibility ↓ & Unfaithfulness ↓ & Implausibility ↓ & Unfaithfulness ↓ & Implausibility ↓\\
\midrule
& ECCCo & 37.08 ± 0.45\hphantom{*}\hphantom{*} & 1.71 ± 4.31\hphantom{*}\hphantom{*} & 39.38 ± 0.21\hphantom{*}\hphantom{*} & 1.72 ± 2.09\hphantom{*}\hphantom{*} & 2.40 ± 0.33\hphantom{*}\hphantom{*} & 4.75 ± 0.79\hphantom{*}\hphantom{*}\\
& ECCCo (no CP) & 40.72 ± 0.46\hphantom{*}\hphantom{*} & 1.70 ± 4.33\hphantom{*}\hphantom{*} & 41.33 ± 0.30\hphantom{*}\hphantom{*} & 1.72 ± 2.09\hphantom{*}\hphantom{*} & 2.51 ± 0.34\hphantom{*}\hphantom{*} & 4.74 ± 0.79\hphantom{*}\hphantom{*}\\
& ECCCo (no EBM) & 49.88 ± 0.32\hphantom{*}\hphantom{*} & 1.69 ± 4.33\hphantom{*}\hphantom{*} & 38.30 ± 0.28\hphantom{*}\hphantom{*} & 1.72 ± 2.08\hphantom{*}\hphantom{*} & 3.11 ± 0.27\hphantom{*}\hphantom{*} & 4.76 ± 0.80\hphantom{*}\hphantom{*}\\
& ECCCo-Δ & 108.32 ± 2.14\hphantom{*}\hphantom{*} & 1.68 ± 4.30\hphantom{*}\hphantom{*} & 93.36 ± 0.83\hphantom{*}\hphantom{*} & 1.72 ± 2.13\hphantom{*}\hphantom{*} & 7.32 ± 1.26\hphantom{*}\hphantom{*} & 4.76 ± 0.80\hphantom{*}\hphantom{*}\\
& ECCCo-Δ (no CP) & 103.50 ± 1.56\hphantom{*}\hphantom{*} & 1.67 ± 4.30\hphantom{*}\hphantom{*} & 98.37 ± 0.89\hphantom{*}\hphantom{*} & 1.72 ± 2.12\hphantom{*}\hphantom{*} & 6.87 ± 1.27\hphantom{*}\hphantom{*} & 4.76 ± 0.80\hphantom{*}\hphantom{*}\\
& ECCCo-Δ (no EBM) & 80.25 ± 1.30\hphantom{*}\hphantom{*} & 1.69 ± 4.33\hphantom{*}\hphantom{*} & 114.56 ± 1.15\hphantom{*}\hphantom{*} & 1.72 ± 2.08\hphantom{*}\hphantom{*} & 7.61 ± 1.51\hphantom{*}\hphantom{*} & 4.76 ± 0.80\hphantom{*}\hphantom{*}\\
& REVISE & 93.47 ± 0.76\hphantom{*}\hphantom{*} & 1.47 ± 3.73\hphantom{*}\hphantom{*} & 96.92 ± 2.21\hphantom{*}\hphantom{*} & 0.94 ± 1.31\hphantom{*}\hphantom{*} & 7.35 ± 1.29\hphantom{*}\hphantom{*} & 3.55 ± 0.21\hphantom{*}\hphantom{*}\\
& Schut & 85.92 ± 1.95\hphantom{*}\hphantom{*} & 1.85 ± 4.30\hphantom{*}\hphantom{*} & 120.89 ± 1.23\hphantom{*}\hphantom{*} & 1.54 ± 2.04\hphantom{*}\hphantom{*} & 7.08 ± 1.42\hphantom{*}\hphantom{*} & 4.85 ± 0.79\hphantom{*}\hphantom{*}\\
\multirow{-9}{*}{\raggedright\arraybackslash MLP} & Wachter & 99.17 ± 2.21\hphantom{*}\hphantom{*} & 1.69 ± 4.33\hphantom{*}\hphantom{*} & 91.30 ± 1.54\hphantom{*}\hphantom{*} & 1.72 ± 2.07\hphantom{*}\hphantom{*} & 7.88 ± 0.73\hphantom{*}\hphantom{*} & 4.76 ± 0.80\hphantom{*}\hphantom{*}\\
\bottomrule
\end{tabu}}
\end{table}
......@@ -15,6 +15,29 @@ dt[,ce:=NULL]
synth <- c("Moons", "Circles", "Linearly Separable")
tabular <- c("GMSC", "German Credit", "California Housing")
dt[,source:=ifelse(dataname %in% synth, "synthetic", ifelse(dataname %in% tabular, "tabular", "vision"))]
```
```{r}
# Generator names
dt[,generator:=factor(generator)]
levels(dt$generator)[match("ECCCo",levels(dt$generator))] <- "ECCCo-L1"
levels(dt$generator)[match("ECCCo (no CP)",levels(dt$generator))] <- "ECCCo-L1 (no CP)"
levels(dt$generator)[match("ECCCo (no EBM)",levels(dt$generator))] <- "ECCCo-L1 (no EBM)"
levels(dt$generator)[match("ECCCo-Δ",levels(dt$generator))] <- "ECCCo"
levels(dt$generator)[match("ECCCo-Δ (latent)",levels(dt$generator))] <- "ECCCo+"
levels(dt$generator)[match("ECCCo-Δ (no CP)",levels(dt$generator))] <- "ECCCo (no CP)"
levels(dt$generator)[match("ECCCo-Δ (no EBM)",levels(dt$generator))] <- "ECCCo (no EBM)"
```
```{r}
# Adjust measure names
dt[source=="vision" & variable=="distance_from_targets_ssim", variable:="implausibility"]
dt[source=="vision" & variable=="distance_from_energy_ssim", variable:="unfaithfulness"]
dt[source!="vision" & variable=="distance_from_targets_l2", variable:="implausibility"]
dt[source!="vision" & variable=="distance_from_energy_l2", variable:="unfaithfulness"]
```
```{r}
dt[,non_valid:=variable=="validity" & value==0.0,.(sample,dataname,generator,model,target,factual,source)]
dt[,non_valid:=any(non_valid==TRUE),.(sample,dataname,generator,model,target,factual,source)]
dt_valid <- dt[non_valid==FALSE]
......@@ -38,8 +61,12 @@ tab$two_std_wachter = F
# Measures to be minimized:
min_measures <- c(
"distance",
"implausibility",
"unfaithfulness",
"distance_from_energy",
"distance_from_energy_l2",
"distance_from_targets",
"distance_from_targets_l2",
"set_size_penalty"
)
tab[variable %in% min_measures,top_val:=val==min(val),.(model, dataname, variable)]
......@@ -90,8 +117,12 @@ tab_valid$two_std_wachter = F
# Measures to be minimized:
min_measures <- c(
"distance",
"implausibility",
"unfaithfulness",
"distance_from_energy",
"distance_from_energy_l2",
"distance_from_targets",
"distance_from_targets_l2",
"set_size_penalty"
)
tab_valid[variable %in% min_measures,top_val:=val==min(val),.(model, dataname, variable)]
......@@ -128,19 +159,86 @@ tab_valid[,one_std_wachter:=NULL]
```{r}
# Choices:
measures <- c("distance_from_energy", "distance_from_targets")
measure_names <- c("Unfaithfulness ↓","Implausibility ↓")
chosen_source <- "real-world"
measures <- c(
"unfaithfulness",
"implausibility"
)
measure_names <- c(
"Unfaithfulness ↓",
"Implausibility ↓"
)
# Order:
chosen_data <- c(
"Linearly Separable",
"GMSC",
"MNIST",
"GMSC"
)
chosen_model <- c(
"MLP",
"JEM",
"LeNet-5"
)
tab_i <- tab
# Logic:
tab_i <- tab_i[variable %in% measures]
tab_i <- tab_i[dataname %in% chosen_data]
tab_i <- tab_i[model %in% chosen_model]
tab_i[,dataname:=factor(dataname,levels=chosen_data)]
tab_i <- dcast(tab_i, model + generator ~ dataname + variable)
col_names <- c(
"Model",
"Generator",
rep(measure_names,length(chosen_data))
)
caption <- "Results for datasets from different domains: sample averages +/- one standard deviation across counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \\label{tab:results-main} \\newline"
file_name <- "paper/contents/table-main.tex"
sub_header <- rep(length(measures), length(chosen_data))
names(sub_header) <- chosen_data
header <- c(
" " = 2, sub_header
)
line_sep <- c(rep("",length(measures)-1),"\\addlinespace")
algin_cols <- c(rep('l',2),rep('c',ncol(tab_i)-2))
kbl(
tab_i, caption = caption,
align = algin_cols, col.names=col_names, booktabs = T, escape=F,
format="latex", linesep = line_sep
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = T) %>%
add_header_above(header) %>%
collapse_rows(columns = 1:2, latex_hline = "major", valign = "middle") %>%
save_kable(file_name)
```
```{r}
# Choices:
measures <- c(
"distance_from_energy_l2",
"distance_from_targets_l2"
)
measure_names <- c(
"Unfaithfulness ↓",
"Implausibility ↓"
)
chosen_source <- "tabular"
# Order:
chosen_data <- c(
"GMSC",
"California Housing",
"German Credit"
)
chosen_model <- "MLP"
tab_i <- tab
# Logic:
tab_i <- tab_i[variable %in% measures]
tab_i <- tab_i[source == chosen_source]
tab_i <- tab_i[dataname %in% chosen_data]
tab_i <- tab_i[model == chosen_model]
tab_i[,dataname:=factor(dataname,levels=chosen_data)]
tab_i <- dcast(tab_i, model + generator ~ dataname + variable)
col_names <- c(
......@@ -170,7 +268,7 @@ kbl(
format="latex", linesep = line_sep
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
add_header_above(header) %>%
collapse_rows(columns = 1:2, latex_hline = "major", valign = "middle") %>%
save_kable(file_name)
......@@ -178,8 +276,15 @@ kbl(
```{r}
# Choices:
measures <- c("distance_from_energy", "distance_from_targets")
measure_names <- c("Unfaithfulness ↓","Implausibility ↓")
# Choices:
measures <- c(
"distance_from_energy_l2",
"distance_from_targets_l2"
)
measure_names <- c(
"Unfaithfulness ↓",
"Implausibility ↓"
)
chosen_source <- "synthetic"
chosen_data <- c(
"Linearly Separable",
......@@ -221,7 +326,67 @@ kbl(
format="latex", linesep = line_sep
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = F) %>%
kable_paper(full_width = T) %>%
add_header_above(header) %>%
collapse_rows(columns = 1:2, latex_hline = "major", valign = "middle") %>%
save_kable(file_name)
```
```{r}
# Choices:
measures <- c(
"distance_from_energy_l2",
"distance_from_targets_l2"
)
measure_names <- c(
"Unfaithfulness ↓",
"Implausibility ↓"
)
chosen_source <- "vision"
# Order:
chosen_data <- c(
"MNIST",
"Fashion MNIST",
)
chosen_model <- "LeNet-5"
tab_i <- tab
# Logic:
tab_i <- tab_i[variable %in% measures]
tab_i <- tab_i[source == chosen_source]
tab_i <- tab_i[dataname %in% chosen_data]
tab_i <- tab_i[model == chosen_model]
tab_i[,dataname:=factor(dataname,levels=chosen_data)]
tab_i <- dcast(tab_i, model + generator ~ dataname + variable)
col_names <- c(
"Model",
"Generator",
rep(measure_names,length(chosen_data))
)
caption <- sprintf(
"Results for %s datasets: sample averages +/- one standard deviation across counterfactuals. Best outcomes are highlighted in bold. Asterisks indicate that the given value is more than one (*) or two (**) standard deviations away from the baseline (Wachter). \\label{tab:results-%s} \\newline",
chosen_source,
chosen_source
)
file_name <- sprintf(
"paper/contents/table-%s.tex",
chosen_source
)
sub_header <- rep(length(measures), length(chosen_data))
names(sub_header) <- chosen_data
header <- c(
" " = 2, sub_header
)
line_sep <- c(rep("",length(measures)-1),"\\addlinespace")
algin_cols <- c(rep('l',2),rep('c',ncol(tab_i)-2))
kbl(
tab_i, caption = caption,
align = algin_cols, col.names=col_names, booktabs = T, escape=F,
format="latex", linesep = line_sep
) %>%
kable_styling(latex_options = c("scale_down")) %>%
kable_paper(full_width = T) %>%
add_header_above(header) %>%
collapse_rows(columns = 1:2, latex_hline = "major", valign = "middle") %>%
save_kable(file_name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment