Skip to content
Snippets Groups Projects
Commit 136467dc authored by pat-alt's avatar pat-alt
Browse files

some work on configurable loss function

parent 8981107c
No related branches found
No related tags found
No related merge requests found
......@@ -12,15 +12,15 @@ version = "1.2.1"
[[deps.Accessors]]
deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"]
git-tree-sha1 = "b9661b900b50ba475145b311a9a0ef9d2a9c85ea"
git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c"
uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
version = "0.1.26"
version = "0.1.28"
[[deps.Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "0310e08cb19f5da31d08341c6120c047598f5b9c"
deps = ["LinearAlgebra", "Requires"]
git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.5.0"
version = "3.6.1"
[[deps.ArgCheck]]
git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4"
......@@ -49,20 +49,25 @@ git-tree-sha1 = "5ba6c757e8feccf03a1554dfaf3e26b3cfc7fd5e"
uuid = "68821587-b530-5797-8361-c406ea357684"
version = "3.5.1+1"
[[deps.ArrayInterfaceCore]]
deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d"
uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2"
version = "0.1.29"
[[deps.ArrayInterface]]
deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "ec9c36854b569323551a6faf2f31fda15e3459a7"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "7.2.0"
[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
[[deps.BFloat16s]]
deps = ["LinearAlgebra", "Printf", "Random", "Test"]
git-tree-sha1 = "a598ecb0d717092b5539dbbe890c98bac842b072"
git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66"
uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
version = "0.2.0"
version = "0.4.2"
[[deps.BSON]]
git-tree-sha1 = "86e9781ac28f4e80e9b98f7f96eae21891332ac2"
uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
version = "0.3.6"
[[deps.BangBang]]
deps = ["Compat", "ConstructionBase", "Future", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables", "ZygoteRules"]
......@@ -112,10 +117,34 @@ uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
version = "0.10.9"
[[deps.CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "6717cb9a3425ebb7b31ca4f832823615d175f64a"
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions"]
git-tree-sha1 = "edff14c60784c8f7191a62a23b15a421185bc8a8"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.13.1"
version = "4.0.1"
[[deps.CUDA_Driver_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"]
git-tree-sha1 = "75d7896d1ec079ef10d3aee8f3668c11354c03a1"
uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc"
version = "0.2.0+0"
[[deps.CUDA_Runtime_Discovery]]
deps = ["Libdl"]
git-tree-sha1 = "58dd8ec29f54f08c04b052d2c2fa6760b4f4b3a4"
uuid = "1af6417a-86b4-443c-805f-a4643ffb695f"
version = "0.1.1"
[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
git-tree-sha1 = "d3e6ccd30f84936c1a3a53d622d85d7d3f9b9486"
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
version = "0.2.3+2"
[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
git-tree-sha1 = "57011df4fce448828165e566af9befa2ea94350a"
uuid = "62b44479-cb7b-5706-934f-f13b2eb2e645"
version = "8.6.0+3"
[[deps.Cairo_jll]]
deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"]
......@@ -143,9 +172,9 @@ version = "0.1.9"
[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
git-tree-sha1 = "fdde4d8a31cf82b1d136cf6cb53924e8744a832b"
git-tree-sha1 = "7d20c2fb8ab838e41069398685e7b6b5f89ed85b"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.47.0"
version = "1.48.0"
[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
......@@ -155,9 +184,9 @@ version = "1.15.7"
[[deps.ChangesOfVariables]]
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
git-tree-sha1 = "844b061c104c408b24537482469400af6075aae4"
git-tree-sha1 = "485193efd2176b88e6622a39a246f8c5b600e74e"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.5"
version = "0.1.6"
[[deps.Chemfiles]]
deps = ["BinaryProvider", "Chemfiles_jll", "DocStringExtensions", "Libdl"]
......@@ -234,16 +263,16 @@ uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3"
version = "0.3.2"
[[deps.ConformalPrediction]]
deps = ["CategoricalArrays", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "Statistics"]
deps = ["CategoricalArrays", "ChainRules", "Flux", "LinearAlgebra", "MLJBase", "MLJModelInterface", "NaturalSort", "Plots", "Statistics"]
path = "../ConformalPrediction.jl"
uuid = "98bfc277-1877-43dc-819b-a3e38c30242f"
version = "0.1.5"
[[deps.ConstructionBase]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "fb21ddd70a051d882a1686a5a550990bbe371a95"
git-tree-sha1 = "89a9db8d28102b094992472d333674bd1a83ce2a"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
version = "1.4.1"
version = "1.5.1"
[[deps.ContextVariablesX]]
deps = ["Compat", "Logging", "UUIDs"]
......@@ -260,7 +289,7 @@ version = "0.6.2"
deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "PkgTemplates", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "Statistics", "StatsBase", "Tables", "UMAP"]
path = "../CounterfactualExplanations.jl"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
version = "0.1.6"
version = "0.1.7"
[[deps.Crayons]]
git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
......@@ -322,9 +351,9 @@ version = "1.1.0"
[[deps.DiffRules]]
deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "c5b6685d53f933c11404a3ae9822afe30d522494"
git-tree-sha1 = "a4ad7ef19d2cdc2eff57abbbe68032b1cd0bd8f8"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.12.2"
version = "1.13.0"
[[deps.Distances]]
deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"]
......@@ -338,9 +367,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]]
deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "74911ad88921455c6afcad1eefa12bd7b1724631"
git-tree-sha1 = "d71264a7b9a95dca3b8fff4477d94a837346c545"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.80"
version = "0.25.84"
[[deps.DocStringExtensions]]
deps = ["LibGit2"]
......@@ -416,10 +445,10 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.13.7"
[[deps.FiniteDiff]]
deps = ["ArrayInterfaceCore", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"]
git-tree-sha1 = "04ed1f0029b6b3af88343e439b995141cb0d0b8d"
deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"]
git-tree-sha1 = "ed1b56934a2f7a65035976985da71b6a65b4f2cf"
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
version = "2.17.0"
version = "2.18.0"
[[deps.FixedPointNumbers]]
deps = ["Statistics"]
......@@ -429,9 +458,9 @@ version = "0.8.4"
[[deps.Flux]]
deps = ["Adapt", "CUDA", "ChainRulesCore", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "NNlibCUDA", "OneHotArrays", "Optimisers", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "Zygote"]
git-tree-sha1 = "c258e51850ac1fdc465f62380a61995d4a66d603"
git-tree-sha1 = "4ff3a1d7b0dd38f2fc38e813bc801f817639c1f2"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.13.12"
version = "0.13.13"
[[deps.FoldsThreads]]
deps = ["Accessors", "FunctionWrappers", "InitialValues", "SplittablesBase", "Transducers"]
......@@ -453,9 +482,9 @@ version = "0.4.2"
[[deps.ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "a69dd6db8a809f78846ff259298678f0d6212180"
git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.34"
version = "0.10.35"
[[deps.FreeType2_jll]]
deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"]
......@@ -476,9 +505,9 @@ version = "1.1.3"
[[deps.Functors]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "61fa9cf802d35fe1b5b8ea9fbaac4b8f020d19b1"
git-tree-sha1 = "7ed0833a55979d3d2658a60b901469748a6b9a7c"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.4.2"
version = "0.4.3"
[[deps.Future]]
deps = ["Random"]
......@@ -492,15 +521,15 @@ version = "3.3.8+0"
[[deps.GPUArrays]]
deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
git-tree-sha1 = "4dfaff044eb2ce11a897fecd85538310e60b91e6"
git-tree-sha1 = "a28f752ffab0ccd6660fc7af5ad1c9ad176f45f7"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "8.6.2"
version = "8.6.3"
[[deps.GPUArraysCore]]
deps = ["Adapt"]
git-tree-sha1 = "57f7cde02d7a53c9d1d28443b9f11ac5fbe7ebc9"
git-tree-sha1 = "1cd7f0af1aa58abc02ea1d872953a97359cb87fa"
uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
version = "0.1.3"
version = "0.1.4"
[[deps.GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
......@@ -562,9 +591,9 @@ version = "1.0.2"
[[deps.HDF5]]
deps = ["Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires", "UUIDs"]
git-tree-sha1 = "b5df7c3cab3a00c33c2e09c6bd23982a75e2fbb2"
git-tree-sha1 = "3dab31542b3da9f25a6a1d11159d4af8fdce7d67"
uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
version = "0.16.13"
version = "0.16.14"
[[deps.HDF5_jll]]
deps = ["Artifacts", "JLLWrappers", "LibCURL_jll", "Libdl", "OpenSSL_jll", "Pkg", "Zlib_jll"]
......@@ -609,10 +638,10 @@ uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534"
version = "0.9.4"
[[deps.ImageShow]]
deps = ["Base64", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"]
git-tree-sha1 = "b563cf9ae75a635592fc73d3eb78b86220e55bd8"
deps = ["Base64", "ColorSchemes", "FileIO", "ImageBase", "ImageCore", "OffsetArrays", "StackViews"]
git-tree-sha1 = "ce28c68c900eed3cdbfa418be66ed053e54d4f56"
uuid = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
version = "0.3.6"
version = "0.3.7"
[[deps.Inflate]]
git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428"
......@@ -657,9 +686,9 @@ uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
version = "1.2.0"
[[deps.IrrationalConstants]]
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
version = "0.1.1"
version = "0.2.2"
[[deps.IteratorInterfaceExtensions]]
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
......@@ -697,10 +726,10 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.12.0"
[[deps.JpegTurbo_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "b53380851c6e6664204efb2e62cd24fa5c47e4ba"
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc"
uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8"
version = "2.1.2+0"
version = "2.1.91+0"
[[deps.JuliaVariables]]
deps = ["MLStyle", "NameResolution"]
......@@ -733,9 +762,9 @@ version = "4.16.0"
[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg", "TOML"]
git-tree-sha1 = "771bfe376249626d3ca12bcd58ba243d3f961576"
git-tree-sha1 = "7718cf44439c676bc0ec66a87099f41015a522d6"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.16+0"
version = "0.0.16+2"
[[deps.LZO_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -851,9 +880,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
[[deps.LogExpFunctions]]
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "680e733c3a0a9cea9e935c8c2184aea6a63fa0b5"
git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.21"
version = "0.3.23"
[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
......@@ -890,15 +919,15 @@ version = "0.7.9"
[[deps.MLJBase]]
deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "f6667db64f84c5031e3f4e48b5da80e1dd39429d"
git-tree-sha1 = "6f3a7338e787cbf3460f035c21ee2547f71f8007"
uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
version = "0.21.5"
version = "0.21.6"
[[deps.MLJFlux]]
deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "ProgressMeter", "Random", "Statistics", "Tables"]
git-tree-sha1 = "a47257705ebca405a25320b111345a978925bcd5"
deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0"
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
version = "0.2.7"
version = "0.2.9"
[[deps.MLJModelInterface]]
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
......@@ -908,9 +937,9 @@ version = "1.8.0"
[[deps.MLJModels]]
deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "638c84dea26eb91a538afb52cef11d77556e6807"
git-tree-sha1 = "1d445497ca058dbc0dbc7528b778707893edb969"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
version = "0.16.3"
version = "0.16.4"
[[deps.MLStyle]]
git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8"
......@@ -960,6 +989,12 @@ git-tree-sha1 = "c13304c81eec1ed3af7fc20e75fb6b26092a1102"
uuid = "442fdcdd-2543-5da2-b0f3-8c86c306513e"
version = "0.3.2"
[[deps.Metalhead]]
deps = ["Artifacts", "BSON", "Flux", "Functors", "LazyArtifacts", "MLUtils", "NNlib", "Random", "Statistics"]
git-tree-sha1 = "0e95f91cc5f23610f8f270d7397f307b21e19d2b"
uuid = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
version = "0.7.4"
[[deps.MicroCollections]]
deps = ["BangBang", "InitialValues", "Setfield"]
git-tree-sha1 = "4d5917a26ca33c66c8e5ca3247bd163624d35493"
......@@ -1011,15 +1046,15 @@ version = "7.8.3"
[[deps.NNlib]]
deps = ["Adapt", "ChainRulesCore", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
git-tree-sha1 = "ddf38a5d9140bc8c08ea6158484a455ca3efdd2d"
git-tree-sha1 = "33ad5a19dc6730d592d8ce91c14354d758e53b0e"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.8.18"
version = "0.8.19"
[[deps.NNlibCUDA]]
deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
git-tree-sha1 = "b05a082b08a3af0e5c576883bc6dfb6513e7e478"
deps = ["Adapt", "CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics", "cuDNN"]
git-tree-sha1 = "f94a9684394ff0d325cc12b06da7032d8be01aaf"
uuid = "a00861dc-f156-4864-bf3c-e6376f28a68d"
version = "0.2.6"
version = "0.2.7"
[[deps.NPZ]]
deps = ["FileIO", "ZipFile"]
......@@ -1029,9 +1064,9 @@ version = "0.4.3"
[[deps.NaNMath]]
deps = ["OpenLibm_jll"]
git-tree-sha1 = "a7c3d1da1189a1c2fe843a3bfa04d18d20eb3211"
git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "1.0.1"
version = "1.0.2"
[[deps.NameResolution]]
deps = ["PrettyPrint"]
......@@ -1120,9 +1155,9 @@ version = "2.0.2"
[[deps.Optimisers]]
deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "e657acef119cc0de2a8c0762666d3b64727b053b"
git-tree-sha1 = "e5a1825d3d53aa4ad4fb42bd4927011ad4a78c3d"
uuid = "3bd65402-5787-11e9-1adc-39752487f4e2"
version = "0.2.14"
version = "0.2.15"
[[deps.Opus_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -1142,9 +1177,9 @@ version = "10.40.0+0"
[[deps.PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "cf494dca75a69712a72b80bc48f59dcf3dea63ec"
git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.16"
version = "0.11.17"
[[deps.PaddedViews]]
deps = ["OffsetArrays"]
......@@ -1206,9 +1241,9 @@ version = "1.3.4"
[[deps.Plots]]
deps = ["Base64", "Contour", "Dates", "Downloads", "FFMPEG", "FixedPointNumbers", "GR", "JLFzf", "JSON", "LaTeXStrings", "Latexify", "LinearAlgebra", "Measures", "NaNMath", "Pkg", "PlotThemes", "PlotUtils", "Preferences", "Printf", "REPL", "Random", "RecipesBase", "RecipesPipeline", "Reexport", "RelocatableFolders", "Requires", "Scratch", "Showoff", "SnoopPrecompile", "SparseArrays", "Statistics", "StatsBase", "UUIDs", "UnicodeFun", "Unzip"]
git-tree-sha1 = "8ac949bd0ebc46a44afb1fdca1094554a84b086e"
git-tree-sha1 = "da1d3fb7183e38603fcdd2061c47979d91202c97"
uuid = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
version = "1.38.5"
version = "1.38.6"
[[deps.PooledArrays]]
deps = ["DataAPI", "Future"]
......@@ -1356,9 +1391,9 @@ version = "1.1.1"
[[deps.SentinelArrays]]
deps = ["Dates", "Random"]
git-tree-sha1 = "c02bd3c9c3fc8463d3591a62a378f90d2d8ab0f3"
git-tree-sha1 = "77d3c4726515dca71f6d80fbb5e251088defe305"
uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
version = "1.3.17"
version = "1.3.18"
[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
......@@ -1422,9 +1457,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[[deps.SpecialFunctions]]
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "d75bda01f8c31ebb72df80a46c88b25d1c79c56d"
git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "2.1.7"
version = "2.2.0"
[[deps.SplittablesBase]]
deps = ["Setfield", "Test"]
......@@ -1440,9 +1475,9 @@ version = "0.1.1"
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
git-tree-sha1 = "67d3e75e8af8089ea34ce96974d5468d4a008ca6"
git-tree-sha1 = "2d7d9e1ddadc8407ffd460e24218e37ef52dd9a3"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.5.15"
version = "1.5.16"
[[deps.StaticArraysCore]]
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
......@@ -1473,9 +1508,9 @@ version = "0.33.21"
[[deps.StatsFuns]]
deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "ab6083f09b3e617e34a956b43e9d51b824206932"
git-tree-sha1 = "5aa6250a781e567388f3285fb4b0f214a501b4d5"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "1.1.1"
version = "1.2.1"
[[deps.Strided]]
deps = ["LinearAlgebra", "TupleTools"]
......@@ -1801,6 +1836,12 @@ git-tree-sha1 = "8c1a8e4dfacb1fd631745552c8db35d0deb09ea0"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.2"
[[deps.cuDNN]]
deps = ["CEnum", "CUDA", "CUDNN_jll"]
git-tree-sha1 = "c0ffcb38d1e8c0bbcd3dab2559cf9c369130b2f2"
uuid = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
version = "1.0.1"
[[deps.fzf_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "868e669ccb12ba16eaf50cb2957ee2ff61261c56"
......
......@@ -3,6 +3,7 @@ using CCE
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
using CounterfactualExplanations.Objectives
using Flux
using MLJBase
using MLJFlux
......@@ -107,18 +108,62 @@ p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, t
plot(p1, p2, size=(800,320))
```
### Penalizing Size
```{julia}
#| output: true
#| echo: false
#| label: fig-ce
#| fig-cap: "Comparison of counterfactuals produced using different generators."
opt = Descent(0.01)
ordered_names = [
"Generic (γ=0.5)",
"Generic (γ=0.9)",
"Conformal (λ₂=1)",
"Conformal (λ₂=10)"
]
loss_fun = Objectives.logitbinarycrossentropy
# Generators:
generators = Dict(
ordered_names[1] => GenericGenerator(opt = opt, decision_threshold=0.5),
ordered_names[2] => GenericGenerator(opt = opt, decision_threshold=0.9),
ordered_names[3] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,1]),
ordered_names[4] => CCE.ConformalGenerator(loss=loss_fun, opt=opt, λ=[0.1,10]),
)
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity) for (name, gen) in generators])
# Plots:
plts = []
for name ∈ ordered_names
ce = counterfactuals[name]
plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0)
plts = vcat(plts..., plt)
end
_n = length(generators)
img_size = 300
plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n))
```
### Configurable Classification Loss
```{julia}
#| output: true
#| echo: false
#| label: fig-ce
#| fig-cap: "Comparison of counterfactuals produced using different generators."
opt = Descent(0.01)
ordered_names = [
"Generic (γ=0.5)",
"Generic (γ=0.9)",
"Conformal (λ₂=1)",
"Conformal (λ₂=10)"
]
loss_fun = Objectives.logitbinarycrossentropy
# Generators:
generators = Dict(
......
using CategoricalArrays
using CounterfactualExplanations
using CounterfactualExplanations.Generators
using CounterfactualExplanations.Models: binary_to_onehot
using CounterfactualExplanations.Objectives
using Flux
using LinearAlgebra
using Parameters
......@@ -7,7 +10,7 @@ using SliceMap
using Statistics
mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
loss::Union{Nothing,Symbol} # loss function
loss::Union{Nothing,Function} # loss function
complexity::Function # complexity function
λ::Union{AbstractFloat,AbstractVector} # strength of penalty
decision_threshold::Union{Nothing,AbstractFloat}
......@@ -26,13 +29,12 @@ end
end
"""
ConformalGenerator(
;
loss::Symbol=:logitbinarycrossentropy,
ConformalGenerator(;
loss::Union{Nothing,Function}=conformal_training_loss,
complexity::Function=norm,
λ::AbstractFloat=0.1,
opt::Flux.Optimise.AbstractOptimiser=Flux.Optimise.Descent(),
τ::AbstractFloat=1e-5
λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0],
decision_threshold=nothing,
kwargs...
)
An outer constructor method that instantiates a generic generator.
......@@ -43,7 +45,7 @@ generator = ConformalGenerator()
```
"""
function ConformalGenerator(;
loss::Union{Nothing,Symbol}=nothing,
loss::Union{Nothing,Function}=conformal_training_loss,
complexity::Function=norm,
λ::Union{AbstractFloat,AbstractVector}=[0.1, 1.0],
decision_threshold=nothing,
......@@ -53,6 +55,32 @@ function ConformalGenerator(;
ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp)
end
@doc raw"""
conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
A configurable classification loss function for Conformal Predictors.
"""
function conformal_training_loss(counterfactual_explanation::AbstractCounterfactualExplanation; kwargs...)
conf_model = counterfactual_explanation.M.model
fitresult = counterfactual_explanation.M.fitresult
X = CounterfactualExplanations.decode_state(counterfactual_explanation)
y = counterfactual_explanation.target_encoded[:,:,1]
if counterfactual_explanation.M.likelihood == :classification_binary
y = binary_to_onehot(y)
end
y = permutedims(y)
generator = counterfactual_explanation.generator
loss = SliceMap.slicemap(X, dims=(1, 2)) do x
x = Matrix(x)
ConformalPrediction.classification_loss(
conf_model, fitresult, x, y;
temp=generator.temp
)
end
loss = mean(loss)
return loss
end
"""
set_size_penalty(
generator::ConformalGenerator,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment