Skip to content
Snippets Groups Projects
Commit c381b504 authored by Pat Alt's avatar Pat Alt
Browse files

MNIST still crashing :cry:

parent 7412c6e4
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -125,21 +125,21 @@ end
Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments.
"""
function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)
function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; save_output::Bool=true, kwargs...)
# Parameters:
exper = Experiment(;
counterfactual_data=counterfactual_data,
test_data=test_data,
kwargs...
)
return run_experiment(exper)
return run_experiment(exper; save_output=save_output)
end
# Pre-trained models:
function pretrained_path(exper::Experiment)
if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls"))
@info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those."
return exper.output_path
return DEFAULT_OUTPUT_PATH
else
@info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions."
Pkg.Artifacts.download_artifact(ARTIFACT_HASH, ARTIFACT_TOML)
......
......@@ -29,9 +29,9 @@ function grid_search(
for tuning_params in grid
outcome = run_experiment(
counterfactual_data, test_data;
save_output=false,
dataname=dataname,
output_path=grid_search_path,
save_output=false,
tuning_params...,
kwargs...,
)
......
......@@ -76,8 +76,8 @@ function prepare_models(exper::Experiment)
# Save models:
if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
@info "Saving models to $(joinpath(exper.output_path, "$(exper.save_name)_models.jls"))."
Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict)
@info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))."
Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict)
end
return model_dict
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment