Skip to content
Snippets Groups Projects
Commit 04f4e9aa authored by pat-alt's avatar pat-alt
Browse files

just saving meta data directly

parent 0c5b87c8
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
...@@ -65,6 +65,8 @@ function run_experiment(exp::Experiment; save_output::Bool=true) ...@@ -65,6 +65,8 @@ function run_experiment(exp::Experiment; save_output::Bool=true)
# Save data: # Save data:
if save_output if save_output
Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_outcome.jls"), outcome) Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_outcome.jls"), outcome)
Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_bmk.jls"), bmk)
meta(outcome::ExperimentOutcome; save_output::Bool=true)
end end
return outcome return outcome
......
using CounterfactualExplanations.Objectives using CounterfactualExplanations.Objectives
using CounterfactualExplanations.Generators: GradientBasedGenerator using CounterfactualExplanations.Generators: GradientBasedGenerator
"Constructor for `CCEGenerator`."
function CCEGenerator(;
λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0],
κ::Real=1.0,
temp::Real=0.1,
kwargs...
)
function _set_size_penalty(ce::AbstractCounterfactualExplanation)
return ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
end
_penalties = [Objectives.distance_l2, _set_size_penalty]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end
"Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator." "Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator."
function ECCCoGenerator(; function ECCCoGenerator(;
λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.2,0.4,0.4], λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.2,0.4,0.4],
...@@ -45,11 +30,11 @@ function ECCCoGenerator(; ...@@ -45,11 +30,11 @@ function ECCCoGenerator(;
_set_size_penalty = (ce::AbstractCounterfactualExplanation) -> ECCCo.set_size_penalty(ce; κ=κ, temp=temp) _set_size_penalty = (ce::AbstractCounterfactualExplanation) -> ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
# Energy penalty # Energy penalty
_energy_penalty = function (ce::AbstractCounterfactualExplanation) _energy_penalty = function(ce::AbstractCounterfactualExplanation)
if use_energy_delta if use_energy_delta
return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...) return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin)
else else
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...) return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin)
end end
end end
...@@ -58,18 +43,4 @@ function ECCCoGenerator(; ...@@ -58,18 +43,4 @@ function ECCCoGenerator(;
# Generator # Generator
return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, kwargs...) return GradientBasedGenerator(; loss=loss_fun, penalty=_penalties, λ=λ, opt=opt, kwargs...)
end
"Constructor for `EnergyDrivenGenerator`."
function EnergyDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...)
_penalties = [Objectives.distance_l2, ECCCo.distance_from_energy]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end
"Constructor for `TargetDrivenGenerator`."
function TargetDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...)
_penalties = [Objectives.distance_l2, ECCCo.distance_from_targets]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return GradientBasedGenerator(; penalty=_penalties, λ=λ, kwargs...)
end end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment