Skip to content
Snippets Groups Projects
Commit ce41cb23 authored by Pat Alt's avatar Pat Alt
Browse files

trying to sort out the sampling :sob:

parent 657b4610
No related branches found
No related tags found
No related merge requests found
......@@ -49,20 +49,26 @@ uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
version = "0.5.4"
[[deps.Arpack_jll]]
deps = ["Libdl", "OpenBLAS_jll", "Pkg"]
git-tree-sha1 = "e214a9b9bd1b4e1b4f15b22c0994862b66af7ff7"
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "OpenBLAS_jll", "Pkg"]
git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e"
uuid = "68821587-b530-5797-8361-c406ea357684"
version = "3.5.0+3"
version = "3.5.1+1"
[[deps.ArrayInterface]]
deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "a89acc90c551067cd84119ff018619a1a76c6277"
git-tree-sha1 = "38911c7737e123b28182d89027f4216cfc8a9da7"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.2.1"
version = "7.4.3"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[deps.Atomix]]
deps = ["UnsafeAtomics"]
git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
version = "0.1.0"
[[deps.AxisAlgorithms]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"]
git-tree-sha1 = "66771c8d21c8ff5e3a93379480a2307ac36863f7"
......@@ -76,9 +82,9 @@ uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.4.2"
[[deps.BSON]]
git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2"
git-tree-sha1 = "2208958832d6e1b59e49f53697483a84ca8d664e"
uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
version = "0.3.6"
version = "0.3.7"
[[deps.BangBang]]
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
......@@ -94,11 +100,10 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e"
uuid = "9718e550-a3fa-408a-8086-8db961cd8217"
version = "0.1.1"
[[deps.BinaryProvider]]
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.10"
[[deps.BitFlags]]
git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d"
uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35"
version = "0.1.7"
[[deps.BufferedStreams]]
git-tree-sha1 = "bb065b14d7f941b8617bc323063dbe79f55d16ea"
......@@ -123,34 +128,34 @@ uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
version = "0.10.9"
[[deps.CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions"]
git-tree-sha1 = "edff14c60784c8f7191a62a23b15a421185bc8a8"
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "8547829ee0da896ce48a24b8d2f4bf929cf3e45e"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "4.0.1"
version = "4.1.4"
[[deps.CUDA_Driver_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
git-tree-sha1 = "75d7896d1ec079ef10d3aee8f3668c11354c03a1"
git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b"
uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
version = "0.2.0+0"
version = "0.5.0+1"
[[deps.CUDA_Runtime_Discovery]]
deps = ["Libdl"]
git-tree-sha1 = "58dd8ec29f54f08c04b052d2c2fa6760b4f4b3a4"
git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536"
uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
version = "0.1.1"
version = "0.2.2"
[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
git-tree-sha1 = "d3e6ccd30f84936c1a3a53d622d85d7d3f9b9486"
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "81eed046f28a0cdd0dc1f61d00a49061b7cc9433"
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
version = "0.2.3+2"
version = "0.5.0+2"
[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
git-tree-sha1 = "aafe89dfde54011993c4029d3be3e037fd63db07"
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "2918fbffb50e3b7a0b9127617587afa76d4276e8"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
version = "8.6.0+5"
version = "8.8.1+0"
[[deps.Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
......@@ -214,9 +219,9 @@ version = "0.14.4"
[[deps.CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "d57c99cc7e637165c81b30eb268eabe156a45c49"
git-tree-sha1 = "d730914ef30a06732bdd9f763f6cc32e92ffbff1"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.2.2"
version = "1.3.1"
[[deps.CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
......@@ -320,10 +325,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.14.0"
[[deps.DataDeps]]
deps = ["BinaryProvider", "HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"]
git-tree-sha1 = "e299d8267135ef2f9c941a764006697082c1e7e8"
deps = ["HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"]
git-tree-sha1 = "bc0a264d3e7b3eeb0b6fc9f6481f970697f29805"
uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
version = "0.7.8"
version = "0.7.10"
[[deps.DataFrames]]
deps = ["Compat", "DataAPI", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SnoopPrecompile", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
......@@ -381,9 +386,9 @@ version = "1.13.0"
[[deps.Distances]]
deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04"
git-tree-sha1 = "49eba9ad9f7ead780bfb7ee319f962c811c6d3b2"
uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
version = "0.10.7"
version = "0.10.8"
[[deps.Distributed]]
deps = ["Random", "Serialization", "Sockets"]
......@@ -391,9 +396,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]]
deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "da9e1a9058f8d3eec3a8c9fe4faacfb89180066b"
git-tree-sha1 = "13027f188d26206b9e7b863036f87d2f2e7d013a"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.86"
version = "0.25.87"
[[deps.DocStringExtensions]]
deps = ["LibGit2"]
......@@ -476,15 +481,15 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
[[deps.FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
git-tree-sha1 = "3b245d1e50466ca0c9529e2033a3c92387c59c2f"
git-tree-sha1 = "7072f1e3e5a8be51d525d64f63d3ec1287ff2790"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.13.9"
version = "0.13.11"
[[deps.FiniteDiff]]
deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"]
git-tree-sha1 = "ed1b56934a2f7a65035976985da71b6a65b4f2cf"
git-tree-sha1 = "03fcb1c42ec905d15b305359603888ec3e65f886"
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
version = "2.18.0"
version = "2.19.0"
[[deps.FixedPointNumbers]]
deps = ["Statistics"]
......@@ -493,10 +498,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
version = "0.8.4"
[[deps.Flux]]
deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"]
git-tree-sha1 = "4ff3a1d7b0dd38f2fc38e813bc801f817639c1f2"
deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote", "cuDNN"]
git-tree-sha1 = "3f6f32ec0bfd80be0cb65907cf74ec796a632012"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.13.13"
version = "0.13.15"
[[deps.FoldsThreads]]
deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
......@@ -541,9 +546,9 @@ version = "1.1.3"
[[deps.Functors]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "7ed0833a55979d3d2658a60b901469748a6b9a7c"
git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.4.3"
version = "0.4.4"
[[deps.Future]]
deps = ["Random"]
......@@ -557,9 +562,9 @@ version = "3.3.8+0"
[[deps.GPUArrays]]
deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
git-tree-sha1 = "7a2e790b1e2e6f648cfb25c4500c5de1f7b375ef"
git-tree-sha1 = "9ade6983c3dbbd492cf5729f865fe030d1541463"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "8.6.5"
version = "8.6.6"
[[deps.GPUArraysCore]]
deps = ["Adapt"]
......@@ -568,22 +573,22 @@ uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
version = "0.1.4"
[[deps.GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "19d693666a304e8c371798f4900f7435558c7cde"
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.17.3"
version = "0.19.3"
[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"]
git-tree-sha1 = "4423d87dc2d3201f3f1768a29e807ddc8cc867ef"
git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362"
uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
version = "0.71.8"
version = "0.72.1"
[[deps.GR_jll]]
deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"]
git-tree-sha1 = "3657eb348d44575cc5560c80d7e55b812ff6ffe1"
git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3"
uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
version = "0.71.8+0"
version = "0.72.1+1"
[[deps.GZip]]
deps = ["Libdl"]
......@@ -604,9 +609,9 @@ uuid = "7746bdde-850d-59dc-9ae8-88ece973131d"
version = "2.74.0+2"
[[deps.Glob]]
git-tree-sha1 = "4df9f7e06108728ebf00a0a11edee4b29a482bb2"
git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496"
uuid = "c27321d9-0574-5035-807b-f59d2c89b15c"
version = "1.3.0"
version = "1.3.1"
[[deps.Graphics]]
deps = ["Colors", "LinearAlgebra", "NaNMath"]
......@@ -638,10 +643,10 @@ uuid = "0234f1f7-429e-5d53-9886-15a909be8d59"
version = "1.12.2+2"
[[deps.HTTP]]
deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"]
git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a"
deps = ["Base64", "CodecZlib", "Dates", "IniFile", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"]
git-tree-sha1 = "37e4657cd56b11abe3d10cd4a1ec5fbdb4180263"
uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3"
version = "0.9.17"
version = "1.7.4"
[[deps.HarfBuzz_jll]]
deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"]
......@@ -656,16 +661,16 @@ uuid = "eafb193a-b7ab-5a9e-9068-77385905fa72"
version = "0.5.2"
[[deps.HypergeometricFunctions]]
deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions", "Test"]
git-tree-sha1 = "709d864e3ed6e3545230601f94e11ebc65994641"
deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
git-tree-sha1 = "432b5b03176f8182bd6841fbfc42c718506a2d5f"
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
version = "0.3.11"
version = "0.3.15"
[[deps.IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "2af2fe19f0d5799311a6491267a14817ad9fbd20"
git-tree-sha1 = "0ade27f0c49cebd8db2523c4eeccf779407cf12c"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.4.8"
version = "0.4.9"
[[deps.ImageBase]]
deps = ["ImageCore", "Reexport"]
......@@ -708,9 +713,9 @@ version = "1.4.0"
[[deps.IntelOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c"
git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83"
uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0"
version = "2018.0.3+2"
version = "2023.1.0+0"
[[deps.InteractiveUtils]]
deps = ["Markdown"]
......@@ -769,9 +774,9 @@ version = "1.4.1"
[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.3"
version = "0.21.4"
[[deps.JSON3]]
deps = ["Dates", "Mmap", "Parsers", "SnoopPrecompile", "StructTypes", "UUIDs"]
......@@ -780,7 +785,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.12.0"
[[deps.JointEnergyModels]]
deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
path = "../JointEnergyModels.jl"
uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
version = "0.1.0"
......@@ -802,6 +807,12 @@ git-tree-sha1 = "4aeebbfcf0615641ec4b0782b73b638eeeabd62e"
uuid = "5cadff95-7770-533d-a838-a1bf817ee6e0"
version = "0.3.0"
[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "SnoopPrecompile", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "976231af02176082fb266a9f96a59da51fcacf20"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.9.2"
[[deps.KernelDensity]]
deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"]
git-tree-sha1 = "9816b296736292a80b9a3200eb7fbb57aaa3917a"
......@@ -822,15 +833,15 @@ version = "3.0.0+1"
[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "1c614dfbecbaee4897b506bba2b432bf0d21f2ed"
git-tree-sha1 = "a8960cae30b42b66dd41808beb76490519f6f9e2"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "4.17.0"
version = "5.0.0"
[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "e46e3a40daddcbe851f86db0ec4a4a3d4badf800"
git-tree-sha1 = "09b7505cc0b1cee87e5d4a26eea61d2e1b0dcd35"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.19+0"
version = "0.0.21+0"
[[deps.LZO_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -953,11 +964,17 @@ version = "0.3.23"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
[[deps.LoggingExtras]]
deps = ["Dates", "Logging"]
git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e"
uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
version = "1.0.0"
[[deps.LossFunctions]]
deps = ["InteractiveUtils", "Markdown", "RecipesBase"]
git-tree-sha1 = "53cd63a12f06a43eef6f4aafb910ac755c122be7"
deps = ["CategoricalArrays", "Markdown"]
git-tree-sha1 = "d4c7ff8c7281943371e1725000fd538a699024d0"
uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
version = "0.8.0"
version = "0.9.0"
[[deps.LsqFit]]
deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"]
......@@ -985,9 +1002,9 @@ version = "0.7.9"
[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "6f3a7338e787cbf3460f035c21ee2547f71f8007"
git-tree-sha1 = "e99e84c7ca1696b38e2a8823f53ac9cc775599dd"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.21.6"
version = "0.21.9"
[[deps.MLJEnsembles]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"]
......@@ -1009,9 +1026,9 @@ version = "1.8.0"
[[deps.MLJModels]]
deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "1d445497ca058dbc0dbc7528b778707893edb969"
git-tree-sha1 = "21acf47dc53ccc3d68e38ac7629756cd09b599f5"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
version = "0.16.4"
version = "0.16.6"
[[deps.MLStyle]]
git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
......@@ -1122,10 +1139,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c"
version = "7.8.3"
[[deps.NNlib]]
deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e"
deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
git-tree-sha1 = "99e6dbb50d8a96702dc60954569e9fe7291cc55d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.8.19"
version = "0.8.20"
[[deps.NNlibCUDA]]
deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"]
......@@ -1164,9 +1181,9 @@ version = "0.3.5"
[[deps.NearestNeighborModels]]
deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "727b8f1c3f9fec6b1a805ba9bef72c73758eda02"
git-tree-sha1 = "c2179f9d8de066c481b889a1426068c5831bb10b"
uuid = "636a865e-7cf4-491e-846c-de09b730eb36"
version = "0.2.1"
version = "0.2.2"
[[deps.NearestNeighbors]]
deps = ["Distances", "StaticArrays"]
......@@ -1211,6 +1228,12 @@ deps = ["Artifacts", "Libdl"]
uuid = "05823500-19ac-5b8b-9628-191a04bc5112"
version = "0.8.1+0"
[[deps.OpenSSL]]
deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"]
git-tree-sha1 = "5b3e170ea0724f1e3ed6018c5b006c190f80e87d"
uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c"
version = "1.3.5"
[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "9ff31d101d987eb9d66bd8b176ac7c277beccd09"
......@@ -1231,9 +1254,9 @@ version = "2.0.2"
[[deps.Optimisers]]
deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "e5a1825d3d53aa4ad4fb42bd4927011ad4a78c3d"
git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.2.15"
version = "0.2.18"
[[deps.Opus_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -1242,9 +1265,9 @@ uuid = "91d4177d-7536-5919-b921-800302f37372"
version = "1.3.2+0"
[[deps.OrderedCollections]]
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.4.1"
version = "1.6.0"
[[deps.PCRE2_jll]]
deps = ["Artifacts", "Libdl"]
......@@ -1259,9 +1282,9 @@ version = "0.11.17"
[[deps.PaddedViews]]
deps = ["OffsetArrays"]
git-tree-sha1 = "03a7a85b76381a3d04c7a1656039197e70eda03d"
git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f"
uuid = "5432bcbf-9aad-5242-b902-cca2824c8663"
version = "0.5.11"
version = "0.5.12"
[[deps.Parameters]]
deps = ["OrderedCollections", "UnPack"]
......@@ -1299,9 +1322,9 @@ version = "1.8.0"
[[deps.PkgTemplates]]
deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"]
git-tree-sha1 = "e93643cc634d7551f68b63739d205d22216cec7c"
git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3"
uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
version = "0.7.32"
version = "0.7.33"
[[deps.PlotThemes]]
deps = ["PlotUtils", "Statistics"]
......@@ -1317,9 +1340,9 @@ version = "1.3.4"
[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"]
git-tree-sha1 = "da1d3fb7183e38603fcdd2061c47979d91202c97"
git-tree-sha1 = "5434b0ee344eaf2854de251f326df8720f6a7b55"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.38.6"
version = "1.38.10"
[[deps.PooledArrays]]
deps = ["DataAPI", "Future"]
......@@ -1501,6 +1524,11 @@ git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de"
uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f"
version = "1.0.3"
[[deps.SimpleBufferStream]]
git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1"
uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7"
version = "1.1.0"
[[deps.SimpleTraits]]
deps = ["InteractiveUtils", "MacroTools"]
git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231"
......@@ -1552,9 +1580,9 @@ version = "0.1.1"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
git-tree-sha1 = "6aa098ef1012364f2ede6b17bf358c7f1fbe90d4"
git-tree-sha1 = "63e84b7fdf5021026d0f17f76af7c57772313d99"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.17"
version = "1.5.21"
[[deps.StaticArraysCore]]
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
......@@ -1573,9 +1601,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[[deps.StatsAPI]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6"
git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.5.0"
version = "1.6.0"
[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
......@@ -1686,9 +1714,9 @@ version = "0.2.20"
[[deps.TranscodingStreams]]
deps = ["Random", "Test"]
git-tree-sha1 = "94f38103c984f89cf77c402f2a68dbd870f8165f"
git-tree-sha1 = "0b829474fed270a4b0ab07117dce9b9a2fa7581a"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.11"
version = "0.9.12"
[[deps.Transducers]]
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
......@@ -1732,14 +1760,25 @@ version = "0.4.1"
[[deps.UnicodePlots]]
deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase"]
git-tree-sha1 = "a5bcfc23e352f499a1a46f428d0d3d7fb9e4fc11"
git-tree-sha1 = "2825e58f6ec3cab889dfa2c824f8d89b9f7ee731"
uuid = "b8865327-cd53-5732-bb35-84acbb429228"
version = "3.4.1"
version = "3.5.1"
[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
version = "0.2.1"
[[deps.UnsafeAtomicsLLVM]]
deps = ["LLVM", "UnsafeAtomics"]
git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175"
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
version = "0.1.2"
[[deps.Unzip]]
git-tree-sha1 = "34db80951901073501137bdbc3d5a8e7bbd06670"
git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78"
uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d"
version = "0.1.2"
version = "0.2.0"
[[deps.Wayland_jll]]
deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"]
......@@ -1927,27 +1966,27 @@ version = "1.2.12+3"
[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "c6edfe154ad7b313c01aceca188c05c835c67360"
git-tree-sha1 = "49ce682769cd5de6c72dcf1b94ed7790cd08974c"
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
version = "1.5.4+0"
version = "1.5.5+0"
[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "4df8f470806a45a8630ac8f597304821dc8e8838"
git-tree-sha1 = "987ae5554ca90e837594a0f30325eeb5e7303d1e"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.59"
version = "0.6.60"
[[deps.ZygoteRules]]
deps = ["MacroTools"]
git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0"
deps = ["ChainRulesCore", "MacroTools"]
git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.2"
version = "0.2.3"
[[deps.cuDNN]]
deps = ["CEnum", "CUDA", "CUDNN_jll"]
git-tree-sha1 = "c0ffcb38d1e8c0bbcd3dab2559cf9c369130b2f2"
git-tree-sha1 = "3aa15aba7aad5be8b9b3c1b77a9b81e3e1357280"
uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
version = "1.0.1"
version = "1.0.2"
[[deps.fzf_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......
......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "bb24fa6d048fab99674a941d85c45a034b033aae"
project_hash = "b4b125f21013c1ac841e2bb761cd8922630f9f03"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......@@ -156,9 +156,9 @@ version = "0.10.9"
[[deps.CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "6591ddc73adb429b9d97145c8197a0ac81664ab4"
git-tree-sha1 = "8547829ee0da896ce48a24b8d2f4bf929cf3e45e"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "4.1.3"
version = "4.1.4"
[[deps.CUDA_Driver_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
......@@ -168,9 +168,9 @@ version = "0.5.0+1"
[[deps.CUDA_Runtime_Discovery]]
deps = ["Libdl"]
git-tree-sha1 = "6c8fceaaa6850dea627288ac3bb86fdcdf05e326"
git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536"
uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
version = "0.2.0"
version = "0.2.2"
[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
......@@ -275,9 +275,9 @@ version = "0.14.4"
[[deps.CodeTracking]]
deps = ["InteractiveUtils", "UUIDs"]
git-tree-sha1 = "0683f086e2ef8e2fdacd3f246b35c59e7088b283"
git-tree-sha1 = "d730914ef30a06732bdd9f763f6cc32e92ffbff1"
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
version = "1.3.0"
version = "1.3.1"
[[deps.CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
......@@ -475,9 +475,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]]
deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "da9e1a9058f8d3eec3a8c9fe4faacfb89180066b"
git-tree-sha1 = "13027f188d26206b9e7b863036f87d2f2e7d013a"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.86"
version = "0.25.87"
[[deps.DocStringExtensions]]
deps = ["LibGit2"]
......@@ -601,9 +601,9 @@ version = "0.8.4"
[[deps.Flux]]
deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote", "cuDNN"]
git-tree-sha1 = "e657a9aad824de4211606f113edd0b50d5e1f6db"
git-tree-sha1 = "3f6f32ec0bfd80be0cb65907cf74ec796a632012"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.13.14"
version = "0.13.15"
[[deps.FoldsThreads]]
deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
......@@ -694,21 +694,21 @@ version = "0.1.4"
[[deps.GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "590d394bad1055b798b2f9b308327ba871b7badf"
git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.19.0"
version = "0.19.3"
[[deps.GR]]
deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"]
git-tree-sha1 = "4423d87dc2d3201f3f1768a29e807ddc8cc867ef"
git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362"
uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
version = "0.71.8"
version = "0.72.1"
[[deps.GR_jll]]
deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"]
git-tree-sha1 = "3657eb348d44575cc5560c80d7e55b812ff6ffe1"
git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3"
uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
version = "0.71.8+0"
version = "0.72.1+1"
[[deps.GZip]]
deps = ["Libdl"]
......@@ -812,9 +812,9 @@ version = "0.5.2"
[[deps.HypergeometricFunctions]]
deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
git-tree-sha1 = "d926e9c297ef4607866e8ef5df41cde1a642917f"
git-tree-sha1 = "432b5b03176f8182bd6841fbfc42c718506a2d5f"
uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a"
version = "0.3.14"
version = "0.3.15"
[[deps.IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
......@@ -963,9 +963,9 @@ version = "0.1.5"
[[deps.IntelOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c"
git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83"
uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0"
version = "2018.0.3+2"
version = "2023.1.0+0"
[[deps.InteractiveUtils]]
deps = ["Markdown"]
......@@ -1041,9 +1041,9 @@ version = "1.4.1"
[[deps.JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.3"
version = "0.21.4"
[[deps.JSON3]]
deps = ["Dates", "Mmap", "Parsers", "SnoopPrecompile", "StructTypes", "UUIDs"]
......@@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.12.0"
[[deps.JointEnergyModels]]
deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
path = "../../JointEnergyModels.jl"
uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
version = "0.1.0"
......@@ -1082,9 +1082,9 @@ version = "0.3.0"
[[deps.KernelAbstractions]]
deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "SnoopPrecompile", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
git-tree-sha1 = "350a880e80004f4d5d82a17f737d8fcdc56c3462"
git-tree-sha1 = "976231af02176082fb266a9f96a59da51fcacf20"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
version = "0.9.1"
version = "0.9.2"
[[deps.KernelDensity]]
deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"]
......@@ -1250,10 +1250,10 @@ uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36"
version = "1.0.0"
[[deps.LossFunctions]]
deps = ["InteractiveUtils", "Markdown", "RecipesBase"]
git-tree-sha1 = "f27330f931944ecee340f004302db724c1985955"
deps = ["CategoricalArrays", "Markdown"]
git-tree-sha1 = "d4c7ff8c7281943371e1725000fd538a699024d0"
uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
version = "0.8.1"
version = "0.9.0"
[[deps.LsqFit]]
deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"]
......@@ -1281,9 +1281,9 @@ version = "0.7.9"
[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "37a311b0cd581764fc460f6632e6219dc32f9427"
git-tree-sha1 = "e99e84c7ca1696b38e2a8823f53ac9cc775599dd"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.21.8"
version = "0.21.9"
[[deps.MLJEnsembles]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"]
......@@ -1305,9 +1305,9 @@ version = "1.8.0"
[[deps.MLJModels]]
deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "865558dcdb963789ba82651b990d0fb9c5e8dd59"
git-tree-sha1 = "21acf47dc53ccc3d68e38ac7629756cd09b599f5"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
version = "0.16.5"
version = "0.16.6"
[[deps.MLStyle]]
git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
......@@ -1360,9 +1360,9 @@ version = "1.2.0"
[[deps.MathTeXEngine]]
deps = ["AbstractTrees", "Automa", "DataStructures", "FreeTypeAbstraction", "GeometryBasics", "LaTeXStrings", "REPL", "RelocatableFolders", "Test", "UnicodeFun"]
git-tree-sha1 = "64890e1e8087b71c03bd6b8af99b49c805b2a78d"
git-tree-sha1 = "8f52dbaa1351ce4cb847d95568cb29e62a307d93"
uuid = "0a4f8689-d25c-4efe-a92b-7142dfc1aa53"
version = "0.5.5"
version = "0.5.6"
[[deps.MbedTLS]]
deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"]
......@@ -1453,10 +1453,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c"
version = "7.8.3"
[[deps.NNlib]]
deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e"
deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
git-tree-sha1 = "99e6dbb50d8a96702dc60954569e9fe7291cc55d"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.8.19"
version = "0.8.20"
[[deps.NNlibCUDA]]
deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"]
......@@ -1562,9 +1562,9 @@ version = "0.8.1+0"
[[deps.OpenSSL]]
deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"]
git-tree-sha1 = "6503b77492fd7fcb9379bf73cd31035670e3c509"
git-tree-sha1 = "5b3e170ea0724f1e3ed6018c5b006c190f80e87d"
uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c"
version = "1.3.3"
version = "1.3.5"
[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -1586,9 +1586,9 @@ version = "2.0.2"
[[deps.Optimisers]]
deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "4b214125921ec010160ddb39931885e0a6585639"
git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.2.17"
version = "0.2.18"
[[deps.Opus_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -1626,9 +1626,9 @@ version = "0.5.0"
[[deps.PaddedViews]]
deps = ["OffsetArrays"]
git-tree-sha1 = "03a7a85b76381a3d04c7a1656039197e70eda03d"
git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f"
uuid = "5432bcbf-9aad-5242-b902-cca2824c8663"
version = "0.5.11"
version = "0.5.12"
[[deps.Pango_jll]]
deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg"]
......@@ -1672,9 +1672,9 @@ version = "1.8.0"
[[deps.PkgTemplates]]
deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"]
git-tree-sha1 = "e93643cc634d7551f68b63739d205d22216cec7c"
git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3"
uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
version = "0.7.32"
version = "0.7.33"
[[deps.PkgVersion]]
deps = ["Pkg"]
......@@ -1696,9 +1696,9 @@ version = "1.3.4"
[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"]
git-tree-sha1 = "f49a45a239e13333b8b936120fe6d793fe58a972"
git-tree-sha1 = "5434b0ee344eaf2854de251f326df8720f6a7b55"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.38.8"
version = "1.38.10"
[[deps.PolygonOps]]
git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6"
......@@ -2023,9 +2023,9 @@ version = "0.1.1"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
git-tree-sha1 = "b8d897fe7fa688e93aef573711cb207c08c9e11e"
git-tree-sha1 = "63e84b7fdf5021026d0f17f76af7c57772313d99"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.19"
version = "1.5.21"
[[deps.StaticArraysCore]]
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
......@@ -2062,9 +2062,9 @@ version = "1.3.0"
[[deps.StatsModels]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsBase", "StatsFuns", "Tables"]
git-tree-sha1 = "06a230063087c11910e9bbd17ccbf5af792a27a4"
git-tree-sha1 = "51cdf1afd9d78552e7a08536930d7abc3b288a5c"
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
version = "0.7.0"
version = "0.7.1"
[[deps.StatsPlots]]
deps = ["AbstractFFTs", "Clustering", "DataStructures", "DataValues", "Distributions", "Interpolations", "KernelDensity", "LinearAlgebra", "MultivariateStats", "NaNMath", "Observables", "Plots", "RecipesBase", "RecipesPipeline", "Reexport", "StatsBase", "TableOperations", "Tables", "Widgets"]
......@@ -2151,9 +2151,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[deps.Tidier]]
deps = ["Chain", "Cleaner", "DataFrames", "MacroTools", "Reexport", "ShiftedArrays", "Statistics"]
git-tree-sha1 = "6c01fc23066480d998d3eabba0f999c46e88ed73"
git-tree-sha1 = "107614fa8ae68ca12a193df719ec2438e86ab97b"
uuid = "f0413319-3358-4bb0-8e7c-0c83523a93bd"
version = "0.7.1"
version = "0.7.4"
[[deps.TiffImages]]
deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"]
......@@ -2181,9 +2181,9 @@ version = "0.2.20"
[[deps.TranscodingStreams]]
deps = ["Random", "Test"]
git-tree-sha1 = "94f38103c984f89cf77c402f2a68dbd870f8165f"
git-tree-sha1 = "0b829474fed270a4b0ab07117dce9b9a2fa7581a"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.9.11"
version = "0.9.12"
[[deps.Transducers]]
deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"]
......@@ -2232,9 +2232,9 @@ version = "0.4.1"
[[deps.UnicodePlots]]
deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase"]
git-tree-sha1 = "a5bcfc23e352f499a1a46f428d0d3d7fb9e4fc11"
git-tree-sha1 = "2825e58f6ec3cab889dfa2c824f8d89b9f7ee731"
uuid = "b8865327-cd53-5732-bb35-84acbb429228"
version = "3.4.1"
version = "3.5.1"
[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
......@@ -2438,15 +2438,15 @@ version = "1.2.12+3"
[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "c6edfe154ad7b313c01aceca188c05c835c67360"
git-tree-sha1 = "49ce682769cd5de6c72dcf1b94ed7790cd08974c"
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
version = "1.5.4+0"
version = "1.5.5+0"
[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "4df8f470806a45a8630ac8f597304821dc8e8838"
git-tree-sha1 = "987ae5554ca90e837594a0f30325eeb5e7303d1e"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.59"
version = "0.6.60"
[[deps.ZygoteRules]]
deps = ["ChainRulesCore", "MacroTools"]
......
......@@ -15,10 +15,10 @@ end
```{julia}
# Data:
n_obs = 10000
n_obs = 1000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = pre_process.(X)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
......@@ -32,18 +32,22 @@ First, let's create a couple of image classifier architectures:
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 32
n_hidden = 50
activation = Flux.swish
# builder = MLJFlux.@builder Flux.Chain(
# Dense(n_in, n_hidden, activation),
# Dense(n_hidden, n_hidden, activation),
# Dense(n_hidden, n_hidden, activation),
# # BatchNorm(n_hidden, activation),
# # Dense(n_hidden, n_hidden),
# # BatchNorm(n_hidden, activation),
# Dense(n_hidden, n_out),
# )
builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation)
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
# Dense(n_hidden, n_hidden, activation),
# Dense(n_hidden, n_hidden, activation),
# Dense(n_in, n_hidden),
# BatchNorm(n_hidden, activation),
# Dense(n_hidden, n_hidden),
# BatchNorm(n_hidden, activation),
Dense(n_hidden, n_out),
)
# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation)
# builder = MLJFlux.MLP(
# hidden=(
# n_hidden,
......@@ -52,7 +56,7 @@ builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation)
# ),
# σ=activation
# )
α = [1.0,1.0,1e-1]
α = [1.0,1.0,1e-2]
# Simple MLP:
mlp = NeuralNetworkClassifier(
......@@ -85,7 +89,7 @@ jem = JointEnergyClassifier(
)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=5)
mlp_ens = EnsembleModel(model=mlp, n=50)
```
```{julia}
......@@ -99,7 +103,7 @@ M = ECCCo.ConformalModel(mach.model, mach.fitresult)
```{julia}
if mach.model.model isa JointEnergyModels.JointEnergyClassifier
jem = mach.model.model.jem
n_iter = 500
n_iter = 200
_w = 1500
plts = []
neach = 10
......@@ -121,6 +125,7 @@ end
```{julia}
test_data = load_mnist_test()
test_data.X = pre_process.(test_data.X)
f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
println("F1 score (test): $(round(f1,digits=3))")
```
......@@ -153,9 +158,9 @@ ce_jsma = generate_counterfactual(
# ECCCo:
λ=[0.0,1.0]
temp=0.01
temp=0.5
# Generate counterfactual using ECCCo generator:
# Generate counterfactual using CCE generator:
generator = CCEGenerator(
λ=λ,
temp=temp,
......@@ -168,7 +173,7 @@ ce_conformal = generate_counterfactual(
converge_when=:generator_conditions,
)
# Generate counterfactual using ECCCo generator:
# Generate counterfactual using CCE generator:
generator = CCEGenerator(
λ=λ,
temp=temp,
......@@ -191,7 +196,7 @@ p1 = Plots.plot(
plts = [p1]
ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma]
_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"]
_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"]
for x in zip(ces, _names)
ce, _name = (x[1],x[2])
x = CounterfactualExplanations.counterfactual(ce)
......@@ -221,6 +226,8 @@ factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.5
T = 100
η=0.1
# Generate counterfactual using generic generator:
generator = GenericGenerator(opt=Flux.Optimise.Adam(),)
ce_wachter = generate_counterfactual(
......@@ -229,7 +236,7 @@ ce_wachter = generate_counterfactual(
initialization=:identity,
)
generator = GreedyGenerator(η=1.0)
generator = GreedyGenerator(η=η)
ce_jsma = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
......@@ -237,8 +244,8 @@ ce_jsma = generate_counterfactual(
)
# ECCCo:
λ=[0.0,1.0,1.0]
temp=0.01
λ=[0.0,0.0,10.0]
temp=0.5
# Generate counterfactual using ECCCo generator:
generator = ECCCoGenerator(
......@@ -257,7 +264,7 @@ ce_conformal = generate_counterfactual(
generator = ECCCoGenerator(
λ=λ,
temp=temp,
opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
opt=CounterfactualExplanations.Generators.JSMADescent(η=η),
)
ce_conformal_jsma = generate_counterfactual(
x, target, counterfactual_data, M, generator;
......@@ -292,7 +299,30 @@ for x in zip(ces, _names)
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(www_path, "cce_mnist.png"))
savefig(plt, joinpath(www_path, "eccco_mnist.png"))
```
```{julia}
if M.model.model isa JointEnergyModels.JointEnergyClassifier
jem = M.model.model.jem
n_iter = 200
_w = 1500
plts = []
neach = 10
for i in 1:10
x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
xj = reshape(xj, (n_digits, n_digits))
plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
end
plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
plts = [plts..., plt]
end
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
display(plt)
end
```
## Benchmark
......
No preview for this file type
......@@ -280,6 +280,13 @@ In order to generate prediction sets $C_{\theta}(f(\mathbf{Z}^\prime);\alpha)$ f
\section{Experiments}
\begin{itemize}
\item BatchNorm does not seem compatible with JEM
\item Coverage and temperature impacts CCE in somewhat unpredictable ways
\item It seems that models that are not explicitly trained for generative task, still learn it implictly
\item Batch size seems to impact quality of generated samples
\end{itemize}
\section{Discussion}
Consistent with the findings in \citet{schut2021generating}, we have demonstrated that predictive uncertainty estimates can be leveraged to generate plausible counterfactuals. Interestingly, \citet{schut2021generating} point out that this finding --- as intuitive as it is --- may be linked to a positive connection between the generative task and predictive uncertainty quantification. In particular, \citet{grathwohl2020your} demonstrate that their proposed method for integrating the generative objective in training yields models that have improved predictive uncertainty quantification. Since neither \citet{schut2021generating} nor we have employed any surrogate generative models, our findings seem to indicate that the positive connection found in \citet{grathwohl2020your} is bidirectional.
......
......@@ -9,6 +9,7 @@ include("losses.jl")
include("generator.jl")
include("sampling.jl")
export ECCCoGenerator, EnergySampler, set_size_penalty, distance_from_energy
export CCEGenerator, ECCCoGenerator, EnergySampler
export set_size_penalty, distance_from_energy
end
\ No newline at end of file
using CounterfactualExplanations.Objectives
"Constructor for `ECCCoGenerator`."
function ECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=1.0, temp::Real=0.05, kwargs...)
"Constructor for `CCEGenerator`."
function CCEGenerator(;
λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0],
κ::Real=1.0,
temp::Real=0.5,
kwargs...
)
function _set_size_penalty(ce::AbstractCounterfactualExplanation)
return ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
end
......@@ -11,7 +16,7 @@ function ECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1,
end
"Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator."
function ECECCCoGenerator(;
function ECCCoGenerator(;
λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0, 1.0],
κ::Real=1.0,
temp::Real=0.5,
......
......@@ -36,20 +36,19 @@ end
function distance_from_energy(
ce::AbstractCounterfactualExplanation;
n::Int=10000, from_buffer=true, agg=mean, kwargs...
n::Int=1, niter=200, from_buffer=true, agg=mean, kwargs...
)
conditional_samples = []
ignore_derivatives() do
_dict = ce.params
if !(:energy_sampler collect(keys(_dict)))
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; kwargs...)
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=1000, kwargs...)
end
sampler = _dict[:energy_sampler]
push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
end
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=3)) do x
x = Matrix(x)
loss = map(eachslice(x′, dims=ndims(x′))) do x
Δ = map(eachcol(conditional_samples[1])) do xsample
norm(x - xsample)
end
......@@ -63,7 +62,7 @@ end
function distance_from_targets(
ce::AbstractCounterfactualExplanation;
n::Int=10000, agg=mean
n::Int=1000, agg=mean
)
target_idx = ce.data.output_encoder.labels .== ce.target
target_samples = ce.data.X[:,target_idx] |>
......
......@@ -42,42 +42,29 @@ function EnergySampler(
)
@assert y data.y_levels || y 1:length(data.y_levels)
K = length(data.y_levels)
𝒟x = Normal()
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y)
if model.model.model isa JointEnergyClassifier
sampler = model.model.model.jem.sampler
else
K = length(data.y_levels)
input_size = size(selectdim(data.X, ndims(data.X), 1))
𝒟x = Uniform(extrema(data.X)...)
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
end
yidx = get_target_index(data.y_levels, y)
# Initiate:
energy_sampler = EnergySampler(model, data, sampler, opt, nothing, nothing)
# Generate samples:
generate_samples!(energy_sampler, nsamples, yidx; niter=niter)
chain = model.model.model.jem.chain
rule = model.model.model.jem.sampling_rule
energy_sampler.sampler(chain, rule; niter=niter, n_samples=nsamples, y=yidx)
return energy_sampler
end
"""
generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
Generates `n` samples from `EnergySampler` for conditioning value `y`.
"""
function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
X = e.sampler(e.model, e.opt, (size(e.data.X, 1), n); niter=niter, y=y)
X = X[:,map(x -> !any(isnan.(x)), eachcol(X))]
return X
end
"""
generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`.
"""
function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
e.buffer = generate_samples(e,n,y;niter=niter)
e.yidx = y
end
"""
EnergySampler(
ce::CounterfactualExplanation;
......@@ -105,12 +92,13 @@ end
Overloads the `rand` method to randomly draw `n` samples from `EnergySampler`.
"""
function Base.rand(sampler::EnergySampler, n::Int=100; from_buffer=true, niter::Int=100)
ntotal = size(sampler.buffer, 2)
ntotal = size(sampler.sampler.buffer)[end]
idx = rand(1:ntotal, n)
if from_buffer
X = sampler.buffer[:, idx]
X = sampler.sampler.buffer[:, idx]
else
X = generate_samples(sampler, n, sampler.yidx; niter=niter)
chain = sampler.model.fitresult[1]
X = sampler.sampler(chain, sampler.opt; niter=niter, n_samples=n, y=sampler.yidx)
end
return X
end
www/cce_mnist.png

22.9 KiB | W: | H:

www/cce_mnist.png

22.8 KiB | W: | H:

www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
  • 2-up
  • Swipe
  • Onion skin
www/eccco_mnist.png

20.1 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment