Skip to content
Snippets Groups Projects
Commit 5eac0ff1 authored by Pat Alt's avatar Pat Alt
Browse files

uh

parent 048e974e
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
...@@ -78,7 +78,7 @@ end ...@@ -78,7 +78,7 @@ end
Run the experiment specified by `exp`. Run the experiment specified by `exp`.
""" """
function run_experiment(exp::Experiment; save_output::Bool=true, only_models::Bool=false) function run_experiment(exp::Experiment; save_output::Bool=true, only_models::Bool=ONLY_MODELS)
# Setup # Setup
@info "All results will be saved to $(exp.output_path)." @info "All results will be saved to $(exp.output_path)."
......
function pre_process(x; noise::Float32=0.03f0)
ϵ = Float32.(randn(size(x)) * noise)
x += ϵ
return x
end
# Training data: # Training data:
n_obs = 10000 n_obs = 10000
counterfactual_data = load_mnist(n_obs) counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X) counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X)
# VAE (trained on full dataset): # VAE (trained on full dataset):
using CounterfactualExplanations.Models: load_mnist_vae using CounterfactualExplanations.Models: load_mnist_vae
......
...@@ -27,18 +27,18 @@ if "circles" in datanames ...@@ -27,18 +27,18 @@ if "circles" in datanames
include("circles.jl") include("circles.jl")
end end
# MNIST
if "mnist" in datanames
@info "Running MNIST experiment."
include("mnist.jl")
end
# GMSC # GMSC
if "gmsc" in datanames if "gmsc" in datanames
@info "Running GMSC experiment." @info "Running GMSC experiment."
include("gmsc.jl") include("gmsc.jl")
end end
# MNIST
if "mnist" in datanames
@info "Running MNIST experiment."
include("mnist.jl")
end
if USE_MPI if USE_MPI
MPI.Finalize() MPI.Finalize()
end end
...@@ -3,6 +3,7 @@ module ECCCo ...@@ -3,6 +3,7 @@ module ECCCo
using CounterfactualExplanations using CounterfactualExplanations
import MLJModelInterface as MMI import MLJModelInterface as MMI
include("utils.jl")
include("model.jl") include("model.jl")
include("sampling.jl") include("sampling.jl")
include("penalties.jl") include("penalties.jl")
......
function pre_process(x; noise::Float32=0.03f0)
ϵ = Float32.(randn(size(x)) * noise)
x += ϵ
return x
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment