diff --git a/artifacts/results/images/mnist_benchmark.png b/artifacts/results/images/mnist_benchmark.png new file mode 100644 index 0000000000000000000000000000000000000000..18f0323b214e9e2e371a564d64db2dc44d51f4ef Binary files /dev/null and b/artifacts/results/images/mnist_benchmark.png differ diff --git a/artifacts/results/mnist_benchmark.csv b/artifacts/results/mnist_benchmark.csv new file mode 100644 index 0000000000000000000000000000000000000000..8dc98fd1d7a1d0f1801b70398a39838f9099ea9b --- /dev/null +++ b/artifacts/results/mnist_benchmark.csv @@ -0,0 +1,801 @@ +sample,variable,value,ce,dataname,generator,model,target,factual +1,distance,6.777590274810791," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +1,distance_from_energy,142.5549774169922," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +1,distance_from_targets,121.48783874511719," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +1,redundancy,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +1,validity,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +2,distance,7.9763593673706055," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +2,distance_from_energy,104.6924819946289," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +2,distance_from_targets,93.2911148071289," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +2,redundancy,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +2,validity,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +3,distance,8.455801963806152," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +3,distance_from_energy,103.70753479003906," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +3,distance_from_targets,102.11791229248047," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +3,redundancy,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +3,validity,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +4,distance,9.339488983154297," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +4,distance_from_energy,122.48368835449219," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +4,distance_from_targets,115.36590576171875," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +4,redundancy,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +4,validity,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +5,distance,7.928236961364746," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +5,distance_from_energy,125.93316650390625," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +5,distance_from_targets,119.22042846679688," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +5,redundancy,0.0012755102040816326," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +5,validity,0.0," +Convergence: âŒ",MNIST,revise,JEM Ensemble,1,0 +1,distance,4.242640495300293," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +1,distance_from_energy,145.15542602539062," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +1,distance_from_targets,143.36795043945312," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +1,redundancy,0.9770408163265306," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +1,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +2,distance,5.196152210235596," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +2,distance_from_energy,173.82566833496094," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +2,distance_from_targets,164.2250213623047," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +2,redundancy,0.9655612244897959," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +2,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +3,distance,7.211102485656738," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +3,distance_from_energy,154.64845275878906," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +3,distance_from_targets,184.35992431640625," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +3,redundancy,0.9336734693877551," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +3,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +4,distance,6.480740547180176," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +4,distance_from_energy,201.22898864746094," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +4,distance_from_targets,207.16464233398438," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +4,redundancy,0.9464285714285714," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +4,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +5,distance,6.164413928985596," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +5,distance_from_energy,181.07154846191406," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +5,distance_from_targets,192.42422485351562," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +5,redundancy,0.951530612244898," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +5,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM Ensemble,1,0 +1,distance,4.971918106079102," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +1,distance_from_energy,96.36383056640625," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +1,distance_from_targets,108.70105743408203," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +1,redundancy,0.0," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +1,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +2,distance,8.110333442687988," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +2,distance_from_energy,74.99360656738281," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +2,distance_from_targets,96.28961181640625," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +2,redundancy,0.1913265306122449," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +2,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +3,distance,7.8416266441345215," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +3,distance_from_energy,102.18379211425781," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +3,distance_from_targets,101.49235534667969," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +3,redundancy,0.19642857142857142," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +3,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +4,distance,9.177770614624023," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +4,distance_from_energy,86.6290054321289," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +4,distance_from_targets,104.81550598144531," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +4,redundancy,0.17091836734693877," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +4,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +5,distance,8.959502220153809," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +5,distance_from_energy,66.32455444335938," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +5,distance_from_targets,100.0269546508789," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +5,redundancy,0.1875," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +5,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM Ensemble,1,0 +1,distance,2.6398472785949707," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +1,distance_from_energy,188.9833984375," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +1,distance_from_targets,173.45201110839844," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +1,redundancy,0.0," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +1,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +2,distance,3.383842945098877," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +2,distance_from_energy,207.63113403320312," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +2,distance_from_targets,194.04270935058594," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +2,redundancy,0.0," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +2,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +3,distance,5.309839248657227," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +3,distance_from_energy,228.83859252929688," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +3,distance_from_targets,222.18597412109375," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +3,redundancy,0.011479591836734694," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +3,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +4,distance,4.59068489074707," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +4,distance_from_energy,244.65087890625," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +4,distance_from_targets,241.43231201171875," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +4,redundancy,0.0," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +4,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +5,distance,4.684568881988525," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +5,distance_from_energy,240.7984619140625," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +5,distance_from_targets,230.89126586914062," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +5,redundancy,0.01020408163265306," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +5,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM Ensemble,1,0 +6,distance,8.719939231872559," +Convergence: âŒ",MNIST,revise,MLP,1,0 +6,distance_from_energy,458.403564453125," +Convergence: âŒ",MNIST,revise,MLP,1,0 +6,distance_from_targets,87.7504653930664," +Convergence: âŒ",MNIST,revise,MLP,1,0 +6,redundancy,0.0," +Convergence: âŒ",MNIST,revise,MLP,1,0 +6,validity,0.0," +Convergence: âŒ",MNIST,revise,MLP,1,0 +7,distance,9.099820137023926," +Convergence: ✅",MNIST,revise,MLP,1,0 +7,distance_from_energy,446.17236328125," +Convergence: ✅",MNIST,revise,MLP,1,0 +7,distance_from_targets,85.84215545654297," +Convergence: ✅",MNIST,revise,MLP,1,0 +7,redundancy,0.0," +Convergence: ✅",MNIST,revise,MLP,1,0 +7,validity,1.0," +Convergence: ✅",MNIST,revise,MLP,1,0 +8,distance,9.351598739624023," +Convergence: ✅",MNIST,revise,MLP,1,0 +8,distance_from_energy,447.8080139160156," +Convergence: ✅",MNIST,revise,MLP,1,0 +8,distance_from_targets,116.43446350097656," +Convergence: ✅",MNIST,revise,MLP,1,0 +8,redundancy,0.0," +Convergence: ✅",MNIST,revise,MLP,1,0 +8,validity,1.0," +Convergence: ✅",MNIST,revise,MLP,1,0 +9,distance,8.3056001663208," +Convergence: âŒ",MNIST,revise,MLP,1,0 +9,distance_from_energy,432.466552734375," +Convergence: âŒ",MNIST,revise,MLP,1,0 +9,distance_from_targets,101.77315521240234," +Convergence: âŒ",MNIST,revise,MLP,1,0 +9,redundancy,0.0," +Convergence: âŒ",MNIST,revise,MLP,1,0 +9,validity,0.0," +Convergence: âŒ",MNIST,revise,MLP,1,0 +10,distance,7.884433269500732," +Convergence: âŒ",MNIST,revise,MLP,1,0 +10,distance_from_energy,437.5025329589844," +Convergence: âŒ",MNIST,revise,MLP,1,0 +10,distance_from_targets,85.83161926269531," +Convergence: âŒ",MNIST,revise,MLP,1,0 +10,redundancy,0.0," +Convergence: âŒ",MNIST,revise,MLP,1,0 +10,validity,0.0," +Convergence: âŒ",MNIST,revise,MLP,1,0 +6,distance,4.0," +Convergence: ✅",MNIST,greedy,MLP,1,0 +6,distance_from_energy,476.1180725097656," +Convergence: ✅",MNIST,greedy,MLP,1,0 +6,distance_from_targets,189.96182250976562," +Convergence: ✅",MNIST,greedy,MLP,1,0 +6,redundancy,0.9795918367346939," +Convergence: ✅",MNIST,greedy,MLP,1,0 +6,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP,1,0 +7,distance,4.0," +Convergence: ✅",MNIST,greedy,MLP,1,0 +7,distance_from_energy,478.76177978515625," +Convergence: ✅",MNIST,greedy,MLP,1,0 +7,distance_from_targets,186.25125122070312," +Convergence: ✅",MNIST,greedy,MLP,1,0 +7,redundancy,0.9795918367346939," +Convergence: ✅",MNIST,greedy,MLP,1,0 +7,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP,1,0 +8,distance,4.358899116516113," +Convergence: ✅",MNIST,greedy,MLP,1,0 +8,distance_from_energy,460.4898376464844," +Convergence: ✅",MNIST,greedy,MLP,1,0 +8,distance_from_targets,240.55148315429688," +Convergence: ✅",MNIST,greedy,MLP,1,0 +8,redundancy,0.9757653061224489," +Convergence: ✅",MNIST,greedy,MLP,1,0 +8,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP,1,0 +9,distance,4.242640495300293," +Convergence: ✅",MNIST,greedy,MLP,1,0 +9,distance_from_energy,505.6572265625," +Convergence: ✅",MNIST,greedy,MLP,1,0 +9,distance_from_targets,198.39793395996094," +Convergence: ✅",MNIST,greedy,MLP,1,0 +9,redundancy,0.9770408163265306," +Convergence: ✅",MNIST,greedy,MLP,1,0 +9,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP,1,0 +10,distance,4.123105525970459," +Convergence: ✅",MNIST,greedy,MLP,1,0 +10,distance_from_energy,465.98388671875," +Convergence: ✅",MNIST,greedy,MLP,1,0 +10,distance_from_targets,174.0707550048828," +Convergence: ✅",MNIST,greedy,MLP,1,0 +10,redundancy,0.9783163265306123," +Convergence: ✅",MNIST,greedy,MLP,1,0 +10,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP,1,0 +6,distance,2.3986656665802," +Convergence: ✅",MNIST,eccco,MLP,1,0 +6,distance_from_energy,421.09893798828125," +Convergence: ✅",MNIST,eccco,MLP,1,0 +6,distance_from_targets,196.3191375732422," +Convergence: ✅",MNIST,eccco,MLP,1,0 +6,redundancy,0.0," +Convergence: ✅",MNIST,eccco,MLP,1,0 +6,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP,1,0 +7,distance,2.5710692405700684," +Convergence: ✅",MNIST,eccco,MLP,1,0 +7,distance_from_energy,408.6450500488281," +Convergence: ✅",MNIST,eccco,MLP,1,0 +7,distance_from_targets,194.28594970703125," +Convergence: ✅",MNIST,eccco,MLP,1,0 +7,redundancy,0.011479591836734694," +Convergence: ✅",MNIST,eccco,MLP,1,0 +7,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP,1,0 +8,distance,2.715510129928589," +Convergence: ✅",MNIST,eccco,MLP,1,0 +8,distance_from_energy,396.1204833984375," +Convergence: ✅",MNIST,eccco,MLP,1,0 +8,distance_from_targets,245.9937286376953," +Convergence: ✅",MNIST,eccco,MLP,1,0 +8,redundancy,0.007653061224489796," +Convergence: ✅",MNIST,eccco,MLP,1,0 +8,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP,1,0 +9,distance,2.6266322135925293," +Convergence: ✅",MNIST,eccco,MLP,1,0 +9,distance_from_energy,434.64111328125," +Convergence: ✅",MNIST,eccco,MLP,1,0 +9,distance_from_targets,204.14910888671875," +Convergence: ✅",MNIST,eccco,MLP,1,0 +9,redundancy,0.00510204081632653," +Convergence: ✅",MNIST,eccco,MLP,1,0 +9,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP,1,0 +10,distance,2.4240455627441406," +Convergence: ✅",MNIST,eccco,MLP,1,0 +10,distance_from_energy,422.4525451660156," +Convergence: ✅",MNIST,eccco,MLP,1,0 +10,distance_from_targets,180.05462646484375," +Convergence: ✅",MNIST,eccco,MLP,1,0 +10,redundancy,0.0," +Convergence: ✅",MNIST,eccco,MLP,1,0 +10,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP,1,0 +6,distance,2.088252305984497," +Convergence: ✅",MNIST,wachter,MLP,1,0 +6,distance_from_energy,457.7002868652344," +Convergence: ✅",MNIST,wachter,MLP,1,0 +6,distance_from_targets,191.50486755371094," +Convergence: ✅",MNIST,wachter,MLP,1,0 +6,redundancy,0.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +6,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +7,distance,2.110544443130493," +Convergence: ✅",MNIST,wachter,MLP,1,0 +7,distance_from_energy,438.66021728515625," +Convergence: ✅",MNIST,wachter,MLP,1,0 +7,distance_from_targets,185.83787536621094," +Convergence: ✅",MNIST,wachter,MLP,1,0 +7,redundancy,0.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +7,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +8,distance,2.516822576522827," +Convergence: ✅",MNIST,wachter,MLP,1,0 +8,distance_from_energy,424.76153564453125," +Convergence: ✅",MNIST,wachter,MLP,1,0 +8,distance_from_targets,239.91163635253906," +Convergence: ✅",MNIST,wachter,MLP,1,0 +8,redundancy,0.012755102040816327," +Convergence: ✅",MNIST,wachter,MLP,1,0 +8,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +9,distance,2.464142322540283," +Convergence: ✅",MNIST,wachter,MLP,1,0 +9,distance_from_energy,450.7071228027344," +Convergence: ✅",MNIST,wachter,MLP,1,0 +9,distance_from_targets,198.46058654785156," +Convergence: ✅",MNIST,wachter,MLP,1,0 +9,redundancy,0.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +9,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +10,distance,2.2826294898986816," +Convergence: ✅",MNIST,wachter,MLP,1,0 +10,distance_from_energy,452.281494140625," +Convergence: ✅",MNIST,wachter,MLP,1,0 +10,distance_from_targets,175.62435913085938," +Convergence: ✅",MNIST,wachter,MLP,1,0 +10,redundancy,0.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +10,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP,1,0 +11,distance,8.70202350616455," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +11,distance_from_energy,419.6873474121094," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +11,distance_from_targets,99.92082977294922," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +11,redundancy,0.0012755102040816326," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +11,validity,0.0," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +12,distance,9.599431037902832," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +12,distance_from_energy,432.94207763671875," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +12,distance_from_targets,101.373291015625," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +12,redundancy,0.0," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +12,validity,0.0," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +13,distance,8.03868579864502," +Convergence: ✅",MNIST,revise,MLP Ensemble,1,0 +13,distance_from_energy,428.8260803222656," +Convergence: ✅",MNIST,revise,MLP Ensemble,1,0 +13,distance_from_targets,84.50151062011719," +Convergence: ✅",MNIST,revise,MLP Ensemble,1,0 +13,redundancy,0.002551020408163265," +Convergence: ✅",MNIST,revise,MLP Ensemble,1,0 +13,validity,1.0," +Convergence: ✅",MNIST,revise,MLP Ensemble,1,0 +14,distance,9.208531379699707," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +14,distance_from_energy,446.864013671875," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +14,distance_from_targets,101.66757202148438," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +14,redundancy,0.0," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +14,validity,0.0," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +15,distance,8.781783103942871," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +15,distance_from_energy,423.536865234375," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +15,distance_from_targets,91.89868927001953," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +15,redundancy,0.0012755102040816326," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +15,validity,0.0," +Convergence: âŒ",MNIST,revise,MLP Ensemble,1,0 +11,distance,4.582575798034668," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +11,distance_from_energy,466.7583923339844," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +11,distance_from_targets,207.3255615234375," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +11,redundancy,0.9732142857142857," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +11,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +12,distance,4.582575798034668," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +12,distance_from_energy,485.16845703125," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +12,distance_from_targets,224.75144958496094," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +12,redundancy,0.9732142857142857," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +12,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +13,distance,3.872983455657959," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +13,distance_from_energy,466.527099609375," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +13,distance_from_targets,171.34530639648438," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +13,redundancy,0.9808673469387755," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +13,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +14,distance,4.242640495300293," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +14,distance_from_energy,467.6905212402344," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +14,distance_from_targets,206.57318115234375," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +14,redundancy,0.9770408163265306," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +14,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +15,distance,5.0," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +15,distance_from_energy,469.6475524902344," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +15,distance_from_targets,207.16424560546875," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +15,redundancy,0.9681122448979592," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +15,validity,1.0," +Convergence: ✅",MNIST,greedy,MLP Ensemble,1,0 +11,distance,2.9991979598999023," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +11,distance_from_energy,390.73455810546875," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +11,distance_from_targets,220.1835479736328," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +11,redundancy,0.0," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +11,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +12,distance,3.1859054565429688," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +12,distance_from_energy,393.7857971191406," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +12,distance_from_targets,236.0286865234375," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +12,redundancy,0.015306122448979591," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +12,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +13,distance,2.1722793579101562," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +13,distance_from_energy,409.6536560058594," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +13,distance_from_targets,178.91244506835938," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +13,redundancy,0.01020408163265306," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +13,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +14,distance,2.4659271240234375," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +14,distance_from_energy,399.98175048828125," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +14,distance_from_targets,214.7428741455078," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +14,redundancy,0.0," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +14,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +15,distance,3.231407642364502," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +15,distance_from_energy,390.91986083984375," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +15,distance_from_targets,220.17723083496094," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +15,redundancy,0.003826530612244898," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +15,validity,1.0," +Convergence: ✅",MNIST,eccco,MLP Ensemble,1,0 +11,distance,2.7152156829833984," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +11,distance_from_energy,408.47833251953125," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +11,distance_from_targets,211.25404357910156," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +11,redundancy,0.006377551020408163," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +11,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +12,distance,2.797426462173462," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +12,distance_from_energy,435.64996337890625," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +12,distance_from_targets,229.1556854248047," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +12,redundancy,0.0," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +12,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +13,distance,1.9998995065689087," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +13,distance_from_energy,419.06854248046875," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +13,distance_from_targets,173.11288452148438," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +13,redundancy,0.05357142857142857," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +13,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +14,distance,2.2875308990478516," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +14,distance_from_energy,423.68280029296875," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +14,distance_from_targets,208.61831665039062," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +14,redundancy,0.0," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +14,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +15,distance,2.9437382221221924," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +15,distance_from_energy,421.8553161621094," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +15,distance_from_targets,208.0811309814453," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +15,redundancy,0.05994897959183673," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +15,validity,1.0," +Convergence: ✅",MNIST,wachter,MLP Ensemble,1,0 +16,distance,5.215779781341553," +Convergence: ✅",MNIST,revise,JEM,1,0 +16,distance_from_energy,100.11660766601562," +Convergence: ✅",MNIST,revise,JEM,1,0 +16,distance_from_targets,121.23953247070312," +Convergence: ✅",MNIST,revise,JEM,1,0 +16,redundancy,0.0," +Convergence: ✅",MNIST,revise,JEM,1,0 +16,validity,1.0," +Convergence: ✅",MNIST,revise,JEM,1,0 +17,distance,8.366628646850586," +Convergence: ✅",MNIST,revise,JEM,1,0 +17,distance_from_energy,86.1596908569336," +Convergence: ✅",MNIST,revise,JEM,1,0 +17,distance_from_targets,96.00230407714844," +Convergence: ✅",MNIST,revise,JEM,1,0 +17,redundancy,0.0," +Convergence: ✅",MNIST,revise,JEM,1,0 +17,validity,1.0," +Convergence: ✅",MNIST,revise,JEM,1,0 +18,distance,7.347404956817627," +Convergence: âŒ",MNIST,revise,JEM,1,0 +18,distance_from_energy,90.30267333984375," +Convergence: âŒ",MNIST,revise,JEM,1,0 +18,distance_from_targets,103.30117797851562," +Convergence: âŒ",MNIST,revise,JEM,1,0 +18,redundancy,0.0," +Convergence: âŒ",MNIST,revise,JEM,1,0 +18,validity,0.0," +Convergence: âŒ",MNIST,revise,JEM,1,0 +19,distance,8.839747428894043," +Convergence: âŒ",MNIST,revise,JEM,1,0 +19,distance_from_energy,125.14628601074219," +Convergence: âŒ",MNIST,revise,JEM,1,0 +19,distance_from_targets,147.10423278808594," +Convergence: âŒ",MNIST,revise,JEM,1,0 +19,redundancy,0.0," +Convergence: âŒ",MNIST,revise,JEM,1,0 +19,validity,0.0," +Convergence: âŒ",MNIST,revise,JEM,1,0 +20,distance,9.20132064819336," +Convergence: âŒ",MNIST,revise,JEM,1,0 +20,distance_from_energy,103.3006820678711," +Convergence: âŒ",MNIST,revise,JEM,1,0 +20,distance_from_targets,108.06597900390625," +Convergence: âŒ",MNIST,revise,JEM,1,0 +20,redundancy,0.0," +Convergence: âŒ",MNIST,revise,JEM,1,0 +20,validity,0.0," +Convergence: âŒ",MNIST,revise,JEM,1,0 +16,distance,3.872983455657959," +Convergence: ✅",MNIST,greedy,JEM,1,0 +16,distance_from_energy,153.08375549316406," +Convergence: ✅",MNIST,greedy,JEM,1,0 +16,distance_from_targets,154.8096160888672," +Convergence: ✅",MNIST,greedy,JEM,1,0 +16,redundancy,0.9808673469387755," +Convergence: ✅",MNIST,greedy,JEM,1,0 +16,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM,1,0 +17,distance,4.690415859222412," +Convergence: ✅",MNIST,greedy,JEM,1,0 +17,distance_from_energy,162.777587890625," +Convergence: ✅",MNIST,greedy,JEM,1,0 +17,distance_from_targets,174.16055297851562," +Convergence: ✅",MNIST,greedy,JEM,1,0 +17,redundancy,0.9719387755102041," +Convergence: ✅",MNIST,greedy,JEM,1,0 +17,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM,1,0 +18,distance,4.123105525970459," +Convergence: ✅",MNIST,greedy,JEM,1,0 +18,distance_from_energy,159.30421447753906," +Convergence: ✅",MNIST,greedy,JEM,1,0 +18,distance_from_targets,170.12925720214844," +Convergence: ✅",MNIST,greedy,JEM,1,0 +18,redundancy,0.9783163265306123," +Convergence: ✅",MNIST,greedy,JEM,1,0 +18,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM,1,0 +19,distance,5.385164737701416," +Convergence: ✅",MNIST,greedy,JEM,1,0 +19,distance_from_energy,215.79312133789062," +Convergence: ✅",MNIST,greedy,JEM,1,0 +19,distance_from_targets,238.88656616210938," +Convergence: ✅",MNIST,greedy,JEM,1,0 +19,redundancy,0.9630102040816326," +Convergence: ✅",MNIST,greedy,JEM,1,0 +19,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM,1,0 +20,distance,5.385164737701416," +Convergence: ✅",MNIST,greedy,JEM,1,0 +20,distance_from_energy,205.99270629882812," +Convergence: ✅",MNIST,greedy,JEM,1,0 +20,distance_from_targets,212.05014038085938," +Convergence: ✅",MNIST,greedy,JEM,1,0 +20,redundancy,0.9630102040816326," +Convergence: ✅",MNIST,greedy,JEM,1,0 +20,validity,1.0," +Convergence: ✅",MNIST,greedy,JEM,1,0 +16,distance,4.361510753631592," +Convergence: ✅",MNIST,eccco,JEM,1,0 +16,distance_from_energy,80.49446105957031," +Convergence: ✅",MNIST,eccco,JEM,1,0 +16,distance_from_targets,105.89717864990234," +Convergence: ✅",MNIST,eccco,JEM,1,0 +16,redundancy,0.18622448979591838," +Convergence: ✅",MNIST,eccco,JEM,1,0 +16,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM,1,0 +17,distance,4.892685413360596," +Convergence: ✅",MNIST,eccco,JEM,1,0 +17,distance_from_energy,114.7015380859375," +Convergence: ✅",MNIST,eccco,JEM,1,0 +17,distance_from_targets,123.17688751220703," +Convergence: ✅",MNIST,eccco,JEM,1,0 +17,redundancy,0.1989795918367347," +Convergence: ✅",MNIST,eccco,JEM,1,0 +17,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM,1,0 +18,distance,4.267268180847168," +Convergence: ✅",MNIST,eccco,JEM,1,0 +18,distance_from_energy,97.6572265625," +Convergence: ✅",MNIST,eccco,JEM,1,0 +18,distance_from_targets,119.03404998779297," +Convergence: ✅",MNIST,eccco,JEM,1,0 +18,redundancy,0.0," +Convergence: ✅",MNIST,eccco,JEM,1,0 +18,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM,1,0 +19,distance,7.968230247497559," +Convergence: ✅",MNIST,eccco,JEM,1,0 +19,distance_from_energy,104.6467056274414," +Convergence: ✅",MNIST,eccco,JEM,1,0 +19,distance_from_targets,130.1501922607422," +Convergence: ✅",MNIST,eccco,JEM,1,0 +19,redundancy,0.15306122448979592," +Convergence: ✅",MNIST,eccco,JEM,1,0 +19,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM,1,0 +20,distance,6.165253162384033," +Convergence: ✅",MNIST,eccco,JEM,1,0 +20,distance_from_energy,97.56226348876953," +Convergence: ✅",MNIST,eccco,JEM,1,0 +20,distance_from_targets,125.54220581054688," +Convergence: ✅",MNIST,eccco,JEM,1,0 +20,redundancy,0.0," +Convergence: ✅",MNIST,eccco,JEM,1,0 +20,validity,1.0," +Convergence: ✅",MNIST,eccco,JEM,1,0 +16,distance,1.8214280605316162," +Convergence: ✅",MNIST,wachter,JEM,1,0 +16,distance_from_energy,158.41860961914062," +Convergence: ✅",MNIST,wachter,JEM,1,0 +16,distance_from_targets,158.55406188964844," +Convergence: ✅",MNIST,wachter,JEM,1,0 +16,redundancy,0.0," +Convergence: ✅",MNIST,wachter,JEM,1,0 +16,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM,1,0 +17,distance,3.2385175228118896," +Convergence: ✅",MNIST,wachter,JEM,1,0 +17,distance_from_energy,192.31753540039062," +Convergence: ✅",MNIST,wachter,JEM,1,0 +17,distance_from_targets,194.8053741455078," +Convergence: ✅",MNIST,wachter,JEM,1,0 +17,redundancy,0.0," +Convergence: ✅",MNIST,wachter,JEM,1,0 +17,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM,1,0 +18,distance,2.148487091064453," +Convergence: ✅",MNIST,wachter,JEM,1,0 +18,distance_from_energy,176.5944061279297," +Convergence: ✅",MNIST,wachter,JEM,1,0 +18,distance_from_targets,176.7031707763672," +Convergence: ✅",MNIST,wachter,JEM,1,0 +18,redundancy,0.06505102040816327," +Convergence: ✅",MNIST,wachter,JEM,1,0 +18,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM,1,0 +19,distance,3.6989712715148926," +Convergence: ✅",MNIST,wachter,JEM,1,0 +19,distance_from_energy,241.1081085205078," +Convergence: ✅",MNIST,wachter,JEM,1,0 +19,distance_from_targets,249.22320556640625," +Convergence: ✅",MNIST,wachter,JEM,1,0 +19,redundancy,0.025510204081632654," +Convergence: ✅",MNIST,wachter,JEM,1,0 +19,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM,1,0 +20,distance,3.367253065109253," +Convergence: ✅",MNIST,wachter,JEM,1,0 +20,distance_from_energy,218.31698608398438," +Convergence: ✅",MNIST,wachter,JEM,1,0 +20,distance_from_targets,218.95526123046875," +Convergence: ✅",MNIST,wachter,JEM,1,0 +20,redundancy,0.0," +Convergence: ✅",MNIST,wachter,JEM,1,0 +20,validity,1.0," +Convergence: ✅",MNIST,wachter,JEM,1,0 diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 82d559fddefc8f283e66c3366ec4ff6fb9e9f48b..72ea2bc5c11a1a96ee6e5e9464292a947d993885 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -148,9 +148,10 @@ _retrain = false _regen = false # Data: -n_obs = 1000 +n_obs = 10000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) +counterfactual_data.generative_model = vae X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1) @@ -186,7 +187,7 @@ sampler = ConditionalSampler( input_size=(input_dim,), batch_size=10, ) -α = [1.0,1.0,1e-1] # penalty strengths +α = [1.0,1.0,1e-2] # penalty strengths ``` ```{julia} @@ -258,6 +259,7 @@ end ```{julia} # Plot generated samples: +n_regen = 150 if _regen for (mod_name, mod) in model_dict if ECCCo._has_sampler(mod) @@ -272,12 +274,11 @@ if _regen opt = ImproperSGLD() f(x) = logits(mod, x) - n_iter = 200 _w = 1500 plts = [] neach = 10 for i in 1:10 - x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) + x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i) plts_i = [] for j in 1:size(x, 2) xj = x[:,j] @@ -321,7 +322,7 @@ model_performance ```{julia} function _plot_eccco_mnist( x::Union{AbstractArray, Int}=x_factual, target::Int=target; - λ=[0.1,0.1,0.1], + λ=[0.5,0.1,0.5], temp=0.1,η=0.01, plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"], opt = Flux.Optimise.Adam(η), @@ -498,7 +499,7 @@ plts = plts[plt_order] plts = [p1, plts...] plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) -savefig(plt, joinpath(output_images_path, "mnist_eccco_benchmark.png")) +savefig(plt, joinpath(output_images_path, "mnist_all_generators.png")) ``` ## Benchmark @@ -509,7 +510,50 @@ measures = [ CounterfactualExplanations.distance, ECCCo.distance_from_energy, ECCCo.distance_from_targets, - CounterfactualExplanations.validity, - CounterfactualExplanations.redudancy, + CounterfactualExplanations.Evaluation.validity, + CounterfactualExplanations.Evaluation.redundancy, ] + +bmk = benchmark( + counterfactual_data; + models=model_dict, + generators=generator_dict, + measure=measures, + suppress_training=true, dataname="MNIST", + n_individuals=5, + factual=0, target=1, + initialization=:identity, +) +CSV.write(joinpath(output_path, "mnist_benchmark.csv"), bmk()) +``` + + +```{julia} +@chain bmk() begin + @group_by(dataname, generator, model, variable) + @summarize(mean=mean(value),sd=std(value)) + @ungroup + @filter(variable == "distance_from_energy") +end +``` + + +```{julia} +df = @chain bmk() begin + @filter(variable in [ + "distance_from_energy", + "distance_from_targets", + "distance",]) + @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable)) + @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable)) + @mutate(variable = ifelse.(variable .== "distance", "Cost", variable)) +end +plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * + mapping(:generator, :value, row=:variable, col=:model, color=:generator) +plt = draw( + plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), + facet=(; linkyaxes=:minimal) +) +display(plt) +save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5) ``` \ No newline at end of file diff --git a/paper/paper.pdf b/paper/paper.pdf index 3b93d70fe1c3e7aac79f8f04500e1726d1e9db7f..d5b4209b0dad3fa1980d19c179fd14fc66c31852 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 509548a838344c977f205b4ce8b75c5f9f748e03..bcd699752790231fcf75e52502cd72f45b8a790a 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -246,7 +246,7 @@ The first two terms in Equation~\ref{eq:eccco} correspond to the counterfactual \begin{minipage}[c]{0.40\textwidth} \centering \includegraphics[width=\textwidth]{../artifacts/results/images/eccco_illustration.png} - \captionof{figure}{Vector fields indicating the direction of gradients with respect to the different components of the ECCCo objective (Equation~\ref{eq:eccco}).} \label{fig:eccco} + \captionof{figure}{[PLACEHOLDER] Vector fields indicating the direction of gradients with respect to the different components of the ECCCo objective (Equation~\ref{eq:eccco}).} \label{fig:eccco} \end{minipage} \hfill \begin{minipage}[c]{0.50\textwidth} @@ -271,7 +271,7 @@ The first two terms in Equation~\ref{eq:eccco} correspond to the counterfactual \begin{minipage}[c]{\textwidth} \includegraphics[width=\textwidth]{../artifacts/results/images/mnist_eccco.png} - \captionof{figure}{Original image (left) and ECCCos for turning an 8 (eight) into a 3 (three) for different Black Boxes from left to right: Multi-Layer Perceptron (MLP), Ensemble of MLPs, Joint Energy Model (JEM), Ensemble of JEMs.}\label{fig:eccco-mnist} + \captionof{figure}{[SUBJECTO TO CHANGE] Original image (left) and ECCCos for turning an 8 (eight) into a 3 (three) for different Black Boxes from left to right: Multi-Layer Perceptron (MLP), Ensemble of MLPs, Joint Energy Model (JEM), Ensemble of JEMs.}\label{fig:eccco-mnist} \end{minipage} \medskip @@ -282,7 +282,7 @@ Finally, we search counterfactuals through gradient descent. Let $\mathcal{L}(\m Figure~\ref{fig:eccco-mnist} presents ECCCos for the MNIST example from Section~\ref{background} for various Black Box models of increasing complexity from left to right: a simple Multi-Layer Perceptron (MLP); an Ensemble of MLPs, each of the same architecture as the single MLP; a Joint Energy Model (JEM) based on the same MLP architecture; and finally, an Ensemble of these JEMs. Since Deep Ensembles have an improved capacity for predictive uncertainty quantification and JEMs are explicitly trained to learn plausible representations of the input data, it is intuitive to see that the plausibility of counterfactuals visibly improves from left to right. This provides some first anecdotal evidence that ECCCos achieve plausibility while maintaining faithfulness to the Black Box. -\section{Experiments}\label{conformity} +\section{Empirical Analysis}\label{emp} In this section, we bolster our anecdotal findings from the previous section through rigorous empirical analysis. We first briefly describe our evaluation framework and data, before presenting and discussing our results. @@ -308,9 +308,22 @@ This measure is straightforward to compute and should be less sensitive to outli As noted by \citet{guidotti2022counterfactual}, these distance-based measures are simplistic and more complex alternative measures may ultimately be more appropriate for the task. For example, we considered using statistical divergence measures instead. This would involve generating not one but many counterfactuals and comparing the generated empirical distribution to the target distributions in Definitions~\ref{def:plausible} and~\ref{def:conformal}. While this approach is potentially more rigorous, generating enough counterfactuals is not always practical. -\section{Experiments} +\subsection{Data} +\subsection{Results} +\begin{figure} + \includegraphics[width=\textwidth]{../artifacts/results/images/mnist_benchmark.png} + \caption{[SUBJECTO TO CHANGE] Original image (left) and ECCCos for turning an 8 (eight) into a 3 (three) for different Black Boxes from left to right: Multi-Layer Perceptron (MLP), Ensemble of MLPs, Joint Energy Model (JEM), Ensemble of JEMs.}\label{fig:mnist-benchmark} +\end{figure} + +\section{Discussion} + +\subsection{Key Insights} + +Consistent with the findings in \citet{schut2021generating}, we have demonstrated that predictive uncertainty estimates can be leveraged to generate plausible counterfactuals. Interestingly, \citet{schut2021generating} point out that this finding --- as intuitive as it is --- may be linked to a positive connection between the generative task and predictive uncertainty quantification. In particular, \citet{grathwohl2020your} demonstrate that their proposed method for integrating the generative objective in training yields models that have improved predictive uncertainty quantification. Since neither \citet{schut2021generating} nor we have employed any surrogate generative models, our findings seem to indicate that the positive connection found in \citet{grathwohl2020your} is bidirectional. + +\subsection{Limitations} \begin{itemize} \item BatchNorm does not seem compatible with JEM @@ -323,10 +336,7 @@ As noted by \citet{guidotti2022counterfactual}, these distance-based measures ar \item For MNIST it seems that ECCCo is better at reducing pixel values than increasing them (better at erasing than writing) \end{itemize} -\section{Discussion} - -Consistent with the findings in \citet{schut2021generating}, we have demonstrated that predictive uncertainty estimates can be leveraged to generate plausible counterfactuals. Interestingly, \citet{schut2021generating} point out that this finding --- as intuitive as it is --- may be linked to a positive connection between the generative task and predictive uncertainty quantification. In particular, \citet{grathwohl2020your} demonstrate that their proposed method for integrating the generative objective in training yields models that have improved predictive uncertainty quantification. Since neither \citet{schut2021generating} nor we have employed any surrogate generative models, our findings seem to indicate that the positive connection found in \citet{grathwohl2020your} is bidirectional. - +\section{Conclusion} \medskip diff --git a/src/penalties.jl b/src/penalties.jl index 92b007d06e219bf2bb3b69952d286619a0c5f8a7..3afe195ff82a60f33934bf1c5a5bd584f592f478 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -47,7 +47,7 @@ function distance_from_energy( _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) end sampler = _dict[:energy_sampler] - push!(conditional_samples, rand(sampler, 100; from_buffer=from_buffer)) + push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) end x′ = CounterfactualExplanations.counterfactual(ce) loss = map(eachslice(x′, dims=ndims(x′))) do x @@ -70,10 +70,9 @@ function distance_from_targets( target_samples = ce.data.X[:,target_idx] |> X -> X[:,rand(1:end,n)] x′ = CounterfactualExplanations.counterfactual(ce) - loss = map(eachslice(x′, dims=3)) do x - x = Matrix(x) + loss = map(eachslice(x′, dims=ndims(x′))) do x Δ = map(eachcol(target_samples)) do xsample - norm(x - xsample) + norm(x - xsample, 1) end return mean(Δ) end