diff --git a/experiments/models/default_models.jl b/experiments/models/default_models.jl index e57e926f95e7d7c551c08e3200ac88b29fa1a556..53319cd9e65c7b8e9b8f5c9e21c863aecdc4d5c6 100644 --- a/experiments/models/default_models.jl +++ b/experiments/models/default_models.jl @@ -31,7 +31,7 @@ Builds a dictionary of default models for training. function default_models(; sampler::AbstractSampler, builder::MLJFlux.Builder=default_builder(), - epochs::Int=25, + epochs::Int=100, batch_size::Int=128, finaliser::Function=Flux.softmax, loss::Function=Flux.Losses.crossentropy, diff --git a/notebooks/prototyping.qmd b/notebooks/prototyping.qmd index 7705d4d122a62d55113f93d04157f008862658f6..a8b2e667683ea0892b26ce39860ad7631e9f2aa5 100644 --- a/notebooks/prototyping.qmd +++ b/notebooks/prototyping.qmd @@ -1,111 +1,40 @@ ```{julia} -include("$(pwd())/notebooks/setup.jl") -eval(setup_notebooks) +include("$(pwd())/experiments/setup_env.jl") ``` # Linearly Separable Data ```{julia} -# Hyper: -_retrain = false +dataname = "linearly_separable" +outcome = Serialization.deserialize(joinpath(DEFAULT_OUTPUT_PATH, "$(dataname)_outcome.jls")) -# Data: -test_size = 0.2 -n_obs = Int(1000 / (1.0 - test_size)) -counterfactual_data, test_data = train_test_split( - load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); - test_size=test_size -) -X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) -X = table(permutedims(X)) -labels = counterfactual_data.output_encoder.labels -input_dim, n_obs = size(counterfactual_data.X) -output_dim = length(unique(labels)) -``` - -First, let's create a couple of image classifier architectures: - -```{julia} -# Model parameters: -epochs = 100 -bs = minimum([Int(round(n_obs/10)), 128]) -n_hidden = 16 -activation = Flux.swish -builder = MLJFlux.MLP( - hidden=(n_hidden, n_hidden, n_hidden), - σ=Flux.swish -) -n_ens = 5 # number of models in ensemble -_loss = Flux.Losses.crossentropy # loss function -_finaliser = Flux.softmax # finaliser function -``` - -```{julia} -# JEM parameters: -ð’Ÿx = Normal() -ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) -sampler = ConditionalSampler( - ð’Ÿx, ð’Ÿy, - input_size=(input_dim,), - batch_size=50, -) -α = [1.0,1.0,1e-1] # penalty strengths -``` - - -```{julia} -# Joint Energy Model: -model = JointEnergyClassifier( - sampler; - builder=builder, - epochs=epochs, - batch_size=bs, - finaliser=_finaliser, - loss=_loss, - jem_training_params=( - α=α,verbosity=10, - ), - sampling_steps=30, -) -``` - -```{julia} -conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95) -mach = machine(conf_model, X, labels) -@info "Begin training model." -fit!(mach) -@info "Finished training model." -M = ECCCo.ConformalModel(mach.model, mach.fitresult) -``` - -```{julia} -λ₠= 0.25 -λ₂ = 0.75 -λ₃ = 0.75 -Λ = [λâ‚, λ₂, λ₃] - -opt = Flux.Optimise.Descent(0.01) -use_class_loss = false - -# Benchmark generators: -generator_dict = Dict( - "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss), - "ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true), -) +# Unpack +exp = outcome.exp +model_dict = outcome.model_dict +generator_dict = outcome.generator_dict +bmk = outcome.bmk ``` ```{julia} Random.seed!(2023) +# Unpack +counterfactual_data = exp.counterfactual_data +X, labels = counterfactual_data.X, counterfactual_data.output_encoder.labels +M = model_dict["MLP"] +gen = filter(((k,v),) -> k in ["ECCCo", "ECCCo-Δ"], generator_dict) + +# Prepare search: X = X isa Matrix ? X : Float32.(permutedims(matrix(X))) factual_label = levels(labels)[2] -x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +x_factual = X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))] |> + x -> x[:,:] target = levels(labels)[1] factual = predict_label(M, counterfactual_data, x_factual)[1] ces = Dict{Any,Any}() plts = [] -for (name, generator) in generator_dict +for (name, generator) in gen ce = generate_counterfactual( x_factual, target, counterfactual_data, M, generator; initialization=:identity, diff --git a/src/penalties.jl b/src/penalties.jl index cba61638c58e6007ce3ca35590fcf1eff5699c03..aff3e84e5cfb55f578e6006a7911980831642968 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -76,14 +76,19 @@ function energy_delta( xproposed = CounterfactualExplanations.decode_state(ce) # current state t = get_target_index(ce.data.y_levels, ce.target) E(x) = -logits(ce.M, x)[t,:] # negative logits for target class - _loss = E(xproposed) .- E(xgenerated) - _loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples + # Generative loss: + gen_loss = E(xproposed) .- E(xgenerated) + gen_loss = reduce((x, y) -> x + y, gen_loss) / n # aggregate over samples + + # Regularization loss: + reg_loss = E(xgenerated).^2 .+ E(xproposed).^2 + reg_loss = reduce((x, y) -> x + y, reg_loss) / n # aggregate over samples if return_conditionals return conditional_samples[1] end - return _loss + return gen_loss + 0.1reg_loss end