diff --git a/Manifest.toml b/Manifest.toml index ec1a4d53558645c9348900c816cedd9ca74b12c3..d739fbc5a7b3621b2443c61f1ce2690bd435891b 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -12,15 +12,15 @@ version = "1.2.1" [[deps.Accessors]] deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"] -git-tree-sha1 = "b9661b900b50ba475145b311a9a0ef9d2a9c85ea" +git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.26" +version = "0.1.28" [[deps.Adapt]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "0310e08cb19f5da31d08341c6120c047598f5b9c" +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.5.0" +version = "3.6.1" [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" @@ -49,20 +49,25 @@ git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e" uuid = "68821587-b530-5797-8361-c406ea357684" version = "3.5.1+1" -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.29" +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "ec9c36854b569323551a6faf2f31fda15e3459a7" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.2.0" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" [[deps.BFloat16s]] deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072" +git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.2.0" +version = "0.4.2" + +[[deps.BSON]] +git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2" +uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +version = "0.3.6" [[deps.BangBang]] deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"] @@ -112,10 +117,34 @@ uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" version = "0.10.9" [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"] -git-tree-sha1 = "6717cb9a3425ebb7b31ca4f832823615d175f64a" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions"] +git-tree-sha1 = "edff14c60784c8f7191a62a23b15a421185bc8a8" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "3.13.1" +version = "4.0.1" + +[[deps.CUDA_Driver_jll]] +deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] +git-tree-sha1 = "75d7896d1ec079ef10d3aee8f3668c11354c03a1" +uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" +version = "0.2.0+0" + +[[deps.CUDA_Runtime_Discovery]] +deps = ["Libdl"] +git-tree-sha1 = "58dd8ec29f54f08c04b052d2c2fa6760b4f4b3a4" +uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" +version = "0.1.1" + +[[deps.CUDA_Runtime_jll]] +deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "d3e6ccd30f84936c1a3a53d622d85d7d3f9b9486" +uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" +version = "0.2.3+2" + +[[deps.CUDNN_jll]] +deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] +git-tree-sha1 = "57011df4fce448828165e566af9befa2ea94350a" +uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645" +version = "8.6.0+3" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] @@ -143,9 +172,9 @@ version = "0.1.9" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "fdde4d8a31cf82b1d136cf6cb53924e8744a832b" +git-tree-sha1 = "7d20c2fb8ab838e41069398685e7b6b5f89ed85b" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.47.0" +version = "1.48.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] @@ -155,9 +184,9 @@ version = "1.15.7" [[deps.ChangesOfVariables]] deps = ["ChainRulesCore", "LinearAlgebra", "Test"] -git-tree-sha1 = "844b061c104c408b24537482469400af6075aae4" +git-tree-sha1 = "485193efd2176b88e6622a39a246f8c5b600e74e" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.5" +version = "0.1.6" [[deps.Chemfiles]] deps = ["BinaryProvider", "Chemfiles_jll", "DocStringExtensions", "Libdl"] @@ -234,16 +263,16 @@ uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" version = "0.3.2" [[deps.ConformalPrediction]] -deps = ["CategoricalArrays", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "Statistics"] +deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "Statistics"] path = "../ConformalPrediction.jl" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" version = "0.1.5" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] -git-tree-sha1 = "fb21ddd70a051d882a1686a5a550990bbe371a95" +git-tree-sha1 = "89a9db8d28102b094992472d333674bd1a83ce2a" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.4.1" +version = "1.5.1" [[deps.ContextVariablesX]] deps = ["Compat", "Logging", "UUIDs"] @@ -260,7 +289,7 @@ version = "0.6.2" deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "PkgTemplates", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "Statistics", "StatsBase", "Tables", "UMAP"] path = "../CounterfactualExplanations.jl" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.6" +version = "0.1.7" [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" @@ -322,9 +351,9 @@ version = "1.1.0" [[deps.DiffRules]] deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "c5b6685d53f933c11404a3ae9822afe30d522494" +git-tree-sha1 = "a4ad7ef19d2cdc2eff57abbbe68032b1cd0bd8f8" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.12.2" +version = "1.13.0" [[deps.Distances]] deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] @@ -338,9 +367,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "74911ad88921455c6afcad1eefa12bd7b1724631" +git-tree-sha1 = "d71264a7b9a95dca3b8fff4477d94a837346c545" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.80" +version = "0.25.84" [[deps.DocStringExtensions]] deps = ["LibGit2"] @@ -416,10 +445,10 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" version = "0.13.7" [[deps.FiniteDiff]] -deps = ["ArrayInterfaceCore", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "04ed1f0029b6b3af88343e439b995141cb0d0b8d" +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"] +git-tree-sha1 = "ed1b56934a2f7a65035976985da71b6a65b4f2cf" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.17.0" +version = "2.18.0" [[deps.FixedPointNumbers]] deps = ["Statistics"] @@ -429,9 +458,9 @@ version = "0.8.4" [[deps.Flux]] deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"] -git-tree-sha1 = "c258e51850ac1fdc465f62380a61995d4a66d603" +git-tree-sha1 = "4ff3a1d7b0dd38f2fc38e813bc801f817639c1f2" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.13.12" +version = "0.13.13" [[deps.FoldsThreads]] deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"] @@ -453,9 +482,9 @@ version = "0.4.2" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "a69dd6db8a809f78846ff259298678f0d6212180" +git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.34" +version = "0.10.35" [[deps.FreeType2_jll]] deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] @@ -476,9 +505,9 @@ version = "1.1.3" [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "61fa9cf802d35fe1b5b8ea9fbaac4b8f020d19b1" +git-tree-sha1 = "7ed0833a55979d3d2658a60b901469748a6b9a7c" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.2" +version = "0.4.3" [[deps.Future]] deps = ["Random"] @@ -492,15 +521,15 @@ version = "3.3.8+0" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "4dfaff044eb2ce11a897fecd85538310e60b91e6" +git-tree-sha1 = "a28f752ffab0ccd6660fc7af5ad1c9ad176f45f7" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.6.2" +version = "8.6.3" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "57f7cde02d7a53c9d1d28443b9f11ac5fbe7ebc9" +git-tree-sha1 = "1cd7f0af1aa58abc02ea1d872953a97359cb87fa" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.3" +version = "0.1.4" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"] @@ -562,9 +591,9 @@ version = "1.0.2" [[deps.HDF5]] deps = ["Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "b5df7c3cab3a00c33c2e09c6bd23982a75e2fbb2" +git-tree-sha1 = "3dab31542b3da9f25a6a1d11159d4af8fdce7d67" uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.16.13" +version = "0.16.14" [[deps.HDF5_jll]] deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"] @@ -609,10 +638,10 @@ uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" version = "0.9.4" [[deps.ImageShow]] -deps = ["Base64", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] -git-tree-sha1 = "b563cf9ae75a635592fc73d3eb78b86220e55bd8" +deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"] +git-tree-sha1 = "ce28c68c900eed3cdbfa418be66ed053e54d4f56" uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31" -version = "0.3.6" +version = "0.3.7" [[deps.Inflate]] git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428" @@ -657,9 +686,9 @@ uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" version = "1.2.0" [[deps.IrrationalConstants]] -git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151" +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.1.1" +version = "0.2.2" [[deps.IteratorInterfaceExtensions]] git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" @@ -697,10 +726,10 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JpegTurbo_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "b53380851c6e6664204efb2e62cd24fa5c47e4ba" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc" uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "2.1.2+0" +version = "2.1.91+0" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -733,9 +762,9 @@ version = "4.16.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"] -git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576" +git-tree-sha1 = "7718cf44439c676bc0ec66a87099f41015a522d6" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.16+0" +version = "0.0.16+2" [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -851,9 +880,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.LogExpFunctions]] deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "680e733c3a0a9cea9e935c8c2184aea6a63fa0b5" +git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.21" +version = "0.3.23" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -890,15 +919,15 @@ version = "0.7.9" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "f6667db64f84c5031e3f4e48b5da80e1dd39429d" +git-tree-sha1 = "6f3a7338e787cbf3460f035c21ee2547f71f8007" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.21.5" +version = "0.21.6" [[deps.MLJFlux]] -deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "ProgressMeter", "Random", "Statistics", "Tables"] -git-tree-sha1 = "a47257705ebca405a25320b111345a978925bcd5" +deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] +git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" -version = "0.2.7" +version = "0.2.9" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] @@ -908,9 +937,9 @@ version = "1.8.0" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "638c84dea26eb91a538afb52cef11d77556e6807" +git-tree-sha1 = "1d445497ca058dbc0dbc7528b778707893edb969" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.3" +version = "0.16.4" [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" @@ -960,6 +989,12 @@ git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102" uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e" version = "0.3.2" +[[deps.Metalhead]] +deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "Random", "Statistics"] +git-tree-sha1 = "0e95f91cc5f23610f8f270d7397f307b21e19d2b" +uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" +version = "0.7.4" + [[deps.MicroCollections]] deps = ["BangBang", "InitialValues", "Setfield"] git-tree-sha1 = "4d5917a26ca33c66c8e5ca3247bd163624d35493" @@ -1011,15 +1046,15 @@ version = "7.8.3" [[deps.NNlib]] deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "ddf38a5d9140bc8c08ea6158484a455ca3efdd2d" +git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.8.18" +version = "0.8.19" [[deps.NNlibCUDA]] -deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"] -git-tree-sha1 = "b05a082b08a3af0e5c576883bc6dfb6513e7e478" +deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"] +git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf" uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d" -version = "0.2.6" +version = "0.2.7" [[deps.NPZ]] deps = ["FileIO", "ZipFile"] @@ -1029,9 +1064,9 @@ version = "0.4.3" [[deps.NaNMath]] deps = ["OpenLibm_jll"] -git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211" +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.1" +version = "1.0.2" [[deps.NameResolution]] deps = ["PrettyPrint"] @@ -1120,9 +1155,9 @@ version = "2.0.2" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "e657acef119cc0de2a8c0762666d3b64727b053b" +git-tree-sha1 = "e5a1825d3d53aa4ad4fb42bd4927011ad4a78c3d" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.2.14" +version = "0.2.15" [[deps.Opus_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1142,9 +1177,9 @@ version = "10.40.0+0" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "cf494dca75a69712a72b80bc48f59dcf3dea63ec" +git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.16" +version = "0.11.17" [[deps.PaddedViews]] deps = ["OffsetArrays"] @@ -1206,9 +1241,9 @@ version = "1.3.4" [[deps.Plots]] deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"] -git-tree-sha1 = "8ac949bd0ebc46a44afb1fdca1094554a84b086e" +git-tree-sha1 = "da1d3fb7183e38603fcdd2061c47979d91202c97" uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -version = "1.38.5" +version = "1.38.6" [[deps.PooledArrays]] deps = ["DataAPI", "Future"] @@ -1356,9 +1391,9 @@ version = "1.1.1" [[deps.SentinelArrays]] deps = ["Dates", "Random"] -git-tree-sha1 = "c02bd3c9c3fc8463d3591a62a378f90d2d8ab0f3" +git-tree-sha1 = "77d3c4726515dca71f6d80fbb5e251088defe305" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.3.17" +version = "1.3.18" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -1422,9 +1457,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.SpecialFunctions]] deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d" +git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.1.7" +version = "2.2.0" [[deps.SplittablesBase]] deps = ["Setfield", "Test"] @@ -1440,9 +1475,9 @@ version = "0.1.1" [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "67d3e75e8af8089ea34ce96974d5468d4a008ca6" +git-tree-sha1 = "2d7d9e1ddadc8407ffd460e24218e37ef52dd9a3" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.15" +version = "1.5.16" [[deps.StaticArraysCore]] git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" @@ -1473,9 +1508,9 @@ version = "0.33.21" [[deps.StatsFuns]] deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "ab6083f09b3e617e34a956b43e9d51b824206932" +git-tree-sha1 = "5aa6250a781e567388f3285fb4b0f214a501b4d5" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.1.1" +version = "1.2.1" [[deps.Strided]] deps = ["LinearAlgebra", "TupleTools"] @@ -1801,6 +1836,12 @@ git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0" uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.2" +[[deps.cuDNN]] +deps = ["CEnum", "CUDA", "CUDNN_jll"] +git-tree-sha1 = "c0ffcb38d1e8c0bbcd3dab2559cf9c369130b2f2" +uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +version = "1.0.1" + [[deps.fzf_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "868e669ccb12ba16eaf50cb2957ee2ff61261c56" diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index a05f147e01a35e68f551d636608751980e09cff0..015fbe0d03bf0a73a962ea731533f0efc4d19990 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -3,6 +3,7 @@ using CCE using ConformalPrediction using CounterfactualExplanations using CounterfactualExplanations.Data +using CounterfactualExplanations.Objectives using Flux using MLJBase using MLJFlux @@ -107,18 +108,62 @@ p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, t plot(p1, p2, size=(800,320)) ``` +### Penalizing Size + +```{julia} +#| output: true +#| echo: false +#| label: fig-ce +#| fig-cap: "Comparison of counterfactuals produced using different generators." + +opt = Descent(0.01) +ordered_names = [ + "Generic (γ=0.5)", + "Generic (γ=0.9)", + "Conformal (λ₂=1)", + "Conformal (λ₂=10)" +] +loss_fun = Objectives.logitbinarycrossentropy + +# Generators: +generators = Dict( + ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5), + ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9), + ordered_names[3] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,1]), + ordered_names[4] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,10]), +) + +counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators]) + +# Plots: +plts = [] +for name ∈ ordered_names + ce = counterfactuals[name] + plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0) + plts = vcat(plts..., plt) +end +_n = length(generators) +img_size = 300 +plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n)) +``` + + +### Configurable Classification Loss + ```{julia} #| output: true #| echo: false #| label: fig-ce #| fig-cap: "Comparison of counterfactuals produced using different generators." +opt = Descent(0.01) ordered_names = [ "Generic (γ=0.5)", "Generic (γ=0.9)", "Conformal (λ₂=1)", "Conformal (λ₂=10)" ] +loss_fun = Objectives.logitbinarycrossentropy # Generators: generators = Dict( diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl index 056387b1b70ceda402115b36d92d0a03d4b5cbac..280e483c84c5025af68d0656ca524a7ee1c12585 100644 --- a/src/ConformalGenerator.jl +++ b/src/ConformalGenerator.jl @@ -1,5 +1,8 @@ +using CategoricalArrays using CounterfactualExplanations using CounterfactualExplanations.Generators +using CounterfactualExplanations.Models: binary_to_onehot +using CounterfactualExplanations.Objectives using Flux using LinearAlgebra using Parameters @@ -7,7 +10,7 @@ using SliceMap using Statistics mutable struct ConformalGenerator <: AbstractGradientBasedGenerator - loss::Union{Nothing,Symbol} # loss function + loss::Union{Nothing,Function} # loss function complexity::Function # complexity function λ::Union{AbstractFloat,AbstractVector} # strength of penalty decision_threshold::Union{Nothing,AbstractFloat} @@ -26,13 +29,12 @@ end end """ - ConformalGenerator( - ; - loss::Symbol=:logitbinarycrossentropy, + ConformalGenerator(; + loss::Union{Nothing,Function}=conformal_training_loss, complexity::Function=norm, - λ::AbstractFloat=0.1, - opt::Flux.Optimise.AbstractOptimiser=Flux.Optimise.Descent(), - τ::AbstractFloat=1e-5 + λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0], + decision_threshold=nothing, + kwargs... ) An outer constructor method that instantiates a generic generator. @@ -43,7 +45,7 @@ generator = ConformalGenerator() ``` """ function ConformalGenerator(; - loss::Union{Nothing,Symbol}=nothing, + loss::Union{Nothing,Function}=conformal_training_loss, complexity::Function=norm, λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0], decision_threshold=nothing, @@ -53,6 +55,32 @@ function ConformalGenerator(; ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp) end +@doc raw""" + conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...) + +A configurable classification loss function for Conformal Predictors. +""" +function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...) + conf_model = counterfactual_explanation.M.model + fitresult = counterfactual_explanation.M.fitresult + X = CounterfactualExplanations.decode_state(counterfactual_explanation) + y = counterfactual_explanation.target_encoded[:,:,1] + if counterfactual_explanation.M.likelihood == :classification_binary + y = binary_to_onehot(y) + end + y = permutedims(y) + generator = counterfactual_explanation.generator + loss = SliceMap.slicemap(X, dims=(1, 2)) do x + x = Matrix(x) + ConformalPrediction.classification_loss( + conf_model, fitresult, x, y; + temp=generator.temp + ) + end + loss = mean(loss) + return loss +end + """ set_size_penalty( generator::ConformalGenerator,