Skip to content
Snippets Groups Projects
Commit fc23c9bd authored by Pat Alt's avatar Pat Alt
Browse files

synthetic

parent f19dcece
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -7,7 +7,9 @@ run_experiment(
n_hidden=32,
α=[1.0, 1.0, 1e-2],
sampling_steps=20,
Λ=[0.1, 0.1, 0.1],
opt=Flux.Optimise.Descent(0.01),
activation=Flux.swish,
Λ=[0.1, 0.2, 0.2],
nsamples=100,
niter_eccco=100,
)
\ No newline at end of file
......@@ -11,4 +11,4 @@
module load 2023r1 openmpi
srun julia --project=experiments experiments/run_experiments.jl -- data=german_credit output_path=results retrain mpi > experiments/german_credit.log
srun julia --project=experiments experiments/run_experiments.jl -- data=german_credit output_path=results mpi > experiments/german_credit.log
......@@ -11,4 +11,4 @@
module load 2023r1 openmpi
srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi retrain > experiments/mnist.log
srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi > experiments/mnist.log
......@@ -7,6 +7,6 @@ run_experiment(
counterfactual_data, test_data;
dataname="Linearly Separable",
nsamples=100,
nmin=1,
niter_eccco=30
niter_eccco=100,
Λ=[0.1, 0.2, 0.2],
)
\ No newline at end of file
......@@ -8,7 +8,9 @@ run_experiment(
activation = Flux.relu,
sampling_batch_size=10,
sampling_steps=30,
Λ=[0.1, 0.1, 0.1],
opt=Flux.Optimise.Descent(0.05),
α=[1.0, 1.0, 1e-1]
α=[1.0, 1.0, 1e-1],
nsamples=100,
niter_eccco=100,
Λ = [0.1, 0.2, 0.2],
)
\ No newline at end of file
......@@ -50,27 +50,6 @@ function energy_delta(
kwargs...
)
# nmin = minimum([nmin, n])
# @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
# conditional_samples = []
# ignore_derivatives() do
# _dict = ce.params
# if !(:energy_sampler ∈ collect(keys(_dict)))
# _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
# end
# eng_sampler = _dict[:energy_sampler]
# if choose_lowest_energy
# nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
# xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
# push!(conditional_samples, xmin)
# elseif choose_random
# push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
# else
# push!(conditional_samples, eng_sampler.buffer)
# end
# end
conditional_samples = []
ignore_derivatives() do
_dict = ce.params
......@@ -85,7 +64,7 @@ function energy_delta(
xgenerated = conditional_samples[1] # conditional samples
xproposed = CounterfactualExplanations.decode_state(ce) # current state
t = get_target_index(ce.data.y_levels, ce.target)
E(x) = -logits(ce.M, x)[t,:] # negative logits for target class
E(x) = -logits(ce.M, x)[t,:] # negative logits for taraget class
# Generative loss:
gen_loss = E(xproposed) .- E(xgenerated)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment