Skip to content
Snippets Groups Projects
Commit 09634019 authored by pat-alt's avatar pat-alt
Browse files

Merge branch '69-initial-run-including-fmnist-lenet-and-new-method' of...

Merge branch '69-initial-run-including-fmnist-lenet-and-new-method' of https://github.com/pat-alt/ECCCo.jl into 69-initial-run-including-fmnist-lenet-and-new-method
parents c03e752a bd12b6b2
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -2,12 +2,12 @@ n_obs = Int(1000 / (1.0 - TEST_SIZE))
counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE)
run_experiment(
counterfactual_data, test_data;
epochs=100,
dataname="Circles",
n_hidden=32,
α=[1.0, 1.0, 1e-2],
sampling_steps=20,
Λ=[0.25, 0.75, 0.75],
Λ=[0.1, 0.1, 0.1],
opt=Flux.Optimise.Descent(0.01),
nsamples=1,
nmin=1,
activation=Flux.swish,
)
\ No newline at end of file
......@@ -16,12 +16,13 @@ n_ind = N_IND_SPECIFIED ? N_IND : 10
run_experiment(
counterfactual_data, test_data;
dataname="GMSC",
epochs=100,
builder = builder,
α=[1.0, 1.0, 1e-1],
sampling_batch_size=10,
sampling_steps = 30,
use_ensembling = true,
Λ=[0.1, 0.5, 0.5],
Λ=[0.1, 0.1, 0.1],
opt = Flux.Optimise.Descent(0.05),
n_individuals = n_ind,
use_variants = false,
......
......@@ -47,9 +47,12 @@ function prepare_models(exper::Experiment)
@info "Training models."
model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage)
else
@info "Loading pre-trained models."
model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
@info "Loading pre-trained models."
model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
end
if is_multi_processed(exper)
model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0)
MPI.Barrier(exper.parallelizer.comm)
end
end
......
......@@ -8,8 +8,7 @@ run_experiment(
activation = Flux.relu,
sampling_batch_size=10,
sampling_steps=30,
Λ=[0.25, 0.75, 0.75],
Λ=[0.1, 0.1, 0.1],
opt=Flux.Optimise.Descent(0.05),
nsamples=1,
nmin=1
α=[1.0, 1.0, 1e-1]
)
\ No newline at end of file
......@@ -46,30 +46,35 @@ function energy_delta(
choose_random=false,
nmin::Int=25,
return_conditionals=false,
reg_strength=0.5,
reg_strength=0.1,
kwargs...
)
nmin = minimum([nmin, n])
@assert choose_lowest_energy choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
# nmin = minimum([nmin, n])
# @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
# conditional_samples = []
# ignore_derivatives() do
# _dict = ce.params
# if !(:energy_sampler ∈ collect(keys(_dict)))
# _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
# end
# eng_sampler = _dict[:energy_sampler]
# if choose_lowest_energy
# nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
# xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
# push!(conditional_samples, xmin)
# elseif choose_random
# push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
# else
# push!(conditional_samples, eng_sampler.buffer)
# end
# end
conditional_samples = []
ignore_derivatives() do
_dict = ce.params
if !(:energy_sampler collect(keys(_dict)))
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
eng_sampler = _dict[:energy_sampler]
if choose_lowest_energy
nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
push!(conditional_samples, xmin)
elseif choose_random
push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
else
push!(conditional_samples, eng_sampler.buffer)
end
ignore_derivatives() do
xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...)
push!(conditional_samples, xsampled)
end
xgenerated = conditional_samples[1] # conditional samples
......@@ -79,16 +84,17 @@ function energy_delta(
# Generative loss:
gen_loss = E(xproposed) .- E(xgenerated)
gen_loss = reduce((x, y) -> x + y, gen_loss) / n # aggregate over samples
gen_loss = reduce((x, y) -> x + y, gen_loss) / length(gen_loss) # aggregate over samples
# Regularization loss:
reg_loss = E(xgenerated).^2 .+ E(xproposed).^2
reg_loss = reduce((x, y) -> x + y, reg_loss) / n # aggregate over samples
reg_loss = reduce((x, y) -> x + y, reg_loss) / length(reg_loss) # aggregate over samples
if return_conditionals
if !return_conditionals
return gen_loss + reg_strength * reg_loss
else
return conditional_samples[1]
end
return gen_loss + reg_strength * reg_loss
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment