Skip to content
Snippets Groups Projects
Commit d1b3a9ba authored by Pat Alt's avatar Pat Alt
Browse files

uh

parent 8aee1501
No related branches found
No related tags found
No related merge requests found
Showing
with 114 additions and 47 deletions
acc,precision,f1score,mod_name
0.8668,0.8676755973746118,0.8667206508066643,JEM Ensemble
0.8692,0.8695976279120128,0.8691648105674503,MLP
0.8776,0.8776241679467486,0.877598041568665,MLP Ensemble
0.866,0.8660114781199538,0.8659989494317635,JEM
acc,precision,f1score,mod_name
1.0,1.0,1.0,MLP
0.976,0.9770992366412213,0.9759861680327868,JEM
acc,precision,f1score,mod_name,dataname
1.0,1.0,1.0,MLP,Circles
0.976,0.9770992366412213,0.9759861680327868,JEM,Circles
No preview for this file type
acc,precision,f1score,mod_name
0.7591,0.7630959766537089,0.759309436325091,JEM Ensemble
0.7801,0.7866657893024347,0.7817885884965358,MLP
0.7821,0.7899289201498965,0.7844916726299729,MLP Ensemble
0.7689,0.7769300969892394,0.7709547228694767,JEM
acc,precision,f1score,mod_name
0.738,0.7510506151794907,0.734550214588796,JEM Ensemble
0.741,0.746854398924903,0.73945523005902,MLP
0.732,0.7377867790550845,0.7303595072420607,MLP Ensemble
0.72,0.7544599896365386,0.7101881293057764,JEM
acc,precision,f1score,mod_name,dataname
0.7336923997606224,0.7502647333049264,0.7292095166233554,JEM Ensemble,GMSC
0.7474566128067026,0.7540904211718942,0.745797428709856,MLP,GMSC
0.7495511669658887,0.7534403970041377,0.7485866360360287,MLP Ensemble,GMSC
0.734590065828845,0.7523243523092521,0.7298431609679099,JEM,GMSC
No preview for this file type
No preview for this file type
acc,precision,f1score,mod_name
0.992,0.992,0.992,MLP
0.992,0.9921259842519685,0.9919994879672299,JEM
acc,precision,f1score,mod_name,dataname
0.992,0.992,0.992,MLP,Linearly Separable
0.992,0.9921259842519685,0.9919994879672299,JEM,Linearly Separable
No preview for this file type
acc,precision,f1score,mod_name
0.8956999999999999,0.8958883425639675,0.8943865992125204,JEM Ensemble
0.9476,0.947323276242302,0.9470241845063315,MLP
0.9495,0.9491766613039271,0.9489687025103262,MLP Ensemble
0.8314,0.8427605598842418,0.8272255498224308,JEM
acc,precision,f1score,mod_name,dataname
0.8956999999999999,0.8958883425639675,0.8943865992125204,JEM Ensemble,MNIST
0.9476,0.947323276242302,0.9470241845063315,MLP,MNIST
0.9495,0.9491766613039271,0.9489687025103262,MLP Ensemble,MNIST
0.8314,0.8427605598842418,0.8272255498224308,JEM,MNIST
No preview for this file type
acc,precision,f1score,mod_name
1.0,1.0,1.0,MLP
0.9967948717948718,0.9968152866242038,0.9967948388687424,JEM
acc,precision,f1score,mod_name,dataname
1.0,1.0,1.0,MLP,Moons
0.9967948717948718,0.9968152866242038,0.9967948388687424,JEM,Moons
No preview for this file type
......@@ -148,6 +148,7 @@ for (mod_name, model) in model_dict
_perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
_perf.dataname .= "Circles"
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"circles_model_performance.jls"), model_performance)
......
......@@ -11,7 +11,7 @@ _retrain = true
# Data:
test_size = 0.2
counterfactual_data, test_data = train_test_split(load_gmsc(); test_size=test_size)
counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=test_size)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
......@@ -115,6 +115,25 @@ else
end
```
```{julia}
params = DataFrame(
Dict(
:n_obs => Int.(round(n_obs/10)*10),
:epochs => epochs,
:batch_size => batch_size,
:n_hidden => n_hidden,
:n_layers => length(model_dict["MLP"].fitresult[1][1])-1,
:activation => string(activation),
:n_ens => n_ens,
:lambda => string(α[3]),
:jem_sampling_steps => jem.sampling_steps,
:sgld_batch_size => sampler.batch_size,
:dataname => "GMSC",
)
)
CSV.write(joinpath(params_path, "gmsc.csv"), params)
```
```{julia}
# Evaluate models:
......@@ -129,6 +148,7 @@ for (mod_name, mod) in model_dict
_perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
_perf.dataname .= "GMSC"
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"gmsc_model_performance.jls"), model_performance)
......
......@@ -151,6 +151,7 @@ for (mod_name, model) in model_dict
_perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
_perf.dataname .= "Linearly Separable"
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"linearly_separable_model_performance.jls"), model_performance)
......
......@@ -346,6 +346,7 @@ for (mod_name, mod) in model_dict
_perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
_perf.dataname .= "MNIST"
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance)
......
......@@ -148,6 +148,7 @@ for (mod_name, model) in model_dict
_perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
_perf.dataname .= "Moons"
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"moons_model_performance.jls"), model_performance)
......
......@@ -291,7 +291,7 @@ kbl(
```{r}
files <- list.files("artifacts/params/")
dt <- lapply(files, function(x) {
dt <- lapply(files[grepl(".csv", files)], function(x) {
fread(file.path("artifacts/params/", x))
})
dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt)
......@@ -304,6 +304,8 @@ setcolorder(
"jem_sampling_steps", "sgld_batch_size", "lambda"
)
)
dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))]
dt <- dt[order(dataname)]
dt_ebm <- dt[,.(dataname, jem_sampling_steps, sgld_batch_size, lambda)]
col_names <- c(
"Dataset",
......@@ -338,4 +340,55 @@ kbl(
kable_paper(full_width = F) %>%
add_header_above(header) %>%
save_kable("paper/contents/table_params.tex")
```
```{r}
files <- list.files("artifacts/params/generator")
dt <- lapply(files, function(x) {
fread(file.path("artifacts/params/generator", x))
})
dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt)
dt <- dt[,.(dataname,eta,λ1,λ3,λ3)]
dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))]
dt <- dt[order(dataname)]
col_names <- c(
"Dataset",
"$\\eta$", "$\\lambda_1$", "$\\lambda_2$", "$\\lambda_3$"
)
kbl(
dt, caption = "Generator hyperparameters. \\label{tab:genparams} \\newline",
align = "r", col.names=col_names, booktabs = T, escape=F,
format="latex"
) %>%
kable_styling(font_size = 8) %>%
kable_paper(full_width = F) %>%
save_kable("paper/contents/table_gen_params.tex")
```
```{r}
files <- list.files("artifacts/results/")
dt <- lapply(files[grepl("_model_performance.csv", files)], function(x) {
fread(file.path("artifacts/results/", x))
})
dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt)
dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))]
dt <- dt[order(dataname,mod_name)]
setcolorder(
dt,
c(
"dataname", "mod_name",
"acc", "precision", "f1score"
)
)
col_names <- c("Dataset", "Model", "Accuracy", "Precision", "F1-Score")
kbl(
dt, caption = "Various standard performance metrics for our different models grouped by dataset. \\label{tab:perf} \\newline",
align = "r", col.names=col_names, booktabs = T, escape=F,
format="latex", digits=2
) %>%
kable_styling(font_size = 8) %>%
kable_paper(full_width = F) %>%
add_header_above(c(" "=2, "Performance Metrics" = 3)) %>%
collapse_rows(columns = 1, latex_hline = "custom", valign = "top", custom_latex_hline = 1) %>%
save_kable("paper/contents/table_perf.tex")
```
\ No newline at end of file
......@@ -9,37 +9,37 @@
\cmidrule(l{3pt}r{3pt}){3-4} \cmidrule(l{3pt}r{3pt}){5-6}
Model & Generator & Unfaithfulness ↓ & Implausibility ↓ & Unfaithfulness ↓ & Implausibility ↓\\
\midrule
& ECCCo & \textbf{19.27 ± 5.02}** & 314.54 ± 32.54*\hphantom{*} & \textbf{79.18 ± 13.01}** & 19.67 ± 6.27**\\
& ECCCo & \textbf{19.27 ± 5.02}** & 314.54 ± 32.54*\hphantom{*} & \textbf{61.26 ± 13.46}** & 19.87 ± 2.37**\\
& REVISE & 188.54 ± 26.22*\hphantom{*} & \textbf{254.32 ± 41.55}** & 186.05 ± 31.81\hphantom{*}\hphantom{*} & \textbf{5.38 ± 1.89}**\\
& REVISE & 188.54 ± 26.22*\hphantom{*} & \textbf{254.32 ± 41.55}** & 152.22 ± 22.93\hphantom{*}\hphantom{*} & \textbf{5.01 ± 0.63}**\\
& Schut & 199.70 ± 28.43\hphantom{*}\hphantom{*} & 273.01 ± 39.60** & 185.40 ± 38.43\hphantom{*}\hphantom{*} & 6.54 ± 0.98**\\
& Schut & 199.70 ± 28.43\hphantom{*}\hphantom{*} & 273.01 ± 39.60** & 161.85 ± 21.18\hphantom{*}\hphantom{*} & 5.44 ± 0.97**\\
\multirow{-4}{*}{\raggedright\arraybackslash JEM} & Wachter & 222.81 ± 26.22\hphantom{*}\hphantom{*} & 361.38 ± 39.55\hphantom{*}\hphantom{*} & 188.81 ± 41.72\hphantom{*}\hphantom{*} & 71.97 ± 60.09\hphantom{*}\hphantom{*}\\
\multirow{-4}{*}{\raggedright\arraybackslash JEM} & Wachter & 222.81 ± 26.22\hphantom{*}\hphantom{*} & 361.38 ± 39.55\hphantom{*}\hphantom{*} & 170.58 ± 22.20\hphantom{*}\hphantom{*} & 78.21 ± 74.12\hphantom{*}\hphantom{*}\\
\cmidrule{1-6}
& ECCCo & \textbf{15.99 ± 3.06}** & 294.72 ± 30.75** & \textbf{79.65 ± 11.83}** & 17.81 ± 5.44**\\
& ECCCo & \textbf{15.99 ± 3.06}** & 294.72 ± 30.75** & \textbf{61.25 ± 14.17}** & 18.66 ± 3.05**\\
& REVISE & 173.05 ± 20.38** & \textbf{246.20 ± 37.74}** & 204.14 ± 36.13\hphantom{*}\hphantom{*} & \textbf{4.90 ± 0.95}**\\
& REVISE & 173.05 ± 20.38** & \textbf{246.20 ± 37.74}** & 158.77 ± 24.61\hphantom{*}\hphantom{*} & \textbf{4.77 ± 0.56}**\\
& Schut & 186.91 ± 22.98*\hphantom{*} & 264.68 ± 37.58** & 186.24 ± 36.18\hphantom{*}\hphantom{*} & 6.35 ± 1.22**\\
& Schut & 186.91 ± 22.98*\hphantom{*} & 264.68 ± 37.58** & 154.90 ± 38.82\hphantom{*}\hphantom{*} & 5.49 ± 0.72**\\
\multirow{-4}{*}{\raggedright\arraybackslash JEM Ensemble} & Wachter & 217.37 ± 23.93\hphantom{*}\hphantom{*} & 362.91 ± 39.40\hphantom{*}\hphantom{*} & 184.05 ± 23.11\hphantom{*}\hphantom{*} & 61.40 ± 48.29\hphantom{*}\hphantom{*}\\
\multirow{-4}{*}{\raggedright\arraybackslash JEM Ensemble} & Wachter & 217.37 ± 23.93\hphantom{*}\hphantom{*} & 362.91 ± 39.40\hphantom{*}\hphantom{*} & 147.47 ± 30.47\hphantom{*}\hphantom{*} & 80.44 ± 44.94\hphantom{*}\hphantom{*}\\
\cmidrule{1-6}
& ECCCo & \textbf{41.95 ± 6.50}** & 591.58 ± 36.24\hphantom{*}\hphantom{*} & \textbf{80.51 ± 16.59}** & 23.43 ± 6.09**\\
& ECCCo & \textbf{41.95 ± 6.50}** & 591.58 ± 36.24\hphantom{*}\hphantom{*} & \textbf{58.19 ± 15.18}** & 19.17 ± 3.76*\hphantom{*}\\
& REVISE & 365.69 ± 14.90*\hphantom{*} & 245.36 ± 39.69** & 180.18 ± 30.75\hphantom{*}\hphantom{*} & \textbf{5.05 ± 1.05}**\\
& REVISE & 365.69 ± 14.90*\hphantom{*} & 245.36 ± 39.69** & 171.10 ± 21.55\hphantom{*}\hphantom{*} & \textbf{4.99 ± 1.03}**\\
& Schut & 371.12 ± 19.99\hphantom{*}\hphantom{*} & \textbf{245.11 ± 35.72}** & 199.88 ± 45.58\hphantom{*}\hphantom{*} & 7.25 ± 1.88**\\
& Schut & 371.12 ± 19.99\hphantom{*}\hphantom{*} & \textbf{245.11 ± 35.72}** & 160.38 ± 31.67\hphantom{*}\hphantom{*} & 6.29 ± 1.23**\\
\multirow{-4}{*}{\raggedright\arraybackslash MLP} & Wachter & 384.76 ± 16.52\hphantom{*}\hphantom{*} & 359.21 ± 42.03\hphantom{*}\hphantom{*} & 196.33 ± 33.11\hphantom{*}\hphantom{*} & 87.52 ± 53.98\hphantom{*}\hphantom{*}\\
\multirow{-4}{*}{\raggedright\arraybackslash MLP} & Wachter & 384.76 ± 16.52\hphantom{*}\hphantom{*} & 359.21 ± 42.03\hphantom{*}\hphantom{*} & 171.64 ± 36.71\hphantom{*}\hphantom{*} & 26.40 ± 1.54\hphantom{*}\hphantom{*}\\
\cmidrule{1-6}
& ECCCo & \textbf{31.43 ± 3.91}** & 490.88 ± 27.19\hphantom{*}\hphantom{*} & \textbf{76.32 ± 14.56}** & 22.99 ± 8.31\hphantom{*}\hphantom{*}\\
& ECCCo & \textbf{31.43 ± 3.91}** & 490.88 ± 27.19\hphantom{*}\hphantom{*} & \textbf{62.06 ± 14.56}** & 18.38 ± 3.74**\\
& REVISE & 337.21 ± 11.68*\hphantom{*} & \textbf{244.84 ± 37.17}** & 184.04 ± 29.13\hphantom{*}\hphantom{*} & \textbf{5.25 ± 1.31}**\\
& REVISE & 337.21 ± 11.68*\hphantom{*} & \textbf{244.84 ± 37.17}** & 153.48 ± 30.61\hphantom{*}\hphantom{*} & \textbf{4.80 ± 0.80}**\\
& Schut & 344.60 ± 13.64*\hphantom{*} & 252.53 ± 37.92** & 214.74 ± 34.33\hphantom{*}\hphantom{*} & 6.18 ± 1.17**\\
& Schut & 344.60 ± 13.64*\hphantom{*} & 252.53 ± 37.92** & 166.85 ± 28.33\hphantom{*}\hphantom{*} & 5.86 ± 0.71**\\
\multirow{-4}{*}{\raggedright\arraybackslash MLP Ensemble} & Wachter & 358.51 ± 13.18\hphantom{*}\hphantom{*} & 352.63 ± 39.93\hphantom{*}\hphantom{*} & 193.41 ± 35.45\hphantom{*}\hphantom{*} & 12.71 ± 4.90\hphantom{*}\hphantom{*}\\
\multirow{-4}{*}{\raggedright\arraybackslash MLP Ensemble} & Wachter & 358.51 ± 13.18\hphantom{*}\hphantom{*} & 352.63 ± 39.93\hphantom{*}\hphantom{*} & 150.78 ± 26.59\hphantom{*}\hphantom{*} & 73.51 ± 33.64\hphantom{*}\hphantom{*}\\
\bottomrule
\end{tabular}}
\end{table}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment