From 9f274e83212768e189f4b1407bbb2efdddcbbd9c Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Wed, 13 Sep 2023 08:13:08 +0200
Subject: [PATCH] fun times

---
 experiments/Manifest.toml    | 52 ++++++++++++++--------------------
 experiments/circles.jl       |  4 ++-
 experiments/gmsc.jl          |  3 +-
 experiments/models/models.jl |  7 +++--
 experiments/moons.jl         |  3 +-
 src/penalties.jl             | 54 ++++++++++++++++++++----------------
 6 files changed, 63 insertions(+), 60 deletions(-)

diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml
index d7bbde7e..b8f01c6e 100644
--- a/experiments/Manifest.toml
+++ b/experiments/Manifest.toml
@@ -248,9 +248,9 @@ version = "1.0.5"
 
 [[deps.CairoMakie]]
 deps = ["Base64", "Cairo", "Colors", "FFTW", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "PrecompileTools", "SHA"]
-git-tree-sha1 = "30562a68ded3dabe80109caf6b4de73a48ac27bc"
+git-tree-sha1 = "696e7931bd6f5c773418452cbe5fd241cb85ac2a"
 uuid = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
-version = "0.10.8"
+version = "0.10.9"
 
 [[deps.Cairo_jll]]
 deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
@@ -508,13 +508,9 @@ version = "0.1.2"
 
 [[deps.DelaunayTriangulation]]
 deps = ["DataStructures", "EnumX", "ExactPredicates", "Random", "SimpleGraphs"]
-git-tree-sha1 = "a1d8532de83f8ce964235eff1edeff9581144d02"
+git-tree-sha1 = "bea7984f7e09aeb28a3b071c420a0186cb4fabad"
 uuid = "927a84f5-c5f4-47a5-9785-b46e178433df"
-version = "0.7.2"
-weakdeps = ["MakieCore"]
-
-    [deps.DelaunayTriangulation.extensions]
-    DelaunayTriangulationMakieCoreExt = "MakieCore"
+version = "0.8.8"
 
 [[deps.DelimitedFiles]]
 deps = ["Mmap"]
@@ -647,12 +643,6 @@ git-tree-sha1 = "5e1e4c53fa39afe63a7d356e30452249365fba99"
 uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910"
 version = "0.1.1"
 
-[[deps.FFMPEG]]
-deps = ["FFMPEG_jll"]
-git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8"
-uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a"
-version = "0.4.1"
-
 [[deps.FFMPEG_jll]]
 deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"]
 git-tree-sha1 = "466d45dc38e15794ec7d5d63ec03d776a9aff36e"
@@ -1331,9 +1321,9 @@ version = "2023.2.0+0"
 
 [[deps.MLDatasets]]
 deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"]
-git-tree-sha1 = "41922968c0aaca46baa5d658d3a173828313e2d0"
+git-tree-sha1 = "10bc70e4c875f1b2ca65cef3ef9ebe705ef936b5"
 uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458"
-version = "0.7.12"
+version = "0.7.13"
 
 [[deps.MLFlowClient]]
 deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"]
@@ -1451,16 +1441,16 @@ uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
 version = "0.5.11"
 
 [[deps.Makie]]
-deps = ["Animations", "Base64", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "Match", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Setfield", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"]
-git-tree-sha1 = "e81675589ba7199a82443e87fc52e17eeceac2e8"
+deps = ["Animations", "Base64", "CRC32c", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG_jll", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "Match", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Setfield", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"]
+git-tree-sha1 = "ecc334efc4a8a68800776b0d85ab7bb2fff63f7a"
 uuid = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
-version = "0.19.8"
+version = "0.19.9"
 
 [[deps.MakieCore]]
 deps = ["Observables"]
-git-tree-sha1 = "f56b09c8b964919373d61750c6d8d4d2c602a2be"
+git-tree-sha1 = "1efb1166dd9398f2ccf6d728f896658c9c84733e"
 uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
-version = "0.6.5"
+version = "0.6.6"
 
 [[deps.MappedArrays]]
 git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e"
@@ -1772,10 +1762,10 @@ uuid = "8b842266-38fa-440a-9b57-31493939ab85"
 version = "0.1.4"
 
 [[deps.Pango_jll]]
-deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg"]
-git-tree-sha1 = "84a314e3926ba9ec66ac097e3635e270986b0f10"
+deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl"]
+git-tree-sha1 = "4745216e94f71cb768d58330b059c9b76f32cb66"
 uuid = "36c8627f-9965-5494-a995-c6b170f724f3"
-version = "1.50.9+0"
+version = "1.50.14+0"
 
 [[deps.Parameters]]
 deps = ["OrderedCollections", "UnPack"]
@@ -1911,9 +1901,9 @@ version = "0.1.4"
 
 [[deps.ProgressMeter]]
 deps = ["Distributed", "Printf"]
-git-tree-sha1 = "ae36206463b2395804f2787ffe172f44452b538d"
+git-tree-sha1 = "00099623ffee15972c16111bcf84c58a0051257c"
 uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
-version = "1.8.0"
+version = "1.9.0"
 
 [[deps.QOI]]
 deps = ["ColorTypes", "FileIO", "FixedPointNumbers"]
@@ -2170,10 +2160,10 @@ uuid = "171d559e-b47b-412a-8079-5efa626c420e"
 version = "0.1.15"
 
 [[deps.StableHashTraits]]
-deps = ["CRC32c", "Compat", "Dates", "SHA", "Tables", "TupleTools", "UUIDs"]
-git-tree-sha1 = "0b8b801b8f03a329a4e86b44c5e8a7d7f4fe10a3"
+deps = ["Compat", "SHA", "Tables", "TupleTools"]
+git-tree-sha1 = "19df33ca14f24a3ad2df9e89124bd5f5cc8467a2"
 uuid = "c5dd0088-6c3f-4803-b00e-f31a60c170fa"
-version = "0.3.1"
+version = "1.0.1"
 
 [[deps.StableRNGs]]
 deps = ["Random", "Test"]
@@ -2324,9 +2314,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [[deps.TidierData]]
 deps = ["Chain", "Cleaner", "DataFrames", "MacroTools", "Reexport", "ShiftedArrays", "Statistics", "StatsBase"]
-git-tree-sha1 = "208ee18374d1aff392d7c50f2753568922bdf68c"
+git-tree-sha1 = "3a34d01b775181ca86f4315ea5895228b1742b81"
 uuid = "fe2206b3-d496-4ee9-a338-6a095c4ece80"
-version = "0.12.0"
+version = "0.12.1"
 
 [[deps.TidierPlots]]
 deps = ["AlgebraOfGraphics", "CairoMakie", "DataFrames", "Makie", "MarketData", "PalmerPenguins", "Reexport"]
diff --git a/experiments/circles.jl b/experiments/circles.jl
index aee790da..a7dc16dc 100644
--- a/experiments/circles.jl
+++ b/experiments/circles.jl
@@ -2,10 +2,12 @@ n_obs = Int(1000 / (1.0 - TEST_SIZE))
 counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE)
 run_experiment(
     counterfactual_data, test_data; 
+    epochs=100,
     dataname="Circles",
     n_hidden=32,
     α=[1.0, 1.0, 1e-2],
     sampling_steps=20,
-    Λ=[0.25, 0.75, 0.75],
+    Λ=[0.1, 0.1, 0.1],
     opt=Flux.Optimise.Descent(0.01),
+    activation=Flux.swish,
 )
\ No newline at end of file
diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index c4641200..aecdbd38 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -16,12 +16,13 @@ n_ind = N_IND_SPECIFIED ? N_IND : 10
 run_experiment(
     counterfactual_data, test_data; 
     dataname="GMSC",
+    epochs=100,
     builder = builder,
     α=[1.0, 1.0, 1e-1],
     sampling_batch_size=10,
     sampling_steps = 30,
     use_ensembling = true,
-    Λ=[0.1, 0.5, 0.5],
+    Λ=[0.1, 0.1, 0.1],
     opt = Flux.Optimise.Descent(0.05),
     n_individuals = n_ind,
     use_variants = false, 
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 022b6937..900e95b9 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -47,9 +47,12 @@ function prepare_models(exper::Experiment)
         @info "Training models."
         model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage)
     else
-        @info "Loading pre-trained models."
-        model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
+        if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+            @info "Loading pre-trained models."
+            model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
+        end
         if is_multi_processed(exper)
+            model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0)
             MPI.Barrier(exper.parallelizer.comm)
         end
     end
diff --git a/experiments/moons.jl b/experiments/moons.jl
index 02034380..8b096983 100644
--- a/experiments/moons.jl
+++ b/experiments/moons.jl
@@ -8,6 +8,7 @@ run_experiment(
     activation = Flux.relu,
     sampling_batch_size=10,
     sampling_steps=30,
-    Λ=[0.25, 0.75, 0.75],
+    Λ=[0.1, 0.1, 0.1],
     opt=Flux.Optimise.Descent(0.05),
+    α=[1.0, 1.0, 1e-1]
 )
\ No newline at end of file
diff --git a/src/penalties.jl b/src/penalties.jl
index 78ff7df3..030a22e6 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -46,30 +46,35 @@ function energy_delta(
     choose_random=false,
     nmin::Int=25,
     return_conditionals=false,
-    reg_strength=0.5,
+    reg_strength=0.1,
     kwargs...
 )
 
-    nmin = minimum([nmin, n])
-
-    @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
-
+    # nmin = minimum([nmin, n])
+
+    # @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
+
+    # conditional_samples = []
+    # ignore_derivatives() do
+    #     _dict = ce.params
+    #     if !(:energy_sampler ∈ collect(keys(_dict)))
+    #         _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
+    #     end
+    #     eng_sampler = _dict[:energy_sampler]
+    #     if choose_lowest_energy
+    #         nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
+    #         xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
+    #         push!(conditional_samples, xmin)
+    #     elseif choose_random
+    #         push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
+    #     else
+    #         push!(conditional_samples, eng_sampler.buffer)
+    #     end
+    # end
     conditional_samples = []
-    ignore_derivatives() do
-        _dict = ce.params
-        if !(:energy_sampler ∈ collect(keys(_dict)))
-            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
-        end
-        eng_sampler = _dict[:energy_sampler]
-        if choose_lowest_energy
-            nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
-            xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
-            push!(conditional_samples, xmin)
-        elseif choose_random
-            push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
-        else
-            push!(conditional_samples, eng_sampler.buffer)
-        end
+    ignore_derivatives() do 
+        xsampled = ECCCo.EnergySampler(ce; niter=niter, nsamples=ce.num_counterfactuals, kwargs...)
+        push!(conditional_samples, xsampled)
     end
 
     xgenerated = conditional_samples[1]                         # conditional samples
@@ -79,16 +84,17 @@ function energy_delta(
 
     # Generative loss:
     gen_loss = E(xproposed) .- E(xgenerated)
-    gen_loss = reduce((x, y) -> x + y, gen_loss) / n                  # aggregate over samples
+    gen_loss = reduce((x, y) -> x + y, gen_loss) / length(gen_loss)                  # aggregate over samples
 
     # Regularization loss:
     reg_loss = E(xgenerated).^2 .+ E(xproposed).^2
-    reg_loss = reduce((x, y) -> x + y, reg_loss) / n                  # aggregate over samples
+    reg_loss = reduce((x, y) -> x + y, reg_loss) / length(reg_loss)                  # aggregate over samples
 
-    if return_conditionals
+    if !return_conditionals
+        return gen_loss + reg_strength * reg_loss
+    else
         return conditional_samples[1]
     end
-    return gen_loss + reg_strength * reg_loss
 
 end
 
-- 
GitLab