diff --git a/artifacts/results/cal_housing_model_performance.csv b/artifacts/results/cal_housing_model_performance.csv deleted file mode 100644 index c60f968a240706b4a774108f1ae72ec19f20e72b..0000000000000000000000000000000000000000 --- a/artifacts/results/cal_housing_model_performance.csv +++ /dev/null @@ -1,5 +0,0 @@ -acc,precision,f1score,mod_name -0.8668,0.8676755973746118,0.8667206508066643,JEM Ensemble -0.8692,0.8695976279120128,0.8691648105674503,MLP -0.8776,0.8776241679467486,0.877598041568665,MLP Ensemble -0.866,0.8660114781199538,0.8659989494317635,JEM diff --git a/artifacts/results/circles_model_performance.csv b/artifacts/results/circles_model_performance.csv index 30613afe00a7a089d1c149adf070dc0f75e4eead..db9d66eaa716c2fb164f7ef5ed09656b9cbf0b16 100644 --- a/artifacts/results/circles_model_performance.csv +++ b/artifacts/results/circles_model_performance.csv @@ -1,3 +1,3 @@ -acc,precision,f1score,mod_name -1.0,1.0,1.0,MLP -0.976,0.9770992366412213,0.9759861680327868,JEM +acc,precision,f1score,mod_name,dataname +1.0,1.0,1.0,MLP,Circles +0.976,0.9770992366412213,0.9759861680327868,JEM,Circles diff --git a/artifacts/results/circles_model_performance.jls b/artifacts/results/circles_model_performance.jls index a4a2bf3ae155af36350064bc9e6ab7f6b8dbf192..2c3465efbd6fb2f778ee19a3e50680c3d0ba126c 100644 Binary files a/artifacts/results/circles_model_performance.jls and b/artifacts/results/circles_model_performance.jls differ diff --git a/artifacts/results/fashion_mnist_model_performance.csv b/artifacts/results/fashion_mnist_model_performance.csv deleted file mode 100644 index 5a0e92b7c12944df90d2a99b072b69345753d83d..0000000000000000000000000000000000000000 --- a/artifacts/results/fashion_mnist_model_performance.csv +++ /dev/null @@ -1,5 +0,0 @@ -acc,precision,f1score,mod_name -0.7591,0.7630959766537089,0.759309436325091,JEM Ensemble -0.7801,0.7866657893024347,0.7817885884965358,MLP -0.7821,0.7899289201498965,0.7844916726299729,MLP Ensemble -0.7689,0.7769300969892394,0.7709547228694767,JEM diff --git a/artifacts/results/gmsc_model_performance.csv b/artifacts/results/gmsc_model_performance.csv index 0011835c2ee8257f3468709f3f945b599862131d..2e09b87a612d06a555add7637ba2d78abd85d851 100644 --- a/artifacts/results/gmsc_model_performance.csv +++ b/artifacts/results/gmsc_model_performance.csv @@ -1,5 +1,5 @@ -acc,precision,f1score,mod_name -0.738,0.7510506151794907,0.734550214588796,JEM Ensemble -0.741,0.746854398924903,0.73945523005902,MLP -0.732,0.7377867790550845,0.7303595072420607,MLP Ensemble -0.72,0.7544599896365386,0.7101881293057764,JEM +acc,precision,f1score,mod_name,dataname +0.7336923997606224,0.7502647333049264,0.7292095166233554,JEM Ensemble,GMSC +0.7474566128067026,0.7540904211718942,0.745797428709856,MLP,GMSC +0.7495511669658887,0.7534403970041377,0.7485866360360287,MLP Ensemble,GMSC +0.734590065828845,0.7523243523092521,0.7298431609679099,JEM,GMSC diff --git a/artifacts/results/gmsc_model_performance.jls b/artifacts/results/gmsc_model_performance.jls index d6088129d9ce90e35ac2cd03638a9b211efac9ca..cc64dcbbef82c039d41db58d55559b1c0d4c100f 100644 Binary files a/artifacts/results/gmsc_model_performance.jls and b/artifacts/results/gmsc_model_performance.jls differ diff --git a/artifacts/results/gmsc_models.jls b/artifacts/results/gmsc_models.jls index f075946026ee81584aaf78a9d68334a3c1214a2f..4daad24e3aa640c9ad65265ed50624161ce40517 100644 Binary files a/artifacts/results/gmsc_models.jls and b/artifacts/results/gmsc_models.jls differ diff --git a/artifacts/results/linearly_separable_model_performance.csv b/artifacts/results/linearly_separable_model_performance.csv index 77f36eb55b38a6b87d14179716c4f88d3369260d..b553ee647749944dee7df1fafed939a9231f9659 100644 --- a/artifacts/results/linearly_separable_model_performance.csv +++ b/artifacts/results/linearly_separable_model_performance.csv @@ -1,3 +1,3 @@ -acc,precision,f1score,mod_name -0.992,0.992,0.992,MLP -0.992,0.9921259842519685,0.9919994879672299,JEM +acc,precision,f1score,mod_name,dataname +0.992,0.992,0.992,MLP,Linearly Separable +0.992,0.9921259842519685,0.9919994879672299,JEM,Linearly Separable diff --git a/artifacts/results/linearly_separable_model_performance.jls b/artifacts/results/linearly_separable_model_performance.jls index 1aae630be0fb6da6f8d1e07ba15c7bb57b9ae70b..2da7edaff6f46dceb83f9dfeb037d1a89ca2faa2 100644 Binary files a/artifacts/results/linearly_separable_model_performance.jls and b/artifacts/results/linearly_separable_model_performance.jls differ diff --git a/artifacts/results/mnist_model_performance.csv b/artifacts/results/mnist_model_performance.csv index 22447ffbc890c0a5c2aade2a2da2cf03ad04d255..715920f312de0a4bbec461f8b2ef6d4ee67c9f41 100644 --- a/artifacts/results/mnist_model_performance.csv +++ b/artifacts/results/mnist_model_performance.csv @@ -1,5 +1,5 @@ -acc,precision,f1score,mod_name -0.8956999999999999,0.8958883425639675,0.8943865992125204,JEM Ensemble -0.9476,0.947323276242302,0.9470241845063315,MLP -0.9495,0.9491766613039271,0.9489687025103262,MLP Ensemble -0.8314,0.8427605598842418,0.8272255498224308,JEM +acc,precision,f1score,mod_name,dataname +0.8956999999999999,0.8958883425639675,0.8943865992125204,JEM Ensemble,MNIST +0.9476,0.947323276242302,0.9470241845063315,MLP,MNIST +0.9495,0.9491766613039271,0.9489687025103262,MLP Ensemble,MNIST +0.8314,0.8427605598842418,0.8272255498224308,JEM,MNIST diff --git a/artifacts/results/mnist_model_performance.jls b/artifacts/results/mnist_model_performance.jls index 4b212fbb622814dc04bf5609194877827f2aaf04..49a7500e881a6ac8f064a9e82d432c1f19684d5f 100644 Binary files a/artifacts/results/mnist_model_performance.jls and b/artifacts/results/mnist_model_performance.jls differ diff --git a/artifacts/results/moons_model_performance.csv b/artifacts/results/moons_model_performance.csv index 29a75b9e8e1f213a070ecbd921ee2c616d65a740..9384197167b404b91c9083cd46ab924a3dc01031 100644 --- a/artifacts/results/moons_model_performance.csv +++ b/artifacts/results/moons_model_performance.csv @@ -1,3 +1,3 @@ -acc,precision,f1score,mod_name -1.0,1.0,1.0,MLP -0.9967948717948718,0.9968152866242038,0.9967948388687424,JEM +acc,precision,f1score,mod_name,dataname +1.0,1.0,1.0,MLP,Moons +0.9967948717948718,0.9968152866242038,0.9967948388687424,JEM,Moons diff --git a/artifacts/results/moons_model_performance.jls b/artifacts/results/moons_model_performance.jls index 4a103b22be9af2c1714fc1c08fbf3a4c3a157617..2e3a2c0adb69210bbdaf538ac1aea34e529e7fd0 100644 Binary files a/artifacts/results/moons_model_performance.jls and b/artifacts/results/moons_model_performance.jls differ diff --git a/notebooks/circles.qmd b/notebooks/circles.qmd index 6b6aa0cc894317c6d462f8090e75279e7b073d1c..107f0e9f1f273d811191ea62dde66ad3ae91101d 100644 --- a/notebooks/circles.qmd +++ b/notebooks/circles.qmd @@ -148,6 +148,7 @@ for (mod_name, model) in model_dict _perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name + _perf.dataname .= "Circles" model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"circles_model_performance.jls"), model_performance) diff --git a/notebooks/gmsc.qmd b/notebooks/gmsc.qmd index b469a56a3f4149b5aaa89ea1cc2ea8160eb59a0f..41131a500bc3bc87685a21cbfb39f16b4d06f897 100644 --- a/notebooks/gmsc.qmd +++ b/notebooks/gmsc.qmd @@ -11,7 +11,7 @@ _retrain = true # Data: test_size = 0.2 -counterfactual_data, test_data = train_test_split(load_gmsc(); test_size=test_size) +counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=test_size) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) labels = counterfactual_data.output_encoder.labels @@ -115,6 +115,25 @@ else end ``` +```{julia} +params = DataFrame( + Dict( + :n_obs => Int.(round(n_obs/10)*10), + :epochs => epochs, + :batch_size => batch_size, + :n_hidden => n_hidden, + :n_layers => length(model_dict["MLP"].fitresult[1][1])-1, + :activation => string(activation), + :n_ens => n_ens, + :lambda => string(α[3]), + :jem_sampling_steps => jem.sampling_steps, + :sgld_batch_size => sampler.batch_size, + :dataname => "GMSC", + ) +) +CSV.write(joinpath(params_path, "gmsc.csv"), params) +``` + ```{julia} # Evaluate models: @@ -129,6 +148,7 @@ for (mod_name, mod) in model_dict _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name + _perf.dataname .= "GMSC" model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"gmsc_model_performance.jls"), model_performance) diff --git a/notebooks/linearly_separable.qmd b/notebooks/linearly_separable.qmd index c81c9b1ddd5f6a1ff57cf39a6859ddd68da1568c..fea25855ddab36a8eed5b6c233fc23bbee487f11 100644 --- a/notebooks/linearly_separable.qmd +++ b/notebooks/linearly_separable.qmd @@ -151,6 +151,7 @@ for (mod_name, model) in model_dict _perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name + _perf.dataname .= "Linearly Separable" model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"linearly_separable_model_performance.jls"), model_performance) diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 78de409ace6d779d7bc3cc0fe2ccca4fabb9af33..afe54d16141b3738fc8788e00b7c05c324883975 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -346,6 +346,7 @@ for (mod_name, mod) in model_dict _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name + _perf.dataname .= "MNIST" model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance) diff --git a/notebooks/moons.qmd b/notebooks/moons.qmd index 263c6382ab5505732943bac0f92bbef33d8b720a..6d1c716301253c23c5b4d1be6b3a21312233e068 100644 --- a/notebooks/moons.qmd +++ b/notebooks/moons.qmd @@ -148,6 +148,7 @@ for (mod_name, model) in model_dict _perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name + _perf.dataname .= "Moons" model_performance = vcat(model_performance, _perf) end Serialization.serialize(joinpath(output_path,"moons_model_performance.jls"), model_performance) diff --git a/notebooks/tables.Rmd b/notebooks/tables.Rmd index 421a3be7d662c9a325e4bd188072eeb65b76907e..185837755ec264d29b4d5149d0d9ce2b1007c31c 100644 --- a/notebooks/tables.Rmd +++ b/notebooks/tables.Rmd @@ -291,7 +291,7 @@ kbl( ```{r} files <- list.files("artifacts/params/") -dt <- lapply(files, function(x) { +dt <- lapply(files[grepl(".csv", files)], function(x) { fread(file.path("artifacts/params/", x)) }) dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt) @@ -304,6 +304,8 @@ setcolorder( "jem_sampling_steps", "sgld_batch_size", "lambda" ) ) +dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))] +dt <- dt[order(dataname)] dt_ebm <- dt[,.(dataname, jem_sampling_steps, sgld_batch_size, lambda)] col_names <- c( "Dataset", @@ -338,4 +340,55 @@ kbl( kable_paper(full_width = F) %>% add_header_above(header) %>% save_kable("paper/contents/table_params.tex") +``` + +```{r} +files <- list.files("artifacts/params/generator") +dt <- lapply(files, function(x) { + fread(file.path("artifacts/params/generator", x)) +}) +dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt) +dt <- dt[,.(dataname,eta,λ1,λ3,λ3)] +dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))] +dt <- dt[order(dataname)] +col_names <- c( + "Dataset", + "$\\eta$", "$\\lambda_1$", "$\\lambda_2$", "$\\lambda_3$" +) +kbl( + dt, caption = "Generator hyperparameters. \\label{tab:genparams} \\newline", + align = "r", col.names=col_names, booktabs = T, escape=F, + format="latex" +) %>% + kable_styling(font_size = 8) %>% + kable_paper(full_width = F) %>% + save_kable("paper/contents/table_gen_params.tex") +``` + +```{r} +files <- list.files("artifacts/results/") +dt <- lapply(files[grepl("_model_performance.csv", files)], function(x) { + fread(file.path("artifacts/results/", x)) +}) +dt <- Reduce(function(x,y) {rbind(x,y, fill=TRUE)}, dt) +dt[,dataname:=factor(dataname, levels=c("Linearly Separable", "Moons", "Circles", "MNIST", "GMSC"))] +dt <- dt[order(dataname,mod_name)] +setcolorder( + dt, + c( + "dataname", "mod_name", + "acc", "precision", "f1score" + ) +) +col_names <- c("Dataset", "Model", "Accuracy", "Precision", "F1-Score") +kbl( + dt, caption = "Various standard performance metrics for our different models grouped by dataset. \\label{tab:perf} \\newline", + align = "r", col.names=col_names, booktabs = T, escape=F, + format="latex", digits=2 +) %>% + kable_styling(font_size = 8) %>% + kable_paper(full_width = F) %>% + add_header_above(c(" "=2, "Performance Metrics" = 3)) %>% + collapse_rows(columns = 1, latex_hline = "custom", valign = "top", custom_latex_hline = 1) %>% + save_kable("paper/contents/table_perf.tex") ``` \ No newline at end of file diff --git a/paper/contents/table-real-world.tex b/paper/contents/table-real-world.tex index 771132ad5210ac98e6c67aa3e3f32516bf78f794..793e1a428acc4678c5048f22d849d3fefc1cc619 100644 --- a/paper/contents/table-real-world.tex +++ b/paper/contents/table-real-world.tex @@ -9,37 +9,37 @@ \cmidrule(l{3pt}r{3pt}){3-4} \cmidrule(l{3pt}r{3pt}){5-6} Model & Generator & Unfaithfulness ↓ & Implausibility ↓ & Unfaithfulness ↓ & Implausibility ↓\\ \midrule - & ECCCo & \textbf{19.27 ± 5.02}** & 314.54 ± 32.54*\hphantom{*} & \textbf{79.18 ± 13.01}** & 19.67 ± 6.27**\\ + & ECCCo & \textbf{19.27 ± 5.02}** & 314.54 ± 32.54*\hphantom{*} & \textbf{61.26 ± 13.46}** & 19.87 ± 2.37**\\ - & REVISE & 188.54 ± 26.22*\hphantom{*} & \textbf{254.32 ± 41.55}** & 186.05 ± 31.81\hphantom{*}\hphantom{*} & \textbf{5.38 ± 1.89}**\\ + & REVISE & 188.54 ± 26.22*\hphantom{*} & \textbf{254.32 ± 41.55}** & 152.22 ± 22.93\hphantom{*}\hphantom{*} & \textbf{5.01 ± 0.63}**\\ - & Schut & 199.70 ± 28.43\hphantom{*}\hphantom{*} & 273.01 ± 39.60** & 185.40 ± 38.43\hphantom{*}\hphantom{*} & 6.54 ± 0.98**\\ + & Schut & 199.70 ± 28.43\hphantom{*}\hphantom{*} & 273.01 ± 39.60** & 161.85 ± 21.18\hphantom{*}\hphantom{*} & 5.44 ± 0.97**\\ -\multirow{-4}{*}{\raggedright\arraybackslash JEM} & Wachter & 222.81 ± 26.22\hphantom{*}\hphantom{*} & 361.38 ± 39.55\hphantom{*}\hphantom{*} & 188.81 ± 41.72\hphantom{*}\hphantom{*} & 71.97 ± 60.09\hphantom{*}\hphantom{*}\\ +\multirow{-4}{*}{\raggedright\arraybackslash JEM} & Wachter & 222.81 ± 26.22\hphantom{*}\hphantom{*} & 361.38 ± 39.55\hphantom{*}\hphantom{*} & 170.58 ± 22.20\hphantom{*}\hphantom{*} & 78.21 ± 74.12\hphantom{*}\hphantom{*}\\ \cmidrule{1-6} - & ECCCo & \textbf{15.99 ± 3.06}** & 294.72 ± 30.75** & \textbf{79.65 ± 11.83}** & 17.81 ± 5.44**\\ + & ECCCo & \textbf{15.99 ± 3.06}** & 294.72 ± 30.75** & \textbf{61.25 ± 14.17}** & 18.66 ± 3.05**\\ - & REVISE & 173.05 ± 20.38** & \textbf{246.20 ± 37.74}** & 204.14 ± 36.13\hphantom{*}\hphantom{*} & \textbf{4.90 ± 0.95}**\\ + & REVISE & 173.05 ± 20.38** & \textbf{246.20 ± 37.74}** & 158.77 ± 24.61\hphantom{*}\hphantom{*} & \textbf{4.77 ± 0.56}**\\ - & Schut & 186.91 ± 22.98*\hphantom{*} & 264.68 ± 37.58** & 186.24 ± 36.18\hphantom{*}\hphantom{*} & 6.35 ± 1.22**\\ + & Schut & 186.91 ± 22.98*\hphantom{*} & 264.68 ± 37.58** & 154.90 ± 38.82\hphantom{*}\hphantom{*} & 5.49 ± 0.72**\\ -\multirow{-4}{*}{\raggedright\arraybackslash JEM Ensemble} & Wachter & 217.37 ± 23.93\hphantom{*}\hphantom{*} & 362.91 ± 39.40\hphantom{*}\hphantom{*} & 184.05 ± 23.11\hphantom{*}\hphantom{*} & 61.40 ± 48.29\hphantom{*}\hphantom{*}\\ +\multirow{-4}{*}{\raggedright\arraybackslash JEM Ensemble} & Wachter & 217.37 ± 23.93\hphantom{*}\hphantom{*} & 362.91 ± 39.40\hphantom{*}\hphantom{*} & 147.47 ± 30.47\hphantom{*}\hphantom{*} & 80.44 ± 44.94\hphantom{*}\hphantom{*}\\ \cmidrule{1-6} - & ECCCo & \textbf{41.95 ± 6.50}** & 591.58 ± 36.24\hphantom{*}\hphantom{*} & \textbf{80.51 ± 16.59}** & 23.43 ± 6.09**\\ + & ECCCo & \textbf{41.95 ± 6.50}** & 591.58 ± 36.24\hphantom{*}\hphantom{*} & \textbf{58.19 ± 15.18}** & 19.17 ± 3.76*\hphantom{*}\\ - & REVISE & 365.69 ± 14.90*\hphantom{*} & 245.36 ± 39.69** & 180.18 ± 30.75\hphantom{*}\hphantom{*} & \textbf{5.05 ± 1.05}**\\ + & REVISE & 365.69 ± 14.90*\hphantom{*} & 245.36 ± 39.69** & 171.10 ± 21.55\hphantom{*}\hphantom{*} & \textbf{4.99 ± 1.03}**\\ - & Schut & 371.12 ± 19.99\hphantom{*}\hphantom{*} & \textbf{245.11 ± 35.72}** & 199.88 ± 45.58\hphantom{*}\hphantom{*} & 7.25 ± 1.88**\\ + & Schut & 371.12 ± 19.99\hphantom{*}\hphantom{*} & \textbf{245.11 ± 35.72}** & 160.38 ± 31.67\hphantom{*}\hphantom{*} & 6.29 ± 1.23**\\ -\multirow{-4}{*}{\raggedright\arraybackslash MLP} & Wachter & 384.76 ± 16.52\hphantom{*}\hphantom{*} & 359.21 ± 42.03\hphantom{*}\hphantom{*} & 196.33 ± 33.11\hphantom{*}\hphantom{*} & 87.52 ± 53.98\hphantom{*}\hphantom{*}\\ +\multirow{-4}{*}{\raggedright\arraybackslash MLP} & Wachter & 384.76 ± 16.52\hphantom{*}\hphantom{*} & 359.21 ± 42.03\hphantom{*}\hphantom{*} & 171.64 ± 36.71\hphantom{*}\hphantom{*} & 26.40 ± 1.54\hphantom{*}\hphantom{*}\\ \cmidrule{1-6} - & ECCCo & \textbf{31.43 ± 3.91}** & 490.88 ± 27.19\hphantom{*}\hphantom{*} & \textbf{76.32 ± 14.56}** & 22.99 ± 8.31\hphantom{*}\hphantom{*}\\ + & ECCCo & \textbf{31.43 ± 3.91}** & 490.88 ± 27.19\hphantom{*}\hphantom{*} & \textbf{62.06 ± 14.56}** & 18.38 ± 3.74**\\ - & REVISE & 337.21 ± 11.68*\hphantom{*} & \textbf{244.84 ± 37.17}** & 184.04 ± 29.13\hphantom{*}\hphantom{*} & \textbf{5.25 ± 1.31}**\\ + & REVISE & 337.21 ± 11.68*\hphantom{*} & \textbf{244.84 ± 37.17}** & 153.48 ± 30.61\hphantom{*}\hphantom{*} & \textbf{4.80 ± 0.80}**\\ - & Schut & 344.60 ± 13.64*\hphantom{*} & 252.53 ± 37.92** & 214.74 ± 34.33\hphantom{*}\hphantom{*} & 6.18 ± 1.17**\\ + & Schut & 344.60 ± 13.64*\hphantom{*} & 252.53 ± 37.92** & 166.85 ± 28.33\hphantom{*}\hphantom{*} & 5.86 ± 0.71**\\ -\multirow{-4}{*}{\raggedright\arraybackslash MLP Ensemble} & Wachter & 358.51 ± 13.18\hphantom{*}\hphantom{*} & 352.63 ± 39.93\hphantom{*}\hphantom{*} & 193.41 ± 35.45\hphantom{*}\hphantom{*} & 12.71 ± 4.90\hphantom{*}\hphantom{*}\\ +\multirow{-4}{*}{\raggedright\arraybackslash MLP Ensemble} & Wachter & 358.51 ± 13.18\hphantom{*}\hphantom{*} & 352.63 ± 39.93\hphantom{*}\hphantom{*} & 150.78 ± 26.59\hphantom{*}\hphantom{*} & 73.51 ± 33.64\hphantom{*}\hphantom{*}\\ \bottomrule \end{tabular}} \end{table} diff --git a/paper/contents/table_ebm_params.tex b/paper/contents/table_ebm_params.tex index 2136f26e00905f1caebd372ee4e71d5d0b7b919f..e7fe1e0907384a1c0cf18ccdcef83720d6b91d45 100644 --- a/paper/contents/table_ebm_params.tex +++ b/paper/contents/table_ebm_params.tex @@ -7,11 +7,11 @@ \toprule Dataset & SGLD Steps & Batch Size & $\lambda$\\ \midrule -Circles & 20 & 100 & 0.01\\ -GMSC & 30 & 10 & 0.10\\ Linearly Separable & 30 & 50 & 0.10\\ -MNIST & 25 & 10 & 0.01\\ Moons & 30 & 10 & 0.10\\ +Circles & 20 & 100 & 0.01\\ +MNIST & 25 & 10 & 0.01\\ +GMSC & 30 & 10 & 0.10\\ \bottomrule \end{tabular} \end{table} diff --git a/paper/contents/table_gen_params.tex b/paper/contents/table_gen_params.tex new file mode 100644 index 0000000000000000000000000000000000000000..84b89401bdaebddd0a0c92778fe91d5cc0122d2b --- /dev/null +++ b/paper/contents/table_gen_params.tex @@ -0,0 +1,17 @@ +\begin{table} + +\caption{Generator hyperparameters. \label{tab:genparams} \newline} +\centering +\fontsize{8}{10}\selectfont +\begin{tabular}[t]{rrrrr} +\toprule +Dataset & $\eta$ & $\lambda_1$ & $\lambda_2$ & $\lambda_3$\\ +\midrule +Linearly Separable & 0.01 & 0.25 & 0.75 & 0.75\\ +Moons & 0.05 & 0.25 & 0.75 & 0.75\\ +Circles & 0.01 & 0.25 & 0.75 & 0.75\\ +MNIST & 0.10 & 0.10 & 0.25 & 0.25\\ +GMSC & 0.05 & 0.10 & 0.50 & 0.50\\ +\bottomrule +\end{tabular} +\end{table} diff --git a/paper/contents/table_params.tex b/paper/contents/table_params.tex index 47ba04a01a622806ff6b24f7b2b21599f9482248..f9098c3853bf8d60465bc9d4b01ffd57d7fa36bc 100644 --- a/paper/contents/table_params.tex +++ b/paper/contents/table_params.tex @@ -9,11 +9,11 @@ \cmidrule(l{3pt}r{3pt}){3-6} \cmidrule(l{3pt}r{3pt}){7-8} Dataset & Sample Size & Hidden Units & Hidden Layers & Activation & Ensemble Size & Epochs & Batch Size\\ \midrule -Circles & 1000 & 32 & 3 & swish & 5 & 100 & 100\\ -GMSC & 50000 & 128 & 2 & swish & 5 & 100 & 250\\ Linearly Separable & 1000 & 16 & 3 & swish & 5 & 100 & 100\\ -MNIST & 10000 & 128 & 1 & swish & 5 & 100 & 128\\ Moons & 2500 & 32 & 3 & relu & 5 & 500 & 128\\ +Circles & 1000 & 32 & 3 & swish & 5 & 100 & 100\\ +MNIST & 10000 & 128 & 1 & swish & 5 & 100 & 128\\ +GMSC & 4000 & 128 & 2 & swish & 5 & 100 & 250\\ \bottomrule \end{tabular}} \end{table} diff --git a/paper/contents/table_perf.tex b/paper/contents/table_perf.tex new file mode 100644 index 0000000000000000000000000000000000000000..4abcf79b72dca99b16922c20f3d1f7da6156f2a6 --- /dev/null +++ b/paper/contents/table_perf.tex @@ -0,0 +1,41 @@ +\begin{table} + +\caption{Various standard performance metrics for our different models grouped by dataset. \label{tab:perf} \newline} +\centering +\fontsize{8}{10}\selectfont +\begin{tabular}[t]{rrrrr} +\toprule +\multicolumn{2}{c}{ } & \multicolumn{3}{c}{Performance Metrics} \\ +\cmidrule(l{3pt}r{3pt}){3-5} +Dataset & Model & Accuracy & Precision & F1-Score\\ +\midrule + & JEM & 0.99 & 0.99 & 0.99\\ + +\multirow[t]{-2}{*}{\raggedleft\arraybackslash Linearly Separable} & MLP & 0.99 & 0.99 & 0.99\\ +\cmidrule{1-5} + & JEM & 1.00 & 1.00 & 1.00\\ + +\multirow[t]{-2}{*}{\raggedleft\arraybackslash Moons} & MLP & 1.00 & 1.00 & 1.00\\ +\cmidrule{1-5} + & JEM & 0.98 & 0.98 & 0.98\\ + +\multirow[t]{-2}{*}{\raggedleft\arraybackslash Circles} & MLP & 1.00 & 1.00 & 1.00\\ +\cmidrule{1-5} + & JEM & 0.83 & 0.84 & 0.83\\ + + & JEM Ensemble & 0.90 & 0.90 & 0.89\\ + + & MLP & 0.95 & 0.95 & 0.95\\ + +\multirow[t]{-4}{*}{\raggedleft\arraybackslash MNIST} & MLP Ensemble & 0.95 & 0.95 & 0.95\\ +\cmidrule{1-5} + & JEM & 0.72 & 0.75 & 0.71\\ + + & JEM Ensemble & 0.74 & 0.75 & 0.73\\ + + & MLP & 0.74 & 0.75 & 0.74\\ + +\multirow[t]{-4}{*}{\raggedleft\arraybackslash GMSC} & MLP Ensemble & 0.73 & 0.74 & 0.73\\ +\bottomrule +\end{tabular} +\end{table} diff --git a/paper/paper.pdf b/paper/paper.pdf index 39a4a39ea354e1e07792b09e7bb43d7ab301d8af..7ab5c32ad4a1eea064f9f200fcb4b60df01cc2d3 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 9664fd5ec904b740590fcf84911f291c8937b329..9bf9fbc21a4cae6f52613e0328df35a78be69865 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -439,7 +439,15 @@ The core part of our code base is integrated into a larger ecosystem of \texttt{ \subsection{Experimental Setup}\label{app:setup} -Table~\ref{tab:params} provides an overview of all parameters related to our experiments. The \textit{GMSC} data were randomly undersampled for balancing purposes and all features were standardized. \textit{MNIST} data was also randomly undersampled for reasons outlined below. Pixel values were preprocessed to fall in the range of $[-1,1]$ and a small Gaussian noise component ($\sigma=0.03$) was added to training samples following common practices in the EBM literature. +Table~\ref{tab:params} provides an overview of all parameters related to our experiments. The \textit{GMSC} data were randomly undersampled for balancing purposes and all features were standardized. \textit{MNIST} data was also randomly undersampled for reasons outlined below. Pixel values were preprocessed to fall in the range of $[-1,1]$ and a small Gaussian noise component ($\sigma=0.03$) was added to training samples following common practice in the EBM literature. Table~\ref{tab:perf} shows standard evaluation metrics measuring the predictive performance of our different models grouped by dataset. These measures were computed over test datasets. + +Table~\ref{tab:genparams} summarises our hyperparameter choices for the counterfactual generators where $\eta$ denotes the learning rate used for Stochastic Gradient Descent (SGD) and $\lambda_1$, $\lambda_2$, $\lambda_3$ represent the chosen penalty strengths (Equations~\ref{eq:general} and~\ref{eq:eccco}). Here $\lambda_1$ also refers to the chosen penalty for the distance from factual values that applies to both \textit{Wachter} and \textit{REVISE}, but not \textit{Schut} which is penalty-free. \textit{Schut} is also the only generator that uses JSMA instead of SGD for optimization. + +\import{contents/}{table_params.tex} + +\import{contents/}{table_perf.tex} + +\import{contents/}{table_gen_params.tex} \subsubsection{Compute} @@ -471,14 +479,14 @@ We have summarised the system information below: \item Memory: 32 GB \end{itemize} -\import{contents/}{table_params.tex} + \subsection{Results}\label{app:results} Figure~\ref{fig:mnist-eccco} shows examples of counterfactuals for \textit{MNIST} data where the underlying model is our \textit{JEM Ensemble}. Original images are shown on the diagonal and the corresponding counterfactuals are plotted across rows. \begin{figure} \centering - \includegraphics[width=0.75\textwidth]{../artifacts/results/images/mnist_eccco_all_digits.png} + \includegraphics[width=0.9\textwidth]{../artifacts/results/images/mnist_eccco_all_digits.png} \caption{Counterfactuals for \textit{MNIST} data and our \textit{JEM Ensemble}. Original images are shown on the diagonal with the corresponding counterfactuals plotted across rows.}\label{fig:mnist-eccco} \end{figure}