Skip to content
Snippets Groups Projects
Commit 4577b4a4 authored by Pat Alt's avatar Pat Alt
Browse files

bloody hell

parent ed785f05
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -444,11 +444,9 @@ version = "0.6.2"
[[deps.CounterfactualExplanations]]
deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"]
git-tree-sha1 = "9bcb579703041d8708b179e55c119f150c5565bc"
repo-rev = "main"
repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git"
git-tree-sha1 = "30cf711962736a6bc5ffc6c7d1b6be6d11d306d9"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
version = "0.1.23"
version = "0.1.24"
[deps.CounterfactualExplanations.extensions]
MPIExt = "MPI"
......
......@@ -16,8 +16,8 @@ test_data = load_fashion_mnist_test()
model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
# Tuning parameters:
tuning_params = DEFAULT_GENERATOR_TUNING[2:end]
push!(tuning_params.Λ, [0.1, 0.1, 3.0])
tuning_params = DEFAULT_GENERATOR_TUNING
tuning_params = (; tuning_params..., Λ=[tuning_params.Λ[2:end]..., [0.1, 0.1, 3.0]])
# Additional models:
add_models = Dict(
......@@ -39,7 +39,8 @@ params = (
epochs=10,
nsamples=10,
nmin=1,
niter_eccco=100
niter_eccco=100,
Λ = [0.1, 0.1, 3.0]
)
if !GRID_SEARCH
......
......@@ -25,6 +25,11 @@ function grid_search(
tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)]
grid = Iterators.product(tuning_params...)
outcomes = Dict{Any,Any}()
# Save:
if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0)
Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes)
end
# Search:
counter = 1
......@@ -44,7 +49,7 @@ function grid_search(
end
# Save:
if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0)
Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes)
end
end
\ No newline at end of file
......@@ -16,8 +16,8 @@ test_data = load_mnist_test()
model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
# Tuning parameters:
tuning_params = DEFAULT_GENERATOR_TUNING[2:end]
push!(tuning_params.Λ, [0.1, 0.1, 3.0])
tuning_params = DEFAULT_GENERATOR_TUNING
tuning_params = (; tuning_params..., Λ=[tuning_params.Λ[2:end]..., [0.1, 0.1, 3.0]])
# Additional models:
add_models = Dict(
......@@ -39,7 +39,8 @@ params = (
epochs=10,
nsamples=10,
nmin=1,
niter_eccco=100
niter_eccco=100,
Λ=[0.1, 0.1, 3.0]
)
if !GRID_SEARCH
......
......@@ -76,8 +76,8 @@ function prepare_models(exper::Experiment; save_models::Bool=true)
# Save models:
if save_models && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
@info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))."
Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict)
@info "Saving models to $(joinpath(exper.output_path , "$(exper.save_name)_models.jls"))."
Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict)
end
return model_dict
......
......@@ -145,7 +145,7 @@ const N_IND_SPECIFIED = n_ind_specified
const GRID_SEARCH = "grid_search" ARGS
"Generator tuning parameters."
const DEFAULT_GENERATOR_TUNING = (
DEFAULT_GENERATOR_TUNING = (
nsamples=[10, 100],
niter_eccco=[10, 100],
Λ=[
......@@ -160,13 +160,13 @@ const DEFAULT_GENERATOR_TUNING = (
const TUNE_MODEL = "tune_model" ARGS
"Model tuning parameters for small datasets."
const DEFAULT_MODEL_TUNING_SMALL = (
DEFAULT_MODEL_TUNING_SMALL = (
n_hidden=[16, 32, 64],
n_layers=[1, 2, 3],
)
"Model tuning parameters for large datasets."
const DEFAULT_MODEL_TUNING_LARGE = (
DEFAULT_MODEL_TUNING_LARGE = (
n_hidden=[32, 64, 128, 512],
n_layers=[2, 3, 5],
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment