Skip to content
Snippets Groups Projects
Commit 04f4e9aa authored by pat-alt's avatar pat-alt
Browse files

just saving meta data directly

parent 0c5b87c8
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -65,6 +65,8 @@ function run_experiment(exp::Experiment; save_output::Bool=true)
# Save data:
if save_output
Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_outcome.jls"), outcome)
Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_bmk.jls"), bmk)
meta(outcome::ExperimentOutcome; save_output::Bool=true)
end
return outcome
......
using CounterfactualExplanations.Objectives
using CounterfactualExplanations.Generators: GradientBasedGenerator
"Constructor for `CCEGenerator`."
function CCEGenerator(;
λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0],
κ::Real=1.0,
temp::Real=0.1,
kwargs...
)
function _set_size_penalty(ce::AbstractCounterfactualExplanation)
return ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
end
_penalties = [Objectives.distance_l2, _set_size_penalty]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end
"Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator."
function ECCCoGenerator(;
λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.2,0.4,0.4],
......@@ -45,11 +30,11 @@ function ECCCoGenerator(;
_set_size_penalty = (ce::AbstractCounterfactualExplanation) -> ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
# Energy penalty
_energy_penalty = function (ce::AbstractCounterfactualExplanation)
_energy_penalty = function(ce::AbstractCounterfactualExplanation)
if use_energy_delta
return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...)
return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin)
else
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin)
end
end
......@@ -58,18 +43,4 @@ function ECCCoGenerator(;
# Generator
return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, kwargs...)
end
"Constructor for `EnergyDrivenGenerator`."
function EnergyDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...)
_penalties = [Objectives.distance_l2, ECCCo.distance_from_energy]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end
"Constructor for `TargetDrivenGenerator`."
function TargetDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...)
_penalties = [Objectives.distance_l2, ECCCo.distance_from_targets]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment