Skip to content
Snippets Groups Projects
Commit 118da1ae authored by pat-alt's avatar pat-alt
Browse files

bloody hell

parent ea3f48a7
No related branches found
No related tags found
1 merge request!8985 overshooting
...@@ -438,11 +438,11 @@ version = "0.6.3" ...@@ -438,11 +438,11 @@ version = "0.6.3"
[[deps.CounterfactualExplanations]] [[deps.CounterfactualExplanations]]
deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"]
git-tree-sha1 = "6ac3b59bcd9d8c75dcbe83822f050247d1143334" git-tree-sha1 = "7962ccc6f9e41b3f197b1050e950aca9e3630a18"
repo-rev = "main" repo-rev = "main"
repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git" repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
version = "0.1.30" version = "0.1.31"
[deps.CounterfactualExplanations.extensions] [deps.CounterfactualExplanations.extensions]
MPIExt = "MPI" MPIExt = "MPI"
......
...@@ -9,7 +9,7 @@ Base.@kwdef struct Experiment ...@@ -9,7 +9,7 @@ Base.@kwdef struct Experiment
use_pretrained::Bool = !RETRAIN use_pretrained::Bool = !RETRAIN
models::Union{Nothing,Dict} = nothing models::Union{Nothing,Dict} = nothing
additional_models::Union{Nothing,Dict} = nothing additional_models::Union{Nothing,Dict} = nothing
𝒟x::Distribution = Normal() 𝒟x::Distribution = ECCCo.prior_sampling_space(counterfactual_data)
sampling_batch_size::Int = 50 sampling_batch_size::Int = 50
sampling_steps::Int = 50 sampling_steps::Int = 50
min_batch_size::Int = 128 min_batch_size::Int = 128
......
...@@ -212,10 +212,10 @@ function best_absolute_outcome( ...@@ -212,10 +212,10 @@ function best_absolute_outcome(
higher_is_better = [var ["validity", "redundancy"] for var in evaluation.variable] higher_is_better = [var ["validity", "redundancy"] for var in evaluation.variable]
evaluation.value[higher_is_better] .= -evaluation.value[higher_is_better] evaluation.value[higher_is_better] .= -evaluation.value[higher_is_better]
# Normalise to allow for comparison across measures: # # Normalise to allow for comparison across measures:
evaluation = # evaluation =
groupby(evaluation, [:dataname, :variable]) |> # groupby(evaluation, [:dataname, :variable]) |>
x -> transform(x, :value => standardize => :value) # x -> transform(x, :value => standardize => :value)
# Reconstruct outcome with normalised values: # Reconstruct outcome with normalised values:
bmk = CounterfactualExplanations.Evaluation.Benchmark(evaluation) bmk = CounterfactualExplanations.Evaluation.Benchmark(evaluation)
...@@ -239,6 +239,8 @@ function best_absolute_outcome( ...@@ -239,6 +239,8 @@ function best_absolute_outcome(
params=outcomes[:df_outcomes].params[best_index], params=outcomes[:df_outcomes].params[best_index],
outcome=outcomes[:df_outcomes].outcome[best_index], outcome=outcomes[:df_outcomes].outcome[best_index],
) )
return best_outcome
end end
best_absolute_outcome_eccco(outcomes; kwrgs...) = best_absolute_outcome_eccco(outcomes; kwrgs...) =
......
#!/bin/bash #!/bin/bash
#SBATCH --job-name="California Housing (ECCCo)" #SBATCH --job-name="California Housing (ECCCo)"
#SBATCH --time=3:00:00 #SBATCH --time=00:30:00
#SBATCH --ntasks=100 #SBATCH --ntasks=30
#SBATCH --cpus-per-task=1 #SBATCH --cpus-per-task=10
#SBATCH --partition=compute #SBATCH --partition=compute
#SBATCH --mem-per-cpu=8GB #SBATCH --mem-per-cpu=4GB
#SBATCH --account=research-eemcs-insy #SBATCH --account=research-eemcs-insy
#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. #SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes.
module load 2023r1 openmpi module load 2023r1 openmpi
srun julia --project=experiments experiments/run_experiments.jl -- data=california_housing output_path=results mpi > experiments/california_housing.log source experiments/slurm_header.sh
srun julia --project=experiments --threads $SLURM_CPUS_PER_TASK experiments/run_experiments.jl -- data=california_housing output_path=results mpi threaded n_individuals=100 n_runs=5 > experiments/logs/california_housing.log
#!/bin/bash #!/bin/bash
#SBATCH --job-name="Grid-search California Housing (ECCCo)" #SBATCH --job-name="Grid-search California Housing (ECCCo)"
#SBATCH --time=01:30:00 #SBATCH --time=01:20:00
#SBATCH --ntasks=30 #SBATCH --ntasks=30
#SBATCH --cpus-per-task=10 #SBATCH --cpus-per-task=10
#SBATCH --partition=compute #SBATCH --partition=compute
......
#!/bin/bash #!/bin/bash
#SBATCH --job-name="Grid-search GMSC (ECCCo)" #SBATCH --job-name="Grid-search GMSC (ECCCo)"
#SBATCH --time=01:30:00 #SBATCH --time=01:20:00
#SBATCH --ntasks=30 #SBATCH --ntasks=30
#SBATCH --cpus-per-task=10 #SBATCH --cpus-per-task=10
#SBATCH --partition=compute #SBATCH --partition=compute
......
using CounterfactualExplanations.Parallelization: ThreadsParallelizer using CounterfactualExplanations.Parallelization: ThreadsParallelizer
using Distributions: Uniform
using Flux using Flux
using LinearAlgebra: norm using LinearAlgebra: norm
using Statistics: mean, std
function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer}) function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer})
if isnothing(parallelizer) || isa(parallelizer, ThreadsParallelizer) if isnothing(parallelizer) || isa(parallelizer, ThreadsParallelizer)
......
...@@ -46,8 +46,10 @@ function EnergySampler( ...@@ -46,8 +46,10 @@ function EnergySampler(
K = length(data.y_levels) K = length(data.y_levels)
input_size = size(selectdim(data.X, ndims(data.X), 1)) input_size = size(selectdim(data.X, ndims(data.X), 1))
𝒟x = Uniform(extrema(data.X)...) # Prior distribution:
𝒟x = prior_sampling_space(data)
𝒟y = Categorical(ones(K) ./ K) 𝒟y = Categorical(ones(K) ./ K)
# Sampler:
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size = input_size) sampler = ConditionalSampler(𝒟x, 𝒟y; input_size = input_size)
yidx = get_target_index(data.y_levels, y) yidx = get_target_index(data.y_levels, y)
......
...@@ -57,3 +57,17 @@ function ssim_dist(x, y) ...@@ -57,3 +57,17 @@ function ssim_dist(x, y)
y = convert2mnist(y) y = convert2mnist(y)
return (1 - assess_ssim(x, y)) / 2 return (1 - assess_ssim(x, y)) / 2
end end
"""
prior_sampling_space(data::CounterfactualData; n_std=3)
Define the prior sampling space for the data.
"""
function prior_sampling_space(data::CounterfactualData; n_std=3)
X = data.X
centers = mean(X, dims=2)
stds = std(X, dims=2)
lower_bound = minimum(centers .- n_std .* stds)[1]
upper_bound = maximum(centers .+ n_std .* stds)[1]
return Uniform(lower_bound, upper_bound)
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment