Skip to content
Snippets Groups Projects
Commit 006d7136 authored by Pat Alt's avatar Pat Alt
Browse files

finally more or less sorted

parent d6608211
No related branches found
No related tags found
1 merge request!68Post rebuttal
This diff is collapsed.
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=test_size)
n_obs = Int(1000 / (1.0 - TEST_SIZE))
counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); TEST_SIZE=TEST_SIZE)
run_experiment!(
counterfactual_data, test_data; dataname="Circles",
n_hidden=32,
......
# General setup:
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
test_size = 0.2
# Constants:
const DEFAULT_OUTPUT_PATH = "$(pwd())/results"
const RETRAIN = "retrain" ARGS ? true : false
"Default model performance measures."
const MODEL_MEASURES = Dict(
:f1score => multiclass_f1score,
:acc => accuracy,
:precision => multiclass_precision
)
"Default coverage rate."
const DEFAULT_COVERAGE = 0.95
"The default benchmarking measures."
const CE_MEASURES = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
ECCCo.set_size_penalty
]
# Pre-trained models:
function pretrained_path()
@info "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5")
end
"Sets up the experiment."
Base.@kwdef struct Experiment
counterfactual_data::CounterfactualData
......
counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=test_size)
counterfactual_data, test_data = train_test_split(load_gmsc(nothing); TEST_SIZE=TEST_SIZE)
run_experiment!(
counterfactual_data, test_data; dataname="GMSC",
n_hidden=128,
......
n_obs = Int(1000 / (1.0 - test_size))
n_obs = Int(1000 / (1.0 - TEST_SIZE))
counterfactual_data, test_data = train_test_split(
load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0));
test_size=test_size
TEST_SIZE=TEST_SIZE
)
run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable")
\ No newline at end of file
n_obs = Int(2500 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=test_size)
n_obs = Int(2500 / (1.0 - TEST_SIZE))
counterfactual_data, test_data = train_test_split(load_moons(n_obs); TEST_SIZE=TEST_SIZE)
run_experiment!(
counterfactual_data, test_data; dataname="Moons",
epochs=500,
......
include("setup.jl")
include("setup_env.jl")
include("experiment.jl")
# User inputs:
if "run-all" in ARGS
......
using Pkg
Pkg.activate(@__DIR__)
# Deps:
using CounterfactualExplanations
using CounterfactualExplanations.Data
using CounterfactualExplanations.DataPreprocessing: train_test_split
using CounterfactualExplanations.Evaluation: benchmark, evaluate, Benchmark
using CounterfactualExplanations.Generators: JSMADescent
using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs
using CounterfactualExplanations.Objectives
using CSV
using Distributions
using ECCCo
using JointEnergyModels
using LazyArtifacts
using MLJBase: multiclass_f1score, accuracy, multiclass_precision
using MLJEnsembles
using MLJFlux
# Constants:
const LATEST_VERSION = "1.8.5"
const ARTIFACT_NAME = "results-paper-submission-$(LATEST_VERSION)"
artifact_toml = LazyArtifacts.find_artifacts_toml(".")
_hash = artifact_hash(ARTIFACT_NAME, artifact_toml)
const LATEST_ARTIFACT_PATH = artifact_path(_hash)
# Pre-trained models:
function pretrained_path()
@info "Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions."
return LATEST_ARTIFACT_PATH
end
"Default output path."
const DEFAULT_OUTPUT_PATH = "$(pwd())/results"
"Boolean flag to retrain models."
const RETRAIN = "retrain" ARGS ? true : false
"Default model performance measures."
const MODEL_MEASURES = Dict(
:f1score => multiclass_f1score,
:acc => accuracy,
:precision => multiclass_precision
)
"Default coverage rate."
const DEFAULT_COVERAGE = 0.95
"The default benchmarking measures."
const CE_MEASURES = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
ECCCo.set_size_penalty
]
"Test set proportion."
const TEST_SIZE = 0.2
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment