From 9f274e83212768e189f4b1407bbb2efdddcbbd9c Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Wed, 13 Sep 2023 08:13:08 +0200 Subject: [PATCH] fun times --- experiments/Manifest.toml | 52 ++++++++++++++-------------------- experiments/circles.jl | 4 ++- experiments/gmsc.jl | 3 +- experiments/models/models.jl | 7 +++-- experiments/moons.jl | 3 +- src/penalties.jl | 54 ++++++++++++++++++++---------------- 6 files changed, 63 insertions(+), 60 deletions(-) diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index d7bbde7e..b8f01c6e 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -248,9 +248,9 @@ version = "1.0.5" [[deps.CairoMakie]] deps = ["Base64", "Cairo", "Colors", "FFTW", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "PrecompileTools", "SHA"] -git-tree-sha1 = "30562a68ded3dabe80109caf6b4de73a48ac27bc" +git-tree-sha1 = "696e7931bd6f5c773418452cbe5fd241cb85ac2a" uuid = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -version = "0.10.8" +version = "0.10.9" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] @@ -508,13 +508,9 @@ version = "0.1.2" [[deps.DelaunayTriangulation]] deps = ["DataStructures", "EnumX", "ExactPredicates", "Random", "SimpleGraphs"] -git-tree-sha1 = "a1d8532de83f8ce964235eff1edeff9581144d02" +git-tree-sha1 = "bea7984f7e09aeb28a3b071c420a0186cb4fabad" uuid = "927a84f5-c5f4-47a5-9785-b46e178433df" -version = "0.7.2" -weakdeps = ["MakieCore"] - - [deps.DelaunayTriangulation.extensions] - DelaunayTriangulationMakieCoreExt = "MakieCore" +version = "0.8.8" [[deps.DelimitedFiles]] deps = ["Mmap"] @@ -647,12 +643,6 @@ git-tree-sha1 = "5e1e4c53fa39afe63a7d356e30452249365fba99" uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910" version = "0.1.1" -[[deps.FFMPEG]] -deps = ["FFMPEG_jll"] -git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" -uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" -version = "0.4.1" - [[deps.FFMPEG_jll]] deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] git-tree-sha1 = "466d45dc38e15794ec7d5d63ec03d776a9aff36e" @@ -1331,9 +1321,9 @@ version = "2023.2.0+0" [[deps.MLDatasets]] deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "41922968c0aaca46baa5d658d3a173828313e2d0" +git-tree-sha1 = "10bc70e4c875f1b2ca65cef3ef9ebe705ef936b5" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.12" +version = "0.7.13" [[deps.MLFlowClient]] deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] @@ -1451,16 +1441,16 @@ uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.11" [[deps.Makie]] -deps = ["Animations", "Base64", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "Match", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Setfield", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"] -git-tree-sha1 = "e81675589ba7199a82443e87fc52e17eeceac2e8" +deps = ["Animations", "Base64", "CRC32c", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG_jll", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "Match", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Setfield", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"] +git-tree-sha1 = "ecc334efc4a8a68800776b0d85ab7bb2fff63f7a" uuid = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" -version = "0.19.8" +version = "0.19.9" [[deps.MakieCore]] deps = ["Observables"] -git-tree-sha1 = "f56b09c8b964919373d61750c6d8d4d2c602a2be" +git-tree-sha1 = "1efb1166dd9398f2ccf6d728f896658c9c84733e" uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b" -version = "0.6.5" +version = "0.6.6" [[deps.MappedArrays]] git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" @@ -1772,10 +1762,10 @@ uuid = "8b842266-38fa-440a-9b57-31493939ab85" version = "0.1.4" [[deps.Pango_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "84a314e3926ba9ec66ac097e3635e270986b0f10" +deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"] +git-tree-sha1 = "4745216e94f71cb768d58330b059c9b76f32cb66" uuid = "36c8627f-9965-5494-a995-c6b170f724f3" -version = "1.50.9+0" +version = "1.50.14+0" [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] @@ -1911,9 +1901,9 @@ version = "0.1.4" [[deps.ProgressMeter]] deps = ["Distributed", "Printf"] -git-tree-sha1 = "ae36206463b2395804f2787ffe172f44452b538d" +git-tree-sha1 = "00099623ffee15972c16111bcf84c58a0051257c" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.8.0" +version = "1.9.0" [[deps.QOI]] deps = ["ColorTypes", "FileIO", "FixedPointNumbers"] @@ -2170,10 +2160,10 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" [[deps.StableHashTraits]] -deps = ["CRC32c", "Compat", "Dates", "SHA", "Tables", "TupleTools", "UUIDs"] -git-tree-sha1 = "0b8b801b8f03a329a4e86b44c5e8a7d7f4fe10a3" +deps = ["Compat", "SHA", "Tables", "TupleTools"] +git-tree-sha1 = "19df33ca14f24a3ad2df9e89124bd5f5cc8467a2" uuid = "c5dd0088-6c3f-4803-b00e-f31a60c170fa" -version = "0.3.1" +version = "1.0.1" [[deps.StableRNGs]] deps = ["Random", "Test"] @@ -2324,9 +2314,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TidierData]] deps = ["Chain", "Cleaner", "DataFrames", "MacroTools", "Reexport", "ShiftedArrays", "Statistics", "StatsBase"] -git-tree-sha1 = "208ee18374d1aff392d7c50f2753568922bdf68c" +git-tree-sha1 = "3a34d01b775181ca86f4315ea5895228b1742b81" uuid = "fe2206b3-d496-4ee9-a338-6a095c4ece80" -version = "0.12.0" +version = "0.12.1" [[deps.TidierPlots]] deps = ["AlgebraOfGraphics", "CairoMakie", "DataFrames", "Makie", "MarketData", "PalmerPenguins", "Reexport"] diff --git a/experiments/circles.jl b/experiments/circles.jl index aee790da..a7dc16dc 100644 --- a/experiments/circles.jl +++ b/experiments/circles.jl @@ -2,10 +2,12 @@ n_obs = Int(1000 / (1.0 - TEST_SIZE)) counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE) run_experiment( counterfactual_data, test_data; + epochs=100, dataname="Circles", n_hidden=32, α=[1.0, 1.0, 1e-2], sampling_steps=20, - Λ=[0.25, 0.75, 0.75], + Λ=[0.1, 0.1, 0.1], opt=Flux.Optimise.Descent(0.01), + activation=Flux.swish, ) \ No newline at end of file diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index c4641200..aecdbd38 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -16,12 +16,13 @@ n_ind = N_IND_SPECIFIED ? N_IND : 10 run_experiment( counterfactual_data, test_data; dataname="GMSC", + epochs=100, builder = builder, α=[1.0, 1.0, 1e-1], sampling_batch_size=10, sampling_steps = 30, use_ensembling = true, - Λ=[0.1, 0.5, 0.5], + Λ=[0.1, 0.1, 0.1], opt = Flux.Optimise.Descent(0.05), n_individuals = n_ind, use_variants = false, diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 022b6937..900e95b9 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -47,9 +47,12 @@ function prepare_models(exper::Experiment) @info "Training models." model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage) else - @info "Loading pre-trained models." - model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls")) + if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + @info "Loading pre-trained models." + model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls")) + end if is_multi_processed(exper) + model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0) MPI.Barrier(exper.parallelizer.comm) end end diff --git a/experiments/moons.jl b/experiments/moons.jl index 02034380..8b096983 100644 --- a/experiments/moons.jl +++ b/experiments/moons.jl @@ -8,6 +8,7 @@ run_experiment( activation = Flux.relu, sampling_batch_size=10, sampling_steps=30, - Λ=[0.25, 0.75, 0.75], + Λ=[0.1, 0.1, 0.1], opt=Flux.Optimise.Descent(0.05), + α=[1.0, 1.0, 1e-1] ) \ No newline at end of file diff --git a/src/penalties.jl b/src/penalties.jl index 78ff7df3..030a22e6 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -46,30 +46,35 @@ function energy_delta( choose_random=false, nmin::Int=25, return_conditionals=false, - reg_strength=0.5, + reg_strength=0.1, kwargs... ) - nmin = minimum([nmin, n]) - - @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither." - + # nmin = minimum([nmin, n]) + + # @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither." + + # conditional_samples = [] + # ignore_derivatives() do + # _dict = ce.params + # if !(:energy_sampler ∈ collect(keys(_dict))) + # _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) + # end + # eng_sampler = _dict[:energy_sampler] + # if choose_lowest_energy + # nmin = minimum([nmin, size(eng_sampler.buffer)[end]]) + # xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin) + # push!(conditional_samples, xmin) + # elseif choose_random + # push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer)) + # else + # push!(conditional_samples, eng_sampler.buffer) + # end + # end conditional_samples = [] - ignore_derivatives() do - _dict = ce.params - if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) - end - eng_sampler = _dict[:energy_sampler] - if choose_lowest_energy - nmin = minimum([nmin, size(eng_sampler.buffer)[end]]) - xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin) - push!(conditional_samples, xmin) - elseif choose_random - push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer)) - else - push!(conditional_samples, eng_sampler.buffer) - end + ignore_derivatives() do + xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...) + push!(conditional_samples, xsampled) end xgenerated = conditional_samples[1] # conditional samples @@ -79,16 +84,17 @@ function energy_delta( # Generative loss: gen_loss = E(xproposed) .- E(xgenerated) - gen_loss = reduce((x, y) -> x + y, gen_loss) / n # aggregate over samples + gen_loss = reduce((x, y) -> x + y, gen_loss) / length(gen_loss) # aggregate over samples # Regularization loss: reg_loss = E(xgenerated).^2 .+ E(xproposed).^2 - reg_loss = reduce((x, y) -> x + y, reg_loss) / n # aggregate over samples + reg_loss = reduce((x, y) -> x + y, reg_loss) / length(reg_loss) # aggregate over samples - if return_conditionals + if !return_conditionals + return gen_loss + reg_strength * reg_loss + else return conditional_samples[1] end - return gen_loss + reg_strength * reg_loss end -- GitLab