Skip to content
Snippets Groups Projects
Commit 4fc205c2 authored by pat-alt's avatar pat-alt
Browse files

slowly slowly

parent fa0bf925
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
...@@ -44,25 +44,25 @@ end ...@@ -44,25 +44,25 @@ end
Run the benchmarking procedure. Run the benchmarking procedure.
""" """
function run_benchmark(exp::Experiment, model_dict::Dict) function run_benchmark(exper::Experiment, model_dict::Dict)
n_individuals = exp.n_individuals n_individuals = exper.n_individuals
dataname = exp.dataname dataname = exper.dataname
counterfactual_data = exp.counterfactual_data counterfactual_data = exper.counterfactual_data
generator_dict = exp.generators generator_dict = exper.generators
measures = exp.ce_measures measures = exper.ce_measures
parallelizer = exp.parallelizer parallelizer = exper.parallelizer
# Benchmark generators: # Benchmark generators:
if isnothing(generator_dict) if isnothing(generator_dict)
generator_dict = default_generators(; generator_dict = default_generators(;
Λ=exp.Λ, Λ=exper.Λ,
Λ_Δ=exp.Λ_Δ, Λ_Δ=exper.Λ_Δ,
use_variants=exp.use_variants, use_variants=exper.use_variants,
use_class_loss=exp.use_class_loss, use_class_loss=exper.use_class_loss,
opt=exp.opt, opt=exper.opt,
nsamples=exp.nsamples, nsamples=exper.nsamples,
nmin=exp.nmin, nmin=exper.nmin,
) )
end end
......
function _prepare_data(exp::Experiment) function _prepare_data(exper::Experiment)
# Unpack data: # Unpack data:
counterfactual_data = exp.counterfactual_data counterfactual_data = exper.counterfactual_data
min_batch_size = exp.min_batch_size min_batch_size = exper.min_batch_size
sampling_batch_size = exp.sampling_batch_size sampling_batch_size = exper.sampling_batch_size
𝒟x = exp.𝒟x 𝒟x = exper.𝒟x
# Data parameters: # Data parameters:
X, _ = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X, _ = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
...@@ -27,17 +27,17 @@ function _prepare_data(exp::Experiment) ...@@ -27,17 +27,17 @@ function _prepare_data(exp::Experiment)
return X, labels, n_obs, batch_size, sampler return X, labels, n_obs, batch_size, sampler
end end
function meta_data(exp::Experiment) function meta_data(exper::Experiment)
_, _, n_obs, batch_size, _ = _prepare_data(exp::Experiment) _, _, n_obs, batch_size, _ = _prepare_data(exper::Experiment)
return n_obs, batch_size return n_obs, batch_size
end end
function prepare_data(exp::Experiment) function prepare_data(exper::Experiment)
X, labels, _, _, sampler = _prepare_data(exp::Experiment) X, labels, _, _, sampler = _prepare_data(exper::Experiment)
return X, labels, sampler return X, labels, sampler
end end
function batch_size(exp::Experiment) function batch_size(exper::Experiment)
_, _, _, batch_size, _ = _prepare_data(exp::Experiment) _, _, _, batch_size, _ = _prepare_data(exper::Experiment)
return batch_size return batch_size
end end
\ No newline at end of file
...@@ -35,65 +35,68 @@ Base.@kwdef struct Experiment ...@@ -35,65 +35,68 @@ Base.@kwdef struct Experiment
nmin::Union{Nothing,Int} = nothing nmin::Union{Nothing,Int} = nothing
finaliser::Function = Flux.softmax finaliser::Function = Flux.softmax
loss::Function = Flux.Losses.crossentropy loss::Function = Flux.Losses.crossentropy
train_parallel::Bool = false
end end
"A container to hold the results of an experiment." "A container to hold the results of an experiment."
mutable struct ExperimentOutcome mutable struct ExperimentOutcome
exp::Experiment exper::Experiment
model_dict::Union{Nothing, Dict} model_dict::Union{Nothing, Dict}
generator_dict::Union{Nothing, Dict} generator_dict::Union{Nothing, Dict}
bmk::Union{Nothing, Benchmark} bmk::Union{Nothing, Benchmark}
end end
""" """
train_models!(outcome::ExperimentOutcome, exp::Experiment) train_models!(outcome::ExperimentOutcome, exper::Experiment)
Train the models specified by `exp` and store them in `outcome`. Train the models specified by `exper` and store them in `outcome`.
""" """
function train_models!(outcome::ExperimentOutcome, exp::Experiment) function train_models!(outcome::ExperimentOutcome, exper::Experiment)
model_dict = prepare_models(exp) model_dict = prepare_models(exper)
outcome.model_dict = model_dict outcome.model_dict = model_dict
meta_model_performance(outcome) if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
meta_model_performance(outcome)
end
end end
""" """
benchmark!(outcome::ExperimentOutcome, exp::Experiment) benchmark!(outcome::ExperimentOutcome, exper::Experiment)
Benchmark the models specified by `exp` and store the results in `outcome`. Benchmark the models specified by `exper` and store the results in `outcome`.
""" """
function benchmark!(outcome::ExperimentOutcome, exp::Experiment) function benchmark!(outcome::ExperimentOutcome, exper::Experiment)
bmk, generator_dict = run_benchmark(exp, outcome.model_dict) bmk, generator_dict = run_benchmark(exper, outcome.model_dict)
outcome.bmk = bmk outcome.bmk = bmk
outcome.generator_dict = generator_dict outcome.generator_dict = generator_dict
end end
""" """
run_experiment(exp::Experiment) run_experiment(exper::Experiment)
Run the experiment specified by `exp`. Run the experiment specified by `exper`.
""" """
function run_experiment(exp::Experiment; save_output::Bool=true, only_models::Bool=ONLY_MODELS) function run_experiment(exper::Experiment; save_output::Bool=true, only_models::Bool=ONLY_MODELS)
# Setup # Setup
@info "All results will be saved to $(exp.output_path)." @info "All results will be saved to $(exper.output_path)."
isdir(exp.output_path) || mkdir(exp.output_path) isdir(exper.output_path) || mkdir(exper.output_path)
@info "All parameter choices will be saved to $(exp.params_path)." @info "All parameter choices will be saved to $(exper.params_path)."
isdir(exp.params_path) || mkdir(exp.params_path) isdir(exper.params_path) || mkdir(exper.params_path)
outcome = ExperimentOutcome(exp, nothing, nothing, nothing) outcome = ExperimentOutcome(exper, nothing, nothing, nothing)
# Models # Models
train_models!(outcome, exp) train_models!(outcome, exper)
# Return if only models are needed: # Return if only models are needed:
!only_models || return outcome !only_models || return outcome
# Benchmark # Benchmark
benchmark!(outcome, exp) benchmark!(outcome, exper)
# Save data: # Save data:
if save_output if save_output && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_outcome.jls"), outcome) Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_outcome.jls"), outcome)
Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_bmk.jls"), outcome.bmk) Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_bmk.jls"), outcome.bmk)
meta(outcome; save_output=true) meta(outcome; save_output=true)
end end
...@@ -108,19 +111,19 @@ Overload the `run_experiment` function to allow for passing in `CounterfactualDa ...@@ -108,19 +111,19 @@ Overload the `run_experiment` function to allow for passing in `CounterfactualDa
""" """
function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)
# Parameters: # Parameters:
exp = Experiment(; exper = Experiment(;
counterfactual_data=counterfactual_data, counterfactual_data=counterfactual_data,
test_data=test_data, test_data=test_data,
kwargs... kwargs...
) )
return run_experiment(exp) return run_experiment(exper)
end end
# Pre-trained models: # Pre-trained models:
function pretrained_path(exp::Experiment) function pretrained_path(exper::Experiment)
if isfile(joinpath(exp.output_path, "$(exp.save_name)_models.jls")) if isfile(joinpath(exper.output_path, "$(exper.save_name)_models.jls"))
@info "Found local pre-trained models in $(exp.output_path) and using those." @info "Found local pre-trained models in $(exper.output_path) and using those."
return exp.output_path return exper.output_path
else else
@info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions." @info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions."
return joinpath(LATEST_ARTIFACT_PATH, "results") return joinpath(LATEST_ARTIFACT_PATH, "results")
......
...@@ -2,58 +2,63 @@ include("additional_models.jl") ...@@ -2,58 +2,63 @@ include("additional_models.jl")
include("default_models.jl") include("default_models.jl")
include("train_models.jl") include("train_models.jl")
function prepare_models(exp::Experiment) function prepare_models(exper::Experiment)
# Unpack data: # Unpack data:
X, labels, sampler = prepare_data(exp::Experiment) X, labels, sampler = prepare_data(exper::Experiment)
# Training: # Training:
if !exp.use_pretrained if !exper.use_pretrained
if isnothing(exp.builder) if isnothing(exper.builder)
builder = default_builder() builder = default_builder()
else else
builder = exp.builder builder = exper.builder
end end
# Default models: # Default models:
if isnothing(exp.models) if isnothing(exper.models)
@info "Using default models." @info "Using default models."
models = default_models(; models = default_models(;
sampler=sampler, sampler=sampler,
builder=builder, builder=builder,
batch_size=batch_size(exp), batch_size=batch_size(exper),
sampling_steps=exp.sampling_steps, sampling_steps=exper.sampling_steps,
α=exp.α, α=exper.α,
n_ens=exp.n_ens, n_ens=exper.n_ens,
use_ensembling=exp.use_ensembling, use_ensembling=exper.use_ensembling,
finaliser=exp.finaliser, finaliser=exper.finaliser,
loss=exp.loss, loss=exper.loss,
epochs=exp.epochs, epochs=exper.epochs,
) )
end end
# Additional models: # Additional models:
if !isnothing(exp.additional_models) if !isnothing(exper.additional_models)
@info "Using additional models." @info "Using additional models."
add_models = Dict{Any,Any}() add_models = Dict{Any,Any}()
for (k, mod) in exp.additional_models for (k, mod) in exper.additional_models
add_models[k] = mod(; add_models[k] = mod(;
batch_size=batch_size(exp), batch_size=batch_size(exper),
finaliser=exp.finaliser, finaliser=exper.finaliser,
loss=exp.loss, loss=exper.loss,
epochs=exp.epochs, epochs=exper.epochs,
) )
end end
models = merge(models, add_models) models = merge(models, add_models)
end end
@info "Training models." @info "Training models."
model_dict = train_models(models, X, labels; cov=exp.coverage) model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage)
else else
@info "Loading pre-trained models." @info "Loading pre-trained models."
model_dict = Serialization.deserialize(joinpath(pretrained_path(exp), "$(exp.save_name)_models.jls")) model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
if is_multi_processed(exper)
MPI.Barrier(exper.parallelizer.comm)
end
end end
# Save models: # Save models:
@info "Saving models to $(joinpath(exp.output_path, "$(exp.save_name)_models.jls"))." if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_models.jls"), model_dict) @info "Saving models to $(joinpath(exper.output_path, "$(exper.save_name)_models.jls"))."
Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict)
end
return model_dict return model_dict
end end
\ No newline at end of file
using CounterfactualExplanations: AbstractParallelizer
""" """
train_models(models::Dict) train_models(models::Dict)
Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects. Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects.
""" """
function train_models(models::Dict, X, y; kwargs...) function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractParallelizer}=nothing, train_parallel::Bool=false, kwargs...)
model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models) if is_multi_processed(parallelizer) && train_parallel
# Split models into groups of approximately equal size:
model_list = [(key, value) for (key, value) in models]
x = split_obs(model_list, parallelizer.n_proc)
x = MPI.scatter(x, parallelizer.comm)
# Train models:
model_dict = Dict()
for (mod_name, model) in x
model_dict[mod_name] = _train(model, X, y; mod_name=mod_name, verbose=false, kwargs...)
end
MPI.Barrier(parallelizer.comm)
output = MPI.gather(output, parallelizer.comm)
# Collect output from all processe in rank 0:
if parallelizer.rank == 0
output = merge(output...)
else
output = nothing
end
# Broadcast output to all processes:
model_dict = MPI.bcast(output, parallelizer.comm; root=0)
MPI.Barrier(parallelizer.comm)
else
model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models)
end
return model_dict return model_dict
end end
...@@ -20,11 +45,15 @@ end ...@@ -20,11 +45,15 @@ end
Trains a model and returns a `ConformalModel` object. Trains a model and returns a `ConformalModel` object.
""" """
function _train(model, X, y; cov, method=:simple_inductive, mod_name="model") function _train(model, X, y; cov, method=:simple_inductive, mod_name="model", verbose::Bool=true)
conf_model = conformal_model(model; method=method, coverage=cov) conf_model = conformal_model(model; method=method, coverage=cov)
mach = machine(conf_model, X, y) mach = machine(conf_model, X, y)
@info "Begin training $mod_name." @info "Begin training $mod_name."
fit!(mach) if verbose
fit!(mach)
else
fit!(mach, verbosity=0)
end
@info "Finished training $mod_name." @info "Finished training $mod_name."
M = ECCCo.ConformalModel(mach.model, mach.fitresult) M = ECCCo.ConformalModel(mach.model, mach.fitresult)
return M return M
......
""" """
meta(exp::Experiment) meta(exper::Experiment)
Extract and save meta data about the experiment. Extract and save meta data about the experiment.
""" """
...@@ -21,28 +21,28 @@ Extract and save meta data about the data and models in `outcome.model_dict`. ...@@ -21,28 +21,28 @@ Extract and save meta data about the data and models in `outcome.model_dict`.
function meta_model(outcome::ExperimentOutcome; save_output::Bool=false) function meta_model(outcome::ExperimentOutcome; save_output::Bool=false)
# Unpack: # Unpack:
exp = outcome.exp exper = outcome.exper
n_obs, batch_size = meta_data(exp) n_obs, batch_size = meta_data(exper)
model_dict = outcome.model_dict model_dict = outcome.model_dict
params = DataFrame( params = DataFrame(
Dict( Dict(
:n_obs => Int.(round(n_obs / 10) * 10), :n_obs => Int.(round(n_obs / 10) * 10),
:batch_size => batch_size, :batch_size => batch_size,
:dataname => exp.dataname, :dataname => exper.dataname,
:sgld_batch_size => exp.sampling_batch_size, :sgld_batch_size => exper.sampling_batch_size,
:epochs => exp.epochs, :epochs => exper.epochs,
:n_hidden => exp.n_hidden, :n_hidden => exper.n_hidden,
:n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
:activation => string(exp.activation), :activation => string(exper.activation),
:n_ens => exp.n_ens, :n_ens => exper.n_ens,
:lambda => string(exp.α[3]), :lambda => string(exper.α[3]),
:jem_sampling_steps => exp.sampling_steps, :jem_sampling_steps => exper.sampling_steps,
) )
) )
if save_output if save_output
save_path = joinpath(exp.params_path, "$(exp.save_name)_model_params.csv") save_path = joinpath(exper.params_path, "$(exper.save_name)_model_params.csv")
@info "Saving model parameters to $(save_path)." @info "Saving model parameters to $(save_path)."
CSV.write(save_path, params) CSV.write(save_path, params)
end end
...@@ -54,10 +54,10 @@ end ...@@ -54,10 +54,10 @@ end
function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false) function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false)
# Unpack: # Unpack:
exp = outcome.exp exper = outcome.exper
generator_dict = outcome.generator_dict generator_dict = outcome.generator_dict
Λ = exp.Λ Λ = exper.Λ
Λ_Δ = exp.Λ_Δ Λ_Δ = exper.Λ_Δ
# Output: # Output:
opt = first(values(generator_dict)).opt opt = first(values(generator_dict)).opt
...@@ -65,19 +65,19 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false) ...@@ -65,19 +65,19 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false)
Dict( Dict(
:opt => string(typeof(opt)), :opt => string(typeof(opt)),
:eta => opt.eta, :eta => opt.eta,
:dataname => exp.dataname, :dataname => exper.dataname,
:lambda_1 => string(Λ[1]), :lambda_1 => string(Λ[1]),
:lambda_2 => string(Λ[2]), :lambda_2 => string(Λ[2]),
:lambda_3 => string(Λ[3]), :lambda_3 => string(Λ[3]),
:lambda_1_Δ => string(Λ_Δ[1]), :lambda_1_Δ => string(Λ_Δ[1]),
:lambda_2_Δ => string(Λ_Δ[2]), :lambda_2_Δ => string(Λ_Δ[2]),
:lambda_3_Δ => string(Λ_Δ[3]), :lambda_3_Δ => string(Λ_Δ[3]),
:n_individuals => exp.n_individuals, :n_individuals => exper.n_individuals,
) )
) )
if save_output if save_output
save_path = joinpath(exp.params_path, "$(exp.save_name)_generator_params.csv") save_path = joinpath(exper.params_path, "$(exper.save_name)_generator_params.csv")
@info "Saving generator parameters to $(save_path)." @info "Saving generator parameters to $(save_path)."
CSV.write(save_path, generator_params) CSV.write(save_path, generator_params)
end end
...@@ -93,18 +93,18 @@ Compute and save the model performance for the models in `outcome.model_dict`. ...@@ -93,18 +93,18 @@ Compute and save the model performance for the models in `outcome.model_dict`.
function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Nothing,Dict}=nothing, save_output::Bool=false) function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Nothing,Dict}=nothing, save_output::Bool=false)
# Unpack: # Unpack:
exp = outcome.exp exper = outcome.exper
measures = isnothing(measures) ? exp.model_measures : measures measures = isnothing(measures) ? exper.model_measures : measures
model_dict = outcome.model_dict model_dict = outcome.model_dict
# Model performance: # Model performance:
model_performance = DataFrame() model_performance = DataFrame()
for (mod_name, model) in model_dict for (mod_name, model) in model_dict
# Test performance: # Test performance:
_perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measures))) _perf = CounterfactualExplanations.Models.model_evaluation(model, exper.test_data, measure=collect(values(measures)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measures))) _perf = DataFrame([[p] for p in _perf], collect(keys(measures)))
_perf.mod_name .= mod_name _perf.mod_name .= mod_name
_perf.dataname .= exp.dataname _perf.dataname .= exper.dataname
model_performance = vcat(model_performance, _perf) model_performance = vcat(model_performance, _perf)
end end
...@@ -112,10 +112,10 @@ function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Noth ...@@ -112,10 +112,10 @@ function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Noth
println(model_performance) println(model_performance)
if save_output if save_output
save_path = joinpath(exp.params_path, "$(exp.save_name)_model_performance.jls") save_path = joinpath(exper.params_path, "$(exper.save_name)_model_performance.jls")
@info "Saving model performance to $(save_path)." @info "Saving model performance to $(save_path)."
Serialization.serialize(save_path, model_performance) Serialization.serialize(save_path, model_performance)
save_path = joinpath(exp.params_path, "$(exp.save_name)_model_performance.csv") save_path = joinpath(exper.params_path, "$(exper.save_name)_model_performance.csv")
@info "Saving model performance to $(save_path)." @info "Saving model performance to $(save_path)."
CSV.write(save_path, model_performance) CSV.write(save_path, model_performance)
end end
......
include("meta_data.jl") include("meta_data.jl")
include("artifacts.jl") include("artifacts.jl")
\ No newline at end of file include("results.jl")
\ No newline at end of file
""" """
aggregate(outcome::ExperimentOutcome; measure::String="distance_from_targets") aggregate_results(outcome::ExperimentOutcome; measure::String="distance_from_targets")
Function to quickly aggregate benchmarking results for a given measure. Function to quickly aggregate benchmarking results for a given measure.
""" """
function aggregate(outcome::ExperimentOutcome; measure::String="distance_from_targets") function aggregate_results(outcome::ExperimentOutcome; measure::String="distance_from_targets")
df = @chain outcome.bmk() begin df = @chain outcome.bmk() begin
@group_by(generator, model) @group_by(generator, model)
@filter(variable == measure) @filter(variable == measure)
......
...@@ -33,12 +33,12 @@ include("data/data.jl") ...@@ -33,12 +33,12 @@ include("data/data.jl")
include("models/models.jl") include("models/models.jl")
include("benchmarking/benchmarking.jl") include("benchmarking/benchmarking.jl")
include("post_processing/post_processing.jl") include("post_processing/post_processing.jl")
include("utils.jl")
# Parallelization: # Parallelization:
plz = nothing plz = nothing
if "threaded" ARGS if "threaded" ARGS
@info "Multi-threading using $(Threads.nthreads()) threads."
const USE_THREADS = true const USE_THREADS = true
plz = ThreadsParallelizer() plz = ThreadsParallelizer()
else else
...@@ -46,14 +46,17 @@ else ...@@ -46,14 +46,17 @@ else
end end
if "mpi" ARGS if "mpi" ARGS
@info "Multi-processing using MPI."
import MPI import MPI
MPI.Init() MPI.Init()
const USE_MPI = true const USE_MPI = true
plz = MPIParallelizer(MPI.COMM_WORLD, USE_THREADS) plz = MPIParallelizer(MPI.COMM_WORLD; threaded=USE_THREADS)
if MPI.Comm_rank(MPI.COMM_WORLD) != 0 if MPI.Comm_rank(MPI.COMM_WORLD) != 0
@info "Disabling logging on non-root processes."
global_logger(NullLogger()) global_logger(NullLogger())
else
@info "Multi-processing using MPI. Disabling logging on non-root processes."
if USE_THREADS
@info "Multi-threading using $(Threads.nthreads()) threads."
end
end end
else else
const USE_MPI = false const USE_MPI = false
......
is_multi_processed(exper::Experiment) = isa(exper.parallelizer, Base.get_extension(CounterfactualExplanations, :MPIExt).MPIParallelizer)
is_multi_processed(parallelizer::AbstractParallelizer) = isa(parallelizer, Base.get_extension(CounterfactualExplanations, :MPIExt).MPIParallelizer)
\ No newline at end of file
...@@ -37,7 +37,7 @@ if pre_proc ...@@ -37,7 +37,7 @@ if pre_proc
function boxcox(x) function boxcox(x)
transf = MLJ.UnivariateBoxCoxTransformer() transf = MLJ.UnivariateBoxCoxTransformer()
x = exp.(x) x = exper.(x)
mach = machine(transf, x) mach = machine(transf, x)
fit!(mach) fit!(mach)
z = MLJ.transform(mach, x) z = MLJ.transform(mach, x)
......
...@@ -9,7 +9,7 @@ dataname = "linearly_separable" ...@@ -9,7 +9,7 @@ dataname = "linearly_separable"
outcome = Serialization.deserialize(joinpath(DEFAULT_OUTPUT_PATH, "$(dataname)_outcome.jls")) outcome = Serialization.deserialize(joinpath(DEFAULT_OUTPUT_PATH, "$(dataname)_outcome.jls"))
# Unpack # Unpack
exp = outcome.exp exper = outcome.exper
model_dict = outcome.model_dict model_dict = outcome.model_dict
generator_dict = outcome.generator_dict generator_dict = outcome.generator_dict
bmk = outcome.bmk bmk = outcome.bmk
...@@ -19,7 +19,7 @@ bmk = outcome.bmk ...@@ -19,7 +19,7 @@ bmk = outcome.bmk
Random.seed!(2023) Random.seed!(2023)
# Unpack # Unpack
counterfactual_data = exp.counterfactual_data counterfactual_data = exper.counterfactual_data
X, labels = counterfactual_data.X, counterfactual_data.output_encoder.labels X, labels = counterfactual_data.X, counterfactual_data.output_encoder.labels
M = model_dict["MLP"] M = model_dict["MLP"]
gen = filter(((k,v),) -> k in ["ECCCo", "ECCCo-Δ"], generator_dict) gen = filter(((k,v),) -> k in ["ECCCo", "ECCCo-Δ"], generator_dict)
......
...@@ -138,7 +138,7 @@ To keep things consistent with the architecture of `CounterfactualExplanations.j ...@@ -138,7 +138,7 @@ To keep things consistent with the architecture of `CounterfactualExplanations.j
Let $\hat{p}_i$ denote the estimated softmax output for feature $i$. Then in the multi-class case the following formula can be applied: Let $\hat{p}_i$ denote the estimated softmax output for feature $i$. Then in the multi-class case the following formula can be applied:
```math ```math
\beta_i x_i = \log (\hat{p}_i) + \log (\sum_i \exp(\hat{p}_i)) \beta_i x_i = \log (\hat{p}_i) + \log (\sum_i \exper(\hat{p}_i))
``` ```
For a short derivation, see here: https://math.stackexchange.com/questions/2786600/invert-the-softmax-function. For a short derivation, see here: https://math.stackexchange.com/questions/2786600/invert-the-softmax-function.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment