From 118da1ae19dee0fcb02dd436ae6fc166a93a587b Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Wed, 18 Oct 2023 17:41:24 +0200
Subject: [PATCH] bloody hell

---
 experiments/Manifest.toml                          |  4 ++--
 experiments/experiment.jl                          |  2 +-
 experiments/grid_search.jl                         | 10 ++++++----
 .../jobscripts/generators/california_housing.sh    | 12 +++++++-----
 .../tuning/generators/california_housing.sh        |  2 +-
 experiments/jobscripts/tuning/generators/gmsc.sh   |  2 +-
 experiments/utils.jl                               |  2 ++
 src/sampling.jl                                    |  4 +++-
 src/utils.jl                                       | 14 ++++++++++++++
 9 files changed, 37 insertions(+), 15 deletions(-)

diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml
index 48d8a3c8..4a9e1575 100644
--- a/experiments/Manifest.toml
+++ b/experiments/Manifest.toml
@@ -438,11 +438,11 @@ version = "0.6.3"
 
 [[deps.CounterfactualExplanations]]
 deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"]
-git-tree-sha1 = "6ac3b59bcd9d8c75dcbe83822f050247d1143334"
+git-tree-sha1 = "7962ccc6f9e41b3f197b1050e950aca9e3630a18"
 repo-rev = "main"
 repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git"
 uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
-version = "0.1.30"
+version = "0.1.31"
 
     [deps.CounterfactualExplanations.extensions]
     MPIExt = "MPI"
diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index a375158e..1d733971 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -9,7 +9,7 @@ Base.@kwdef struct Experiment
     use_pretrained::Bool = !RETRAIN
     models::Union{Nothing,Dict} = nothing
     additional_models::Union{Nothing,Dict} = nothing
-    𝒟x::Distribution = Normal()
+    𝒟x::Distribution = ECCCo.prior_sampling_space(counterfactual_data)
     sampling_batch_size::Int = 50
     sampling_steps::Int = 50
     min_batch_size::Int = 128
diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl
index 36904a08..29ee1982 100644
--- a/experiments/grid_search.jl
+++ b/experiments/grid_search.jl
@@ -212,10 +212,10 @@ function best_absolute_outcome(
         higher_is_better = [var ∈ ["validity", "redundancy"] for var in evaluation.variable]
         evaluation.value[higher_is_better] .= -evaluation.value[higher_is_better]
 
-        # Normalise to allow for comparison across measures:
-        evaluation =
-            groupby(evaluation, [:dataname, :variable]) |>
-            x -> transform(x, :value => standardize => :value)
+        # # Normalise to allow for comparison across measures:
+        # evaluation =
+        #     groupby(evaluation, [:dataname, :variable]) |>
+        #     x -> transform(x, :value => standardize => :value)
 
         # Reconstruct outcome with normalised values:
         bmk = CounterfactualExplanations.Evaluation.Benchmark(evaluation)
@@ -239,6 +239,8 @@ function best_absolute_outcome(
         params=outcomes[:df_outcomes].params[best_index],
         outcome=outcomes[:df_outcomes].outcome[best_index],
     )
+
+    return best_outcome
 end
 
 best_absolute_outcome_eccco(outcomes; kwrgs...) =
diff --git a/experiments/jobscripts/generators/california_housing.sh b/experiments/jobscripts/generators/california_housing.sh
index 07d5770e..446cb9d9 100644
--- a/experiments/jobscripts/generators/california_housing.sh
+++ b/experiments/jobscripts/generators/california_housing.sh
@@ -1,14 +1,16 @@
 #!/bin/bash
 
 #SBATCH --job-name="California Housing (ECCCo)"
-#SBATCH --time=3:00:00
-#SBATCH --ntasks=100
-#SBATCH --cpus-per-task=1
+#SBATCH --time=00:30:00
+#SBATCH --ntasks=30
+#SBATCH --cpus-per-task=10
 #SBATCH --partition=compute
-#SBATCH --mem-per-cpu=8GB
+#SBATCH --mem-per-cpu=4GB
 #SBATCH --account=research-eemcs-insy
 #SBATCH --mail-type=END     # Set mail type to 'END' to receive a mail when the job finishes. 
 
 module load 2023r1 openmpi
 
-srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi > experiments/california_housing.log
+source experiments/slurm_header.sh
+
+srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi threaded n_individuals=100 n_runs=5 > experiments/logs/california_housing.log
diff --git a/experiments/jobscripts/tuning/generators/california_housing.sh b/experiments/jobscripts/tuning/generators/california_housing.sh
index 5e8f4f40..e1805b58 100644
--- a/experiments/jobscripts/tuning/generators/california_housing.sh
+++ b/experiments/jobscripts/tuning/generators/california_housing.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
 #SBATCH --job-name="Grid-search California Housing (ECCCo)"
-#SBATCH --time=01:30:00
+#SBATCH --time=01:20:00
 #SBATCH --ntasks=30
 #SBATCH --cpus-per-task=10
 #SBATCH --partition=compute
diff --git a/experiments/jobscripts/tuning/generators/gmsc.sh b/experiments/jobscripts/tuning/generators/gmsc.sh
index 9b84593e..848b074f 100644
--- a/experiments/jobscripts/tuning/generators/gmsc.sh
+++ b/experiments/jobscripts/tuning/generators/gmsc.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
 #SBATCH --job-name="Grid-search GMSC (ECCCo)"
-#SBATCH --time=01:30:00
+#SBATCH --time=01:20:00
 #SBATCH --ntasks=30
 #SBATCH --cpus-per-task=10
 #SBATCH --partition=compute
diff --git a/experiments/utils.jl b/experiments/utils.jl
index be94743d..5b77538c 100644
--- a/experiments/utils.jl
+++ b/experiments/utils.jl
@@ -1,6 +1,8 @@
 using CounterfactualExplanations.Parallelization: ThreadsParallelizer
+using Distributions: Uniform
 using Flux
 using LinearAlgebra: norm
+using Statistics: mean, std
 
 function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer})
     if isnothing(parallelizer) || isa(parallelizer, ThreadsParallelizer)
diff --git a/src/sampling.jl b/src/sampling.jl
index aedd495b..75776c2c 100644
--- a/src/sampling.jl
+++ b/src/sampling.jl
@@ -46,8 +46,10 @@ function EnergySampler(
 
     K = length(data.y_levels)
     input_size = size(selectdim(data.X, ndims(data.X), 1))
-    𝒟x = Uniform(extrema(data.X)...)
+    # Prior distribution:
+    𝒟x = prior_sampling_space(data)
     𝒟y = Categorical(ones(K) ./ K)
+    # Sampler:
     sampler = ConditionalSampler(𝒟x, 𝒟y; input_size = input_size)
     yidx = get_target_index(data.y_levels, y)
 
diff --git a/src/utils.jl b/src/utils.jl
index a476fc6c..a2fae145 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -57,3 +57,17 @@ function ssim_dist(x, y)
     y = convert2mnist(y)
     return (1 - assess_ssim(x, y)) / 2
 end
+
+"""
+    prior_sampling_space(data::CounterfactualData; n_std=3)
+
+Define the prior sampling space for the data.
+"""
+function prior_sampling_space(data::CounterfactualData; n_std=3)
+    X = data.X
+    centers = mean(X, dims=2)
+    stds = std(X, dims=2)
+    lower_bound = minimum(centers .- n_std .* stds)[1]
+    upper_bound = maximum(centers .+ n_std .* stds)[1]
+    return Uniform(lower_bound, upper_bound)
+end
\ No newline at end of file
-- 
GitLab