diff --git a/Manifest.toml b/Manifest.toml index 4531e744c4232042cd1891c52c5235f2fdd795ca..06b80abd0a09ec27e3d9f94fbf4dc29493113df6 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -49,20 +49,26 @@ uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" version = "0.5.4" [[deps.Arpack_jll]] -deps = ["Libdl", "OpenBLAS_jll", "Pkg"] -git-tree-sha1 = "e214a9b9bd1b4e1b4f15b22c0994862b66af7ff7" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "OpenBLAS_jll", "Pkg"] +git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e" uuid = "68821587-b530-5797-8361-c406ea357684" -version = "3.5.0+3" +version = "3.5.1+1" [[deps.ArrayInterface]] deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "a89acc90c551067cd84119ff018619a1a76c6277" +git-tree-sha1 = "38911c7737e123b28182d89027f4216cfc8a9da7" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.2.1" +version = "7.4.3" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +[[deps.Atomix]] +deps = ["UnsafeAtomics"] +git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +version = "0.1.0" + [[deps.AxisAlgorithms]] deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] git-tree-sha1 = "66771c8d21c8ff5e3a93379480a2307ac36863f7" @@ -76,9 +82,9 @@ uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" version = "0.4.2" [[deps.BSON]] -git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2" +git-tree-sha1 = "2208958832d6e1b59e49f53697483a84ca8d664e" uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -version = "0.3.6" +version = "0.3.7" [[deps.BangBang]] deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] @@ -94,11 +100,10 @@ git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" uuid = "9718e550-a3fa-408a-8086-8db961cd8217" version = "0.1.1" -[[deps.BinaryProvider]] -deps = ["Libdl", "Logging", "SHA"] -git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058" -uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.10" +[[deps.BitFlags]] +git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.7" [[deps.BufferedStreams]] git-tree-sha1 = "bb065b14d7f941b8617bc323063dbe79f55d16ea" @@ -123,34 +128,34 @@ uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.10.9" [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions"] -git-tree-sha1 = "edff14c60784c8f7191a62a23b15a421185bc8a8" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "8547829ee0da896ce48a24b8d2f4bf929cf3e45e" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "4.0.1" +version = "4.1.4" [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "75d7896d1ec079ef10d3aee8f3668c11354c03a1" +git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b" uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.2.0+0" +version = "0.5.0+1" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] -git-tree-sha1 = "58dd8ec29f54f08c04b052d2c2fa6760b4f4b3a4" +git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536" uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.1.1" +version = "0.2.2" [[deps.CUDA_Runtime_jll]] -deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] -git-tree-sha1 = "d3e6ccd30f84936c1a3a53d622d85d7d3f9b9486" +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "81eed046f28a0cdd0dc1f61d00a49061b7cc9433" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.2.3+2" +version = "0.5.0+2" [[deps.CUDNN_jll]] -deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] -git-tree-sha1 = "aafe89dfde54011993c4029d3be3e037fd63db07" +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] +git-tree-sha1 = "2918fbffb50e3b7a0b9127617587afa76d4276e8" uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" -version = "8.6.0+5" +version = "8.8.1+0" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] @@ -214,9 +219,9 @@ version = "0.14.4" [[deps.CodeTracking]] deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "d57c99cc7e637165c81b30eb268eabe156a45c49" +git-tree-sha1 = "d730914ef30a06732bdd9f763f6cc32e92ffbff1" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.2.2" +version = "1.3.1" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -320,10 +325,10 @@ uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" version = "1.14.0" [[deps.DataDeps]] -deps = ["BinaryProvider", "HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"] -git-tree-sha1 = "e299d8267135ef2f9c941a764006697082c1e7e8" +deps = ["HTTP", "Libdl", "Reexport", "SHA", "p7zip_jll"] +git-tree-sha1 = "bc0a264d3e7b3eeb0b6fc9f6481f970697f29805" uuid = "124859b0-ceae-595e-8997-d05f6a7a8dfe" -version = "0.7.8" +version = "0.7.10" [[deps.DataFrames]] deps = ["Compat", "DataAPI", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SnoopPrecompile", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] @@ -381,9 +386,9 @@ version = "1.13.0" [[deps.Distances]] deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "3258d0659f812acde79e8a74b11f17ac06d0ca04" +git-tree-sha1 = "49eba9ad9f7ead780bfb7ee319f962c811c6d3b2" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.7" +version = "0.10.8" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] @@ -391,9 +396,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "da9e1a9058f8d3eec3a8c9fe4faacfb89180066b" +git-tree-sha1 = "13027f188d26206b9e7b863036f87d2f2e7d013a" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.86" +version = "0.25.87" [[deps.DocStringExtensions]] deps = ["LibGit2"] @@ -476,15 +481,15 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "3b245d1e50466ca0c9529e2033a3c92387c59c2f" +git-tree-sha1 = "7072f1e3e5a8be51d525d64f63d3ec1287ff2790" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.9" +version = "0.13.11" [[deps.FiniteDiff]] deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "ed1b56934a2f7a65035976985da71b6a65b4f2cf" +git-tree-sha1 = "03fcb1c42ec905d15b305359603888ec3e65f886" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.18.0" +version = "2.19.0" [[deps.FixedPointNumbers]] deps = ["Statistics"] @@ -493,10 +498,10 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.4" [[deps.Flux]] -deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"] -git-tree-sha1 = "4ff3a1d7b0dd38f2fc38e813bc801f817639c1f2" +deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote", "cuDNN"] +git-tree-sha1 = "3f6f32ec0bfd80be0cb65907cf74ec796a632012" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.13" +version = "0.13.15" [[deps.FoldsThreads]] deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] @@ -541,9 +546,9 @@ version = "1.1.3" [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "7ed0833a55979d3d2658a60b901469748a6b9a7c" +git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.3" +version = "0.4.4" [[deps.Future]] deps = ["Random"] @@ -557,9 +562,9 @@ version = "3.3.8+0" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "7a2e790b1e2e6f648cfb25c4500c5de1f7b375ef" +git-tree-sha1 = "9ade6983c3dbbd492cf5729f865fe030d1541463" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.6.5" +version = "8.6.6" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -568,22 +573,22 @@ uuid = "46192b85-c4d5-4398-a991-12ede77f4527" version = "0.1.4" [[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "19d693666a304e8c371798f4900f7435558c7cde" +deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] +git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.17.3" +version = "0.19.3" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"] -git-tree-sha1 = "4423d87dc2d3201f3f1768a29e807ddc8cc867ef" +git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.71.8" +version = "0.72.1" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "3657eb348d44575cc5560c80d7e55b812ff6ffe1" +git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.71.8+0" +version = "0.72.1+1" [[deps.GZip]] deps = ["Libdl"] @@ -604,9 +609,9 @@ uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" version = "2.74.0+2" [[deps.Glob]] -git-tree-sha1 = "4df9f7e06108728ebf00a0a11edee4b29a482bb2" +git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" -version = "1.3.0" +version = "1.3.1" [[deps.Graphics]] deps = ["Colors", "LinearAlgebra", "NaNMath"] @@ -638,10 +643,10 @@ uuid = "0234f1f7-429e-5d53-9886-15a909be8d59" version = "1.12.2+2" [[deps.HTTP]] -deps = ["Base64", "Dates", "IniFile", "Logging", "MbedTLS", "NetworkOptions", "Sockets", "URIs"] -git-tree-sha1 = "0fa77022fe4b511826b39c894c90daf5fce3334a" +deps = ["Base64", "CodecZlib", "Dates", "IniFile", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "37e4657cd56b11abe3d10cd4a1ec5fbdb4180263" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "0.9.17" +version = "1.7.4" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] @@ -656,16 +661,16 @@ uuid = "eafb193a-b7ab-5a9e-9068-77385905fa72" version = "0.5.2" [[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions", "Test"] -git-tree-sha1 = "709d864e3ed6e3545230601f94e11ebc65994641" +deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "432b5b03176f8182bd6841fbfc42c718506a2d5f" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.11" +version = "0.3.15" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "2af2fe19f0d5799311a6491267a14817ad9fbd20" +git-tree-sha1 = "0ade27f0c49cebd8db2523c4eeccf779407cf12c" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.8" +version = "0.4.9" [[deps.ImageBase]] deps = ["ImageCore", "Reexport"] @@ -708,9 +713,9 @@ version = "1.4.0" [[deps.IntelOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c" +git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83" uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2018.0.3+2" +version = "2023.1.0+0" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -769,9 +774,9 @@ version = "1.4.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.3" +version = "0.21.4" [[deps.JSON3]] deps = ["Dates", "Mmap", "Parsers", "SnoopPrecompile", "StructTypes", "UUIDs"] @@ -780,7 +785,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JointEnergyModels]] -deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] +deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] path = "../JointEnergyModels.jl" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.0" @@ -802,6 +807,12 @@ git-tree-sha1 = "4aeebbfcf0615641ec4b0782b73b638eeeabd62e" uuid = "5cadff95-7770-533d-a838-a1bf817ee6e0" version = "0.3.0" +[[deps.KernelAbstractions]] +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "SnoopPrecompile", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "976231af02176082fb266a9f96a59da51fcacf20" +uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +version = "0.9.2" + [[deps.KernelDensity]] deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] git-tree-sha1 = "9816b296736292a80b9a3200eb7fbb57aaa3917a" @@ -822,15 +833,15 @@ version = "3.0.0+1" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "1c614dfbecbaee4897b506bba2b432bf0d21f2ed" +git-tree-sha1 = "a8960cae30b42b66dd41808beb76490519f6f9e2" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "4.17.0" +version = "5.0.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "e46e3a40daddcbe851f86db0ec4a4a3d4badf800" +git-tree-sha1 = "09b7505cc0b1cee87e5d4a26eea61d2e1b0dcd35" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.19+0" +version = "0.0.21+0" [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -953,11 +964,17 @@ version = "0.3.23" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.0" + [[deps.LossFunctions]] -deps = ["InteractiveUtils", "Markdown", "RecipesBase"] -git-tree-sha1 = "53cd63a12f06a43eef6f4aafb910ac755c122be7" +deps = ["CategoricalArrays", "Markdown"] +git-tree-sha1 = "d4c7ff8c7281943371e1725000fd538a699024d0" uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" -version = "0.8.0" +version = "0.9.0" [[deps.LsqFit]] deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"] @@ -985,9 +1002,9 @@ version = "0.7.9" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "6f3a7338e787cbf3460f035c21ee2547f71f8007" +git-tree-sha1 = "e99e84c7ca1696b38e2a8823f53ac9cc775599dd" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.21.6" +version = "0.21.9" [[deps.MLJEnsembles]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] @@ -1009,9 +1026,9 @@ version = "1.8.0" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "1d445497ca058dbc0dbc7528b778707893edb969" +git-tree-sha1 = "21acf47dc53ccc3d68e38ac7629756cd09b599f5" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.4" +version = "0.16.6" [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" @@ -1122,10 +1139,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" version = "7.8.3" [[deps.NNlib]] -deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "99e6dbb50d8a96702dc60954569e9fe7291cc55d" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.19" +version = "0.8.20" [[deps.NNlibCUDA]] deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] @@ -1164,9 +1181,9 @@ version = "0.3.5" [[deps.NearestNeighborModels]] deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "727b8f1c3f9fec6b1a805ba9bef72c73758eda02" +git-tree-sha1 = "c2179f9d8de066c481b889a1426068c5831bb10b" uuid = "636a865e-7cf4-491e-846c-de09b730eb36" -version = "0.2.1" +version = "0.2.2" [[deps.NearestNeighbors]] deps = ["Distances", "StaticArrays"] @@ -1211,6 +1228,12 @@ deps = ["Artifacts", "Libdl"] uuid = "05823500-19ac-5b8b-9628-191a04bc5112" version = "0.8.1+0" +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "5b3e170ea0724f1e3ed6018c5b006c190f80e87d" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.3.5" + [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "9ff31d101d987eb9d66bd8b176ac7c277beccd09" @@ -1231,9 +1254,9 @@ version = "2.0.2" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "e5a1825d3d53aa4ad4fb42bd4927011ad4a78c3d" +git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.15" +version = "0.2.18" [[deps.Opus_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1242,9 +1265,9 @@ uuid = "91d4177d-7536-5919-b921-800302f37372" version = "1.3.2+0" [[deps.OrderedCollections]] -git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.4.1" +version = "1.6.0" [[deps.PCRE2_jll]] deps = ["Artifacts", "Libdl"] @@ -1259,9 +1282,9 @@ version = "0.11.17" [[deps.PaddedViews]] deps = ["OffsetArrays"] -git-tree-sha1 = "03a7a85b76381a3d04c7a1656039197e70eda03d" +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.11" +version = "0.5.12" [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] @@ -1299,9 +1322,9 @@ version = "1.8.0" [[deps.PkgTemplates]] deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"] -git-tree-sha1 = "e93643cc634d7551f68b63739d205d22216cec7c" +git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3" uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" -version = "0.7.32" +version = "0.7.33" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] @@ -1317,9 +1340,9 @@ version = "1.3.4" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"] -git-tree-sha1 = "da1d3fb7183e38603fcdd2061c47979d91202c97" +git-tree-sha1 = "5434b0ee344eaf2854de251f326df8720f6a7b55" uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.38.6" +version = "1.38.10" [[deps.PooledArrays]] deps = ["DataAPI", "Future"] @@ -1501,6 +1524,11 @@ git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" version = "1.0.3" +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + [[deps.SimpleTraits]] deps = ["InteractiveUtils", "MacroTools"] git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" @@ -1552,9 +1580,9 @@ version = "0.1.1" [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "6aa098ef1012364f2ede6b17bf358c7f1fbe90d4" +git-tree-sha1 = "63e84b7fdf5021026d0f17f76af7c57772313d99" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.17" +version = "1.5.21" [[deps.StaticArraysCore]] git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" @@ -1573,9 +1601,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [[deps.StatsAPI]] deps = ["LinearAlgebra"] -git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6" +git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.5.0" +version = "1.6.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] @@ -1686,9 +1714,9 @@ version = "0.2.20" [[deps.TranscodingStreams]] deps = ["Random", "Test"] -git-tree-sha1 = "94f38103c984f89cf77c402f2a68dbd870f8165f" +git-tree-sha1 = "0b829474fed270a4b0ab07117dce9b9a2fa7581a" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.11" +version = "0.9.12" [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] @@ -1732,14 +1760,25 @@ version = "0.4.1" [[deps.UnicodePlots]] deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase"] -git-tree-sha1 = "a5bcfc23e352f499a1a46f428d0d3d7fb9e4fc11" +git-tree-sha1 = "2825e58f6ec3cab889dfa2c824f8d89b9f7ee731" uuid = "b8865327-cd53-5732-bb35-84acbb429228" -version = "3.4.1" +version = "3.5.1" + +[[deps.UnsafeAtomics]] +git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" +version = "0.2.1" + +[[deps.UnsafeAtomicsLLVM]] +deps = ["LLVM", "UnsafeAtomics"] +git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175" +uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" +version = "0.1.2" [[deps.Unzip]] -git-tree-sha1 = "34db80951901073501137bdbc3d5a8e7bbd06670" +git-tree-sha1 = "ca0969166a028236229f63514992fc073799bb78" uuid = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d" -version = "0.1.2" +version = "0.2.0" [[deps.Wayland_jll]] deps = ["Artifacts", "Expat_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg", "XML2_jll"] @@ -1927,27 +1966,27 @@ version = "1.2.12+3" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c6edfe154ad7b313c01aceca188c05c835c67360" +git-tree-sha1 = "49ce682769cd5de6c72dcf1b94ed7790cd08974c" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.4+0" +version = "1.5.5+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "4df8f470806a45a8630ac8f597304821dc8e8838" +git-tree-sha1 = "987ae5554ca90e837594a0f30325eeb5e7303d1e" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.59" +version = "0.6.60" [[deps.ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" +deps = ["ChainRulesCore", "MacroTools"] +git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.2" +version = "0.2.3" [[deps.cuDNN]] deps = ["CEnum", "CUDA", "CUDNN_jll"] -git-tree-sha1 = "c0ffcb38d1e8c0bbcd3dab2559cf9c369130b2f2" +git-tree-sha1 = "3aa15aba7aad5be8b9b3c1b77a9b81e3e1357280" uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" -version = "1.0.1" +version = "1.0.2" [[deps.fzf_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 6dbe102ca9eb0869445e2279efae73b9f686cf9d..400facb7be5d41186be3476bfe291d7420d08ee1 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "bb24fa6d048fab99674a941d85c45a034b033aae" +project_hash = "b4b125f21013c1ac841e2bb761cd8922630f9f03" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -156,9 +156,9 @@ version = "0.10.9" [[deps.CUDA]] deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "6591ddc73adb429b9d97145c8197a0ac81664ab4" +git-tree-sha1 = "8547829ee0da896ce48a24b8d2f4bf929cf3e45e" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "4.1.3" +version = "4.1.4" [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] @@ -168,9 +168,9 @@ version = "0.5.0+1" [[deps.CUDA_Runtime_Discovery]] deps = ["Libdl"] -git-tree-sha1 = "6c8fceaaa6850dea627288ac3bb86fdcdf05e326" +git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536" uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.2.0" +version = "0.2.2" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] @@ -275,9 +275,9 @@ version = "0.14.4" [[deps.CodeTracking]] deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "0683f086e2ef8e2fdacd3f246b35c59e7088b283" +git-tree-sha1 = "d730914ef30a06732bdd9f763f6cc32e92ffbff1" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.3.0" +version = "1.3.1" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -475,9 +475,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "da9e1a9058f8d3eec3a8c9fe4faacfb89180066b" +git-tree-sha1 = "13027f188d26206b9e7b863036f87d2f2e7d013a" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.86" +version = "0.25.87" [[deps.DocStringExtensions]] deps = ["LibGit2"] @@ -601,9 +601,9 @@ version = "0.8.4" [[deps.Flux]] deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote", "cuDNN"] -git-tree-sha1 = "e657a9aad824de4211606f113edd0b50d5e1f6db" +git-tree-sha1 = "3f6f32ec0bfd80be0cb65907cf74ec796a632012" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.14" +version = "0.13.15" [[deps.FoldsThreads]] deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] @@ -694,21 +694,21 @@ version = "0.1.4" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "590d394bad1055b798b2f9b308327ba871b7badf" +git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.19.0" +version = "0.19.3" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"] -git-tree-sha1 = "4423d87dc2d3201f3f1768a29e807ddc8cc867ef" +git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.71.8" +version = "0.72.1" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "3657eb348d44575cc5560c80d7e55b812ff6ffe1" +git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.71.8+0" +version = "0.72.1+1" [[deps.GZip]] deps = ["Libdl"] @@ -812,9 +812,9 @@ version = "0.5.2" [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "d926e9c297ef4607866e8ef5df41cde1a642917f" +git-tree-sha1 = "432b5b03176f8182bd6841fbfc42c718506a2d5f" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.14" +version = "0.3.15" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] @@ -963,9 +963,9 @@ version = "0.1.5" [[deps.IntelOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "d979e54b71da82f3a65b62553da4fc3d18c9004c" +git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83" uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2018.0.3+2" +version = "2023.1.0+0" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -1041,9 +1041,9 @@ version = "1.4.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.3" +version = "0.21.4" [[deps.JSON3]] deps = ["Dates", "Mmap", "Parsers", "SnoopPrecompile", "StructTypes", "UUIDs"] @@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JointEnergyModels]] -deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] +deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] path = "../../JointEnergyModels.jl" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.0" @@ -1082,9 +1082,9 @@ version = "0.3.0" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "SnoopPrecompile", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "350a880e80004f4d5d82a17f737d8fcdc56c3462" +git-tree-sha1 = "976231af02176082fb266a9f96a59da51fcacf20" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.1" +version = "0.9.2" [[deps.KernelDensity]] deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] @@ -1250,10 +1250,10 @@ uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.0" [[deps.LossFunctions]] -deps = ["InteractiveUtils", "Markdown", "RecipesBase"] -git-tree-sha1 = "f27330f931944ecee340f004302db724c1985955" +deps = ["CategoricalArrays", "Markdown"] +git-tree-sha1 = "d4c7ff8c7281943371e1725000fd538a699024d0" uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" -version = "0.8.1" +version = "0.9.0" [[deps.LsqFit]] deps = ["Distributions", "ForwardDiff", "LinearAlgebra", "NLSolversBase", "OptimBase", "Random", "StatsBase"] @@ -1281,9 +1281,9 @@ version = "0.7.9" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "37a311b0cd581764fc460f6632e6219dc32f9427" +git-tree-sha1 = "e99e84c7ca1696b38e2a8823f53ac9cc775599dd" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.21.8" +version = "0.21.9" [[deps.MLJEnsembles]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] @@ -1305,9 +1305,9 @@ version = "1.8.0" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "865558dcdb963789ba82651b990d0fb9c5e8dd59" +git-tree-sha1 = "21acf47dc53ccc3d68e38ac7629756cd09b599f5" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.5" +version = "0.16.6" [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" @@ -1360,9 +1360,9 @@ version = "1.2.0" [[deps.MathTeXEngine]] deps = ["AbstractTrees", "Automa", "DataStructures", "FreeTypeAbstraction", "GeometryBasics", "LaTeXStrings", "REPL", "RelocatableFolders", "Test", "UnicodeFun"] -git-tree-sha1 = "64890e1e8087b71c03bd6b8af99b49c805b2a78d" +git-tree-sha1 = "8f52dbaa1351ce4cb847d95568cb29e62a307d93" uuid = "0a4f8689-d25c-4efe-a92b-7142dfc1aa53" -version = "0.5.5" +version = "0.5.6" [[deps.MbedTLS]] deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] @@ -1453,10 +1453,10 @@ uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" version = "7.8.3" [[deps.NNlib]] -deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] +git-tree-sha1 = "99e6dbb50d8a96702dc60954569e9fe7291cc55d" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.19" +version = "0.8.20" [[deps.NNlibCUDA]] deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] @@ -1562,9 +1562,9 @@ version = "0.8.1+0" [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "6503b77492fd7fcb9379bf73cd31035670e3c509" +git-tree-sha1 = "5b3e170ea0724f1e3ed6018c5b006c190f80e87d" uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.3.3" +version = "1.3.5" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1586,9 +1586,9 @@ version = "2.0.2" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "4b214125921ec010160ddb39931885e0a6585639" +git-tree-sha1 = "6a01f65dd8583dee82eecc2a19b0ff21521aa749" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.17" +version = "0.2.18" [[deps.Opus_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1626,9 +1626,9 @@ version = "0.5.0" [[deps.PaddedViews]] deps = ["OffsetArrays"] -git-tree-sha1 = "03a7a85b76381a3d04c7a1656039197e70eda03d" +git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.11" +version = "0.5.12" [[deps.Pango_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1672,9 +1672,9 @@ version = "1.8.0" [[deps.PkgTemplates]] deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"] -git-tree-sha1 = "e93643cc634d7551f68b63739d205d22216cec7c" +git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3" uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" -version = "0.7.32" +version = "0.7.33" [[deps.PkgVersion]] deps = ["Pkg"] @@ -1696,9 +1696,9 @@ version = "1.3.4" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"] -git-tree-sha1 = "f49a45a239e13333b8b936120fe6d793fe58a972" +git-tree-sha1 = "5434b0ee344eaf2854de251f326df8720f6a7b55" uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.38.8" +version = "1.38.10" [[deps.PolygonOps]] git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6" @@ -2023,9 +2023,9 @@ version = "0.1.1" [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "b8d897fe7fa688e93aef573711cb207c08c9e11e" +git-tree-sha1 = "63e84b7fdf5021026d0f17f76af7c57772313d99" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.19" +version = "1.5.21" [[deps.StaticArraysCore]] git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" @@ -2062,9 +2062,9 @@ version = "1.3.0" [[deps.StatsModels]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsBase", "StatsFuns", "Tables"] -git-tree-sha1 = "06a230063087c11910e9bbd17ccbf5af792a27a4" +git-tree-sha1 = "51cdf1afd9d78552e7a08536930d7abc3b288a5c" uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d" -version = "0.7.0" +version = "0.7.1" [[deps.StatsPlots]] deps = ["AbstractFFTs", "Clustering", "DataStructures", "DataValues", "Distributions", "Interpolations", "KernelDensity", "LinearAlgebra", "MultivariateStats", "NaNMath", "Observables", "Plots", "RecipesBase", "RecipesPipeline", "Reexport", "StatsBase", "TableOperations", "Tables", "Widgets"] @@ -2151,9 +2151,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.Tidier]] deps = ["Chain", "Cleaner", "DataFrames", "MacroTools", "Reexport", "ShiftedArrays", "Statistics"] -git-tree-sha1 = "6c01fc23066480d998d3eabba0f999c46e88ed73" +git-tree-sha1 = "107614fa8ae68ca12a193df719ec2438e86ab97b" uuid = "f0413319-3358-4bb0-8e7c-0c83523a93bd" -version = "0.7.1" +version = "0.7.4" [[deps.TiffImages]] deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] @@ -2181,9 +2181,9 @@ version = "0.2.20" [[deps.TranscodingStreams]] deps = ["Random", "Test"] -git-tree-sha1 = "94f38103c984f89cf77c402f2a68dbd870f8165f" +git-tree-sha1 = "0b829474fed270a4b0ab07117dce9b9a2fa7581a" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.11" +version = "0.9.12" [[deps.Transducers]] deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] @@ -2232,9 +2232,9 @@ version = "0.4.1" [[deps.UnicodePlots]] deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "Printf", "Requires", "SnoopPrecompile", "SparseArrays", "StaticArrays", "StatsBase"] -git-tree-sha1 = "a5bcfc23e352f499a1a46f428d0d3d7fb9e4fc11" +git-tree-sha1 = "2825e58f6ec3cab889dfa2c824f8d89b9f7ee731" uuid = "b8865327-cd53-5732-bb35-84acbb429228" -version = "3.4.1" +version = "3.5.1" [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" @@ -2438,15 +2438,15 @@ version = "1.2.12+3" [[deps.Zstd_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "c6edfe154ad7b313c01aceca188c05c835c67360" +git-tree-sha1 = "49ce682769cd5de6c72dcf1b94ed7790cd08974c" uuid = "3161d3a3-bdf6-5164-811a-617609db77b4" -version = "1.5.4+0" +version = "1.5.5+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "4df8f470806a45a8630ac8f597304821dc8e8838" +git-tree-sha1 = "987ae5554ca90e837594a0f30325eeb5e7303d1e" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.59" +version = "0.6.60" [[deps.ZygoteRules]] deps = ["ChainRulesCore", "MacroTools"] diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 5e7f7aa615f8b9bf3f7c549c7c61f349961e3c7c..8a7f179f1dc24a1e55975659fc4c2a5b45d6ffa6 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -15,10 +15,10 @@ end ```{julia} # Data: -n_obs = 10000 +n_obs = 1000 counterfactual_data = load_mnist(n_obs) +counterfactual_data.X = pre_process.(counterfactual_data.X) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) -X = pre_process.(X) X = table(permutedims(X)) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) @@ -32,18 +32,22 @@ First, let's create a couple of image classifier architectures: # Model parameters: epochs = 100 batch_size = minimum([Int(round(n_obs/10)), 128]) -n_hidden = 32 +n_hidden = 50 activation = Flux.swish -# builder = MLJFlux.@builder Flux.Chain( -# Dense(n_in, n_hidden, activation), -# Dense(n_hidden, n_hidden, activation), -# Dense(n_hidden, n_hidden, activation), -# # BatchNorm(n_hidden, activation), -# # Dense(n_hidden, n_hidden), -# # BatchNorm(n_hidden, activation), -# Dense(n_hidden, n_out), -# ) -builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation) +builder = MLJFlux.@builder Flux.Chain( + + Dense(n_in, n_hidden, activation), + # Dense(n_hidden, n_hidden, activation), + # Dense(n_hidden, n_hidden, activation), + + # Dense(n_in, n_hidden), + # BatchNorm(n_hidden, activation), + # Dense(n_hidden, n_hidden), + # BatchNorm(n_hidden, activation), + + Dense(n_hidden, n_out), +) +# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation) # builder = MLJFlux.MLP( # hidden=( # n_hidden, @@ -52,7 +56,7 @@ builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.1, σ=activation) # ), # σ=activation # ) -α = [1.0,1.0,1e-1] +α = [1.0,1.0,1e-2] # Simple MLP: mlp = NeuralNetworkClassifier( @@ -85,7 +89,7 @@ jem = JointEnergyClassifier( ) # Deep Ensemble: -mlp_ens = EnsembleModel(model=mlp, n=5) +mlp_ens = EnsembleModel(model=mlp, n=50) ``` ```{julia} @@ -99,7 +103,7 @@ M = ECCCo.ConformalModel(mach.model, mach.fitresult) ```{julia} if mach.model.model isa JointEnergyModels.JointEnergyClassifier jem = mach.model.model.jem - n_iter = 500 + n_iter = 200 _w = 1500 plts = [] neach = 10 @@ -121,6 +125,7 @@ end ```{julia} test_data = load_mnist_test() +test_data.X = pre_process.(test_data.X) f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) println("F1 score (test): $(round(f1,digits=3))") ``` @@ -153,9 +158,9 @@ ce_jsma = generate_counterfactual( # ECCCo: λ=[0.0,1.0] -temp=0.01 +temp=0.5 -# Generate counterfactual using ECCCo generator: +# Generate counterfactual using CCE generator: generator = CCEGenerator( λ=λ, temp=temp, @@ -168,7 +173,7 @@ ce_conformal = generate_counterfactual( converge_when=:generator_conditions, ) -# Generate counterfactual using ECCCo generator: +# Generate counterfactual using CCE generator: generator = CCEGenerator( λ=λ, temp=temp, @@ -191,7 +196,7 @@ p1 = Plots.plot( plts = [p1] ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] -_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"] +_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"] for x in zip(ces, _names) ce, _name = (x[1],x[2]) x = CounterfactualExplanations.counterfactual(ce) @@ -221,6 +226,8 @@ factual = predict_label(M, counterfactual_data, x)[1] γ = 0.5 T = 100 +η=0.1 + # Generate counterfactual using generic generator: generator = GenericGenerator(opt=Flux.Optimise.Adam(),) ce_wachter = generate_counterfactual( @@ -229,7 +236,7 @@ ce_wachter = generate_counterfactual( initialization=:identity, ) -generator = GreedyGenerator(η=1.0) +generator = GreedyGenerator(η=η) ce_jsma = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, @@ -237,8 +244,8 @@ ce_jsma = generate_counterfactual( ) # ECCCo: -λ=[0.0,1.0,1.0] -temp=0.01 +λ=[0.0,0.0,10.0] +temp=0.5 # Generate counterfactual using ECCCo generator: generator = ECCCoGenerator( @@ -257,7 +264,7 @@ ce_conformal = generate_counterfactual( generator = ECCCoGenerator( λ=λ, temp=temp, - opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), + opt=CounterfactualExplanations.Generators.JSMADescent(η=η), ) ce_conformal_jsma = generate_counterfactual( x, target, counterfactual_data, M, generator; @@ -292,7 +299,30 @@ for x in zip(ces, _names) end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) -savefig(plt, joinpath(www_path, "cce_mnist.png")) +savefig(plt, joinpath(www_path, "eccco_mnist.png")) +``` + +```{julia} +if M.model.model isa JointEnergyModels.JointEnergyClassifier + jem = M.model.model.jem + n_iter = 200 + _w = 1500 + plts = [] + neach = 10 + for i in 1:10 + x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i) + plts_i = [] + for j in 1:size(x, 2) + xj = x[:,j] + xj = reshape(xj, (n_digits, n_digits)) + plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] + end + plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) + plts = [plts..., plt] + end + plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) + display(plt) +end ``` ## Benchmark diff --git a/paper/paper.pdf b/paper/paper.pdf index 2ba71c2968b35cca31816ccb45be28b071acf7fa..222cec3cdb4a1bfb1966a554996d478b458fc381 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index bd18cf2a571a04e33ac4f7085bbdb62132d7190c..6ee1a43bf68b58fb0a7c4c743d1a728573ad951c 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -280,6 +280,13 @@ In order to generate prediction sets $C_{\theta}(f(\mathbf{Z}^\prime);\alpha)$ f \section{Experiments} +\begin{itemize} + \item BatchNorm does not seem compatible with JEM + \item Coverage and temperature impacts CCE in somewhat unpredictable ways + \item It seems that models that are not explicitly trained for generative task, still learn it implictly + \item Batch size seems to impact quality of generated samples +\end{itemize} + \section{Discussion} Consistent with the findings in \citet{schut2021generating}, we have demonstrated that predictive uncertainty estimates can be leveraged to generate plausible counterfactuals. Interestingly, \citet{schut2021generating} point out that this finding --- as intuitive as it is --- may be linked to a positive connection between the generative task and predictive uncertainty quantification. In particular, \citet{grathwohl2020your} demonstrate that their proposed method for integrating the generative objective in training yields models that have improved predictive uncertainty quantification. Since neither \citet{schut2021generating} nor we have employed any surrogate generative models, our findings seem to indicate that the positive connection found in \citet{grathwohl2020your} is bidirectional. diff --git a/src/ECCCo.jl b/src/ECCCo.jl index 82c901d72e83b0b8ef6abf2ca43d5801c1403489..33135f6352d0a3f0a53b20dd5ac1f64e64ae6c70 100644 --- a/src/ECCCo.jl +++ b/src/ECCCo.jl @@ -9,6 +9,7 @@ include("losses.jl") include("generator.jl") include("sampling.jl") -export ECCCoGenerator, EnergySampler, set_size_penalty, distance_from_energy +export CCEGenerator, ECCCoGenerator, EnergySampler +export set_size_penalty, distance_from_energy end \ No newline at end of file diff --git a/src/generator.jl b/src/generator.jl index 2dc3e460c8bfed0846ecd56f90f6792fd99f3280..72952aa31de89bf7d1a8d297a8337abe97cf9a16 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -1,7 +1,12 @@ using CounterfactualExplanations.Objectives -"Constructor for `ECCCoGenerator`." -function ECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=1.0, temp::Real=0.05, kwargs...) +"Constructor for `CCEGenerator`." +function CCEGenerator(; + λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], + κ::Real=1.0, + temp::Real=0.5, + kwargs... +) function _set_size_penalty(ce::AbstractCounterfactualExplanation) return ECCCo.set_size_penalty(ce; κ=κ, temp=temp) end @@ -11,7 +16,7 @@ function ECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, end "Constructor for `ECECCCoGenerator`: Energy Constrained Conformal Counterfactual Explanation Generator." -function ECECCCoGenerator(; +function ECCCoGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0, 1.0], κ::Real=1.0, temp::Real=0.5, diff --git a/src/penalties.jl b/src/penalties.jl index 81b953fe5884a9fec638a2db01df73edcdd3176a..95c96b4109cba3bf5f40d4c8f7ae03e44bfefceb 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -36,20 +36,19 @@ end function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=10000, from_buffer=true, agg=mean, kwargs... + n::Int=1, niter=200, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do _dict = ce.params if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = ECCCo.EnergySampler(ce; kwargs...) + _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=1000, kwargs...) end sampler = _dict[:energy_sampler] push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) end x′ = CounterfactualExplanations.counterfactual(ce) - loss = map(eachslice(x′, dims=3)) do x - x = Matrix(x) + loss = map(eachslice(x′, dims=ndims(x′))) do x Δ = map(eachcol(conditional_samples[1])) do xsample norm(x - xsample) end @@ -63,7 +62,7 @@ end function distance_from_targets( ce::AbstractCounterfactualExplanation; - n::Int=10000, agg=mean + n::Int=1000, agg=mean ) target_idx = ce.data.output_encoder.labels .== ce.target target_samples = ce.data.X[:,target_idx] |> diff --git a/src/sampling.jl b/src/sampling.jl index 4f57fc408631300cee03fea0bc28a071635a3d32..50a170283a65f6e0ad64faae58084913898f0a95 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -42,42 +42,29 @@ function EnergySampler( ) @assert y ∈ data.y_levels || y ∈ 1:length(data.y_levels) - K = length(data.y_levels) - ð’Ÿx = Normal() - ð’Ÿy = Categorical(ones(K) ./ K) - sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy) + + if model.model.model isa JointEnergyClassifier + sampler = model.model.model.jem.sampler + else + K = length(data.y_levels) + input_size = size(selectdim(data.X, ndims(data.X), 1)) + ð’Ÿx = Uniform(extrema(data.X)...) + ð’Ÿy = Categorical(ones(K) ./ K) + sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) + end yidx = get_target_index(data.y_levels, y) # Initiate: energy_sampler = EnergySampler(model, data, sampler, opt, nothing, nothing) # Generate samples: - generate_samples!(energy_sampler, nsamples, yidx; niter=niter) + chain = model.model.model.jem.chain + rule = model.model.model.jem.sampling_rule + energy_sampler.sampler(chain, rule; niter=niter, n_samples=nsamples, y=yidx) return energy_sampler end -""" - generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) - -Generates `n` samples from `EnergySampler` for conditioning value `y`. -""" -function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100) - X = e.sampler(e.model, e.opt, (size(e.data.X, 1), n); niter=niter, y=y) - X = X[:,map(x -> !any(isnan.(x)), eachcol(X))] - return X -end - -""" - generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100) - -Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`. -""" -function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100) - e.buffer = generate_samples(e,n,y;niter=niter) - e.yidx = y -end - """ EnergySampler( ce::CounterfactualExplanation; @@ -105,12 +92,13 @@ end Overloads the `rand` method to randomly draw `n` samples from `EnergySampler`. """ function Base.rand(sampler::EnergySampler, n::Int=100; from_buffer=true, niter::Int=100) - ntotal = size(sampler.buffer, 2) + ntotal = size(sampler.sampler.buffer)[end] idx = rand(1:ntotal, n) if from_buffer - X = sampler.buffer[:, idx] + X = sampler.sampler.buffer[:, idx] else - X = generate_samples(sampler, n, sampler.yidx; niter=niter) + chain = sampler.model.fitresult[1] + X = sampler.sampler(chain, sampler.opt; niter=niter, n_samples=n, y=sampler.yidx) end return X end diff --git a/www/cce_mnist.png b/www/cce_mnist.png index 55ff8d5650856abe43798dcd34a10da9de1c34f0..c96db716f561b5792813f3ace77acb7a51180819 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ diff --git a/www/eccco_mnist.png b/www/eccco_mnist.png new file mode 100644 index 0000000000000000000000000000000000000000..81cecbe4d0f947bf55ef08cd7f2db35f318337c2 Binary files /dev/null and b/www/eccco_mnist.png differ