diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index 70510d5a1dd59cdcf4de7f85504937b14e0fa558..934a3d38af9291e21ccc704b1382f41fc049265d 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.9.2" manifest_format = "2.0" -project_hash = "a981b871c7220e241da2ed077ce7f07968e32677" +project_hash = "59bc4ed7aabf01af8804eba2ac9b602604180c6c" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] diff --git a/experiments/Project.toml b/experiments/Project.toml index 0c27e4aaf83a4709dee692da80b9185b5882046f..c99239a4bea520d308430124d0f0cb9737ab5a4f 100644 --- a/experiments/Project.toml +++ b/experiments/Project.toml @@ -1,11 +1,14 @@ [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" diff --git a/experiments/circles.jl b/experiments/circles.jl index 0872ba1542bbc763c75af3d3ad0e4c0d49a92589..f67c83a0bef5624fba55bb95a15d0a33c7d1a414 100644 --- a/experiments/circles.jl +++ b/experiments/circles.jl @@ -1,5 +1,5 @@ n_obs = Int(1000 / (1.0 - TEST_SIZE)) -counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); TEST_SIZE=TEST_SIZE) +counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE) run_experiment!( counterfactual_data, test_data; dataname="Circles", n_hidden=32, diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 3394be9111eb7de39904f0f834b2b17929b493e6..41559715dc002eb3a06bf098bb4249a03c3ee619 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -8,7 +8,7 @@ Base.@kwdef struct Experiment params_path::String = joinpath(output_path, "params") use_pretrained::Bool = true models::Union{Nothing,Dict} = nothing - builder::Union{Nothing,MLJFlux.GenericBuilder} = nothing + builder::Union{Nothing,MLJFlux.Builder} = nothing ð’Ÿx::Distribution = Normal() sampling_batch_size::Int = 50 min_batch_size::Int = 128 @@ -69,8 +69,9 @@ Overload the `run_experiment` function to allow for passing in `CounterfactualDa """ function run_experiment!(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) # Parameters: - exp = Experiment( - counterfactual_data, test_data; + exp = Experiment(; + counterfactual_data=counterfactual_data, + test_data=test_data, kwargs... ) return run_experiment!(exp) diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index baf5e580b07ba2a070fc89aa394c9f9594dec037..62575fd5b3b6d937a6bf611135d69c2f5d5a7fd2 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -1,4 +1,4 @@ -counterfactual_data, test_data = train_test_split(load_gmsc(nothing); TEST_SIZE=TEST_SIZE) +counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE) run_experiment!( counterfactual_data, test_data; dataname="GMSC", n_hidden=128, diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index c95a050998d8e51f2010b0f030589535f13d429f..f884bd206b59fffefde9ac83349408e862761a83 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -1,6 +1,6 @@ n_obs = Int(1000 / (1.0 - TEST_SIZE)) counterfactual_data, test_data = train_test_split( load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0)); - TEST_SIZE=TEST_SIZE + test_size=TEST_SIZE ) run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable") \ No newline at end of file diff --git a/experiments/models/default_models.jl b/experiments/models/default_models.jl index 4b87ebce15cc741c84a5a720910f095e8c36727c..e57e926f95e7d7c551c08e3200ac88b29fa1a556 100644 --- a/experiments/models/default_models.jl +++ b/experiments/models/default_models.jl @@ -30,7 +30,7 @@ Builds a dictionary of default models for training. """ function default_models(; sampler::AbstractSampler, - builder::MLJFlux.GenericBuilder=default_builder(), + builder::MLJFlux.Builder=default_builder(), epochs::Int=25, batch_size::Int=128, finaliser::Function=Flux.softmax, diff --git a/experiments/models/models.jl b/experiments/models/models.jl index b3e3286d18abdec4b5f3b6b3b4e512680002222c..f9d645bf88295b1a6c51ba5fb2fb42f794d5f05f 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -7,26 +7,24 @@ function prepare_models(exp::Experiment) # Unpack data: X, labels, sampler = prepare_data(exp::Experiment) - # Setup: - if isnothing(exp.builder) - builder = default_builder() - end - if isnothing(exp.models) - @info "Using default models." - models = default_models(; - sampler=sampler, - builder=builder, - batch_size=batch_size(exp) - ) - end - # Training: - if !pretrained + if !exp.use_pretrained + if isnothing(exp.builder) + builder = default_builder() + end + if isnothing(exp.models) + @info "Using default models." + models = default_models(; + sampler=sampler, + builder=builder, + batch_size=batch_size(exp) + ) + end @info "Training models." model_dict = train_models(models, X, labels; coverage=exp.coverage) else @info "Loading pre-trained models." - model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(exp.save_name)_models.jls")) + model_dict = Serialization.deserialize(joinpath(pretrained_path(), "results/$(exp.save_name)_models.jls")) end # Save models: diff --git a/experiments/moons.jl b/experiments/moons.jl index 198a605edbf325cea64895e0f979119956100c5b..76aa72c53642f1291bb319121a2ca9ff9a45bf72 100644 --- a/experiments/moons.jl +++ b/experiments/moons.jl @@ -1,5 +1,5 @@ n_obs = Int(2500 / (1.0 - TEST_SIZE)) -counterfactual_data, test_data = train_test_split(load_moons(n_obs); TEST_SIZE=TEST_SIZE) +counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=TEST_SIZE) run_experiment!( counterfactual_data, test_data; dataname="Moons", epochs=500, diff --git a/experiments/post_processing.jl b/experiments/post_processing.jl index 58d17b825fc1bbc096e4a9b2e77473e19372c750..144b84b3e7785a3a8277eadbaa74e454110c6407 100644 --- a/experiments/post_processing.jl +++ b/experiments/post_processing.jl @@ -88,7 +88,7 @@ function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Noth _perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measures))) _perf = DataFrame([[p] for p in _perf], collect(keys(measures))) _perf.mod_name .= mod_name - _perf.dataname .= dataname + _perf.dataname .= exp.dataname model_performance = vcat(model_performance, _perf) end diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 1fd73899bf645e9c7996d6ccd4ddf31b3264f8b7..f2f84e444ad21bb588beb1f798b64d864d5ee171 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -10,20 +10,23 @@ using CounterfactualExplanations.Generators: JSMADescent using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs using CounterfactualExplanations.Objectives using CSV -using Distributions +using DataFrames +using Distributions: Normal, Distribution, Categorical using ECCCo +using Flux using JointEnergyModels using LazyArtifacts -using MLJBase: multiclass_f1score, accuracy, multiclass_precision +using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table using MLJEnsembles using MLJFlux +using Serialization # Constants: const LATEST_VERSION = "1.8.5" const ARTIFACT_NAME = "results-paper-submission-$(LATEST_VERSION)" artifact_toml = LazyArtifacts.find_artifacts_toml(".") _hash = artifact_hash(ARTIFACT_NAME, artifact_toml) -const LATEST_ARTIFACT_PATH = artifact_path(_hash) +const LATEST_ARTIFACT_PATH = joinpath(artifact_path(_hash), ARTIFACT_NAME) # Pre-trained models: function pretrained_path()