# Deps: using Chain: @chain using ConformalPrediction using CounterfactualExplanations using CounterfactualExplanations.Data using CounterfactualExplanations.DataPreprocessing: train_test_split using CounterfactualExplanations.Evaluation: benchmark, evaluate, Benchmark using CounterfactualExplanations.Generators: JSMADescent using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs using CounterfactualExplanations.Objectives using CounterfactualExplanations.Parallelization using CSV using Dates using DataFrames using Distributions: Normal, Distribution, Categorical, Uniform using ECCCo using Flux using Flux.Optimise: Optimiser, Descent, Adam, ClipValue using JointEnergyModels using LazyArtifacts using Logging using Metalhead using MLJ: TunedModel, Grid, CV, fitted_params, report using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table, machine, fit!, Supervised using MLJEnsembles using MLJFlux using Random using Serialization using Statistics import MPI import MultivariateStats Random.seed!(2023) ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # avoid command prompt and just download data # Scripts: include("experiment.jl") include("grid_search.jl") include("data/data.jl") include("models/models.jl") include("model_tuning.jl") include("benchmarking/benchmarking.jl") include("post_processing/post_processing.jl") include("utils.jl") include("save_best.jl") # Number of counterfactuals: n_ind_specified = false if any(contains.(ARGS, "n_individuals=")) n_ind_specified = true n_individuals = ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> x -> replace(x, "n_individuals=" => "") |> x -> parse(Int, x) else n_individuals = 100 end "Number of individuals to use in benchmarking." const N_IND = n_individuals "Boolean flag to check if number of individuals was specified." const N_IND_SPECIFIED = n_ind_specified # Number of tasks per process: if any(contains.(ARGS, "n_each=")) n_each = ARGS[findall(contains.(ARGS, "n_each="))][1] |> x -> replace(x, "n_each=" => "") |> x -> x == "nothing" ? nothing : parse(Int, x) else n_each = 32 end "Number of objects to pass to each process." const N_EACH = n_each # Number of benchmark runs: if any(contains.(ARGS, "n_runs=")) n_runs = ARGS[findall(contains.(ARGS, "n_runs="))][1] |> x -> replace(x, "n_runs=" => "") |> x -> parse(Int, x) else n_runs = 1 end "Number of benchmark runs." const N_RUNS = n_runs # Parallelization: plz = nothing if "threaded" ∈ ARGS const USE_THREADS = true plz = ThreadsParallelizer() else const USE_THREADS = false end if "mpi" ∈ ARGS MPI.Init() const USE_MPI = true plz = MPIParallelizer(MPI.COMM_WORLD; threaded = USE_THREADS, n_each = N_EACH) if MPI.Comm_rank(MPI.COMM_WORLD) != 0 global_logger(NullLogger()) else @info "Multi-processing using MPI. Disabling logging on non-root processes." if USE_THREADS @info "Multi-threading using $(Threads.nthreads()) threads." if Threads.threadid() != 1 global_logger(NullLogger()) end end end else const USE_MPI = false end const PLZ = plz # Constants: const LATEST_VERSION = "1.9.3" const ARTIFACT_NAME = "results_aaai" const ARTIFACT_TOML = LazyArtifacts.find_artifacts_toml(".") const ARTIFACT_HASH = artifact_hash(ARTIFACT_NAME, ARTIFACT_TOML) const LATEST_ARTIFACT_PATH = joinpath(artifact_path(ARTIFACT_HASH), ARTIFACT_NAME) time_stamped = false if any(contains.(ARGS, "output_path")) @assert sum(contains.(ARGS, "output_path")) == 1 "Only one output path can be specified." _path = ARGS[findall(contains.(ARGS, "output_path"))][1] |> x -> replace(x, "output_path=" => "") elseif isinteractive() @info "You are running experiments interactively. By default, results will be saved in a temporary directory." _path = tempdir() else timestamp = Dates.format(now(), "yyyy-mm-dd@HH:MM") time_stamped = true _path = "$(pwd())/results_$(timestamp)" end "Default output path." const DEFAULT_OUTPUT_PATH = _path const TIME_STAMPED = time_stamped "Boolean flag to only train models." const ONLY_MODELS = "only_models" ∈ ARGS "Boolean flag to retrain models." const RETRAIN = "retrain" ∈ ARGS || ONLY_MODELS "Default model performance measures." const MODEL_MEASURES = Dict( :f1score => multiclass_f1score, :acc => accuracy, :precision => multiclass_precision, ) "Default coverage rate." const DEFAULT_COVERAGE = 0.95 "The default benchmarking measures." const CE_MEASURES = [ CounterfactualExplanations.distance, ECCCo.distance_from_energy, ECCCo.distance_from_energy_l2, ECCCo.distance_from_targets, ECCCo.distance_from_targets_l2, CounterfactualExplanations.Evaluation.validity, CounterfactualExplanations.Evaluation.redundancy, ECCCo.set_size_penalty, ] "Test set proportion." const TEST_SIZE = 0.2 "Boolean flag to check if upload was specified." const UPLOAD = "upload" ∈ ARGS "Boolean flag to check if grid search was specified." const GRID_SEARCH = "grid_search" ∈ ARGS "Generator tuning parameters." DEFAULT_GENERATOR_TUNING = ( Λ=[[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.1, 0.1, 0.5],], reg_strength = [0.0, 0.1, 0.5], opt = [ Descent(0.01), Descent(0.05), ], decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Generator tuning parameters for large datasets." DEFAULT_GENERATOR_TUNING_LARGE = ( Λ=[[0.1, 0.1, 0.1], [0.1, 0.1, 0.2], [0.1, 0.1, 0.5],], reg_strength=[0.0, 0.1, 0.5], opt = [ Descent(0.01), Descent(0.05), ], decay = [(0.0, 1), (0.01, 1), (0.1, 1)], ) "Boolean flag to check if model tuning was specified." const TUNE_MODEL = "tune_model" ∈ ARGS "Model tuning parameters for small datasets." DEFAULT_MODEL_TUNING_SMALL = (n_hidden = [16, 32, 64], n_layers = [1, 2, 3]) "Model tuning parameters for large datasets." DEFAULT_MODEL_TUNING_LARGE = (n_hidden = [32, 64, 128, 512], n_layers = [2, 3, 5]) "Boolean flag to check if store counterfactual explanations was specified." STORE_CE = "store_ce" ∈ ARGS "Boolean flag to chech if best outcome from grid search should be used." FROM_GRID_SEARCH = "from_grid" ∈ ARGS