diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index 9b8a8f9676d49f0f22054b0f13ca9127e5f98e11..683a8b015bd4b58c5e0fe24833d836a8acc4ccf8 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -276,9 +276,9 @@ version = "0.5.0" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "dbeca245b0680f5393b4e6c40dcead7230ab0b3b" +git-tree-sha1 = "01b0594d8907485ed894bc59adfc0a24a9cde7a3" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.54.0" +version = "1.55.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] @@ -433,12 +433,12 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298" version = "0.6.3" [[deps.CounterfactualExplanations]] -deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] -git-tree-sha1 = "082d60e57d67cf5302ca007ccad053270d401f5f" +deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PkgTemplates", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] +git-tree-sha1 = "1f1861d5b947a7af85c2172e3791c03453b6a183" repo-rev = "main" repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.29" +version = "0.1.30" [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" @@ -543,9 +543,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "9e11104e7b41a8a5f04e8694467fc1f94a135bd7" +git-tree-sha1 = "3d5873f811f582873bb9871fc9c451784d5dc8c7" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.101" +version = "0.25.102" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -674,9 +674,9 @@ version = "1.16.1" [[deps.FilePathsBase]] deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.20" +version = "0.9.21" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" @@ -804,9 +804,9 @@ version = "0.6.1" [[deps.GeoInterface]] deps = ["Extents"] -git-tree-sha1 = "bb198ff907228523f3dee1070ceee63b9359b6ab" +git-tree-sha1 = "d53480c0793b13341c40199190f92c611aa2e93c" uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" -version = "1.3.1" +version = "1.3.2" [[deps.GeometryBasics]] deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] @@ -845,9 +845,9 @@ version = "1.3.14+0" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "1cf1d7dcb4bc32d7b4a5add4232db3750c27ecb4" +git-tree-sha1 = "899050ace26649433ef1af25bc17a815b3db52b7" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.8.0" +version = "1.9.0" [[deps.Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" @@ -856,9 +856,9 @@ version = "1.0.2" [[deps.HDF5]] deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"] -git-tree-sha1 = "ec7df74b7b2022e8252a8bfd4ec23411491adc3b" +git-tree-sha1 = "26407bd1c60129062cec9da63dc7d08251544d53" uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" -version = "0.17.0" +version = "0.17.1" weakdeps = ["MPI"] [deps.HDF5.extensions] @@ -1196,15 +1196,15 @@ version = "3.0.0+1" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "a9d2ce1d5007b1e8f6c5b89c5a31ff8bd146db5c" +git-tree-sha1 = "4ea2928a96acfcf8589e6cd1429eff2a3a82c366" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.2.1" +version = "6.3.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "7ca6850ae880cc99b59b88517545f91a52020afa" +git-tree-sha1 = "e7c01b69bcbcb93fd4cbc3d0fea7d229541e18d2" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.25+0" +version = "0.0.26+0" [[deps.LLVMOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1481,9 +1481,9 @@ version = "0.4.3" [[deps.MPI]] deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"] -git-tree-sha1 = "df53d0e1e0dbebf2315f4cd35e13e52ad43416c2" +git-tree-sha1 = "b4d8707e42b693720b54f0b3434abee6dd4d947a" uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195" -version = "0.20.15" +version = "0.20.16" [deps.MPI.extensions] AMDGPUExt = "AMDGPU" @@ -1567,9 +1567,9 @@ version = "0.1.4" [[deps.MicrosoftMPI_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "a8027af3d1743b3bfae34e54872359fdebb31422" +git-tree-sha1 = "a7023883872e52bc29bcaac74f19adf39347d2d5" uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf" -version = "10.1.3+4" +version = "10.1.4+0" [[deps.Missings]] deps = ["DataAPI"] @@ -1615,18 +1615,20 @@ version = "0.3.0" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "6e4e90c2e2ef091ef50b91af65fa4bb09c3d0728" +git-tree-sha1 = "3bc568de99214f72a76c7773ade218819afcc36e" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.6" +version = "0.9.7" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" + NNlibEnzymeCoreExt = "EnzymeCore" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NPZ]] @@ -1728,9 +1730,9 @@ version = "0.3.1" [[deps.OpenMPI_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"] -git-tree-sha1 = "f3080f4212a8ba2ceb10a34b938601b862094314" +git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762" uuid = "fe0851c0-eecd-5654-98d4-656369965a5c" -version = "4.1.5+0" +version = "4.1.6+0" [[deps.OpenSSL]] deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] @@ -1774,9 +1776,9 @@ version = "10.42.0+0" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "bf6085e8bd7735e68c210c6e5d81f9a6fe192060" +git-tree-sha1 = "b7c4f29f93b548caa58f703580f4d79ab753c8ac" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.19" +version = "0.11.21" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] @@ -2238,9 +2240,9 @@ weakdeps = ["OffsetArrays", "StaticArrays"] [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "d5fb407ec3179063214bc6277712928ba78459e2" +git-tree-sha1 = "0adf069a2a490c47273727e029371b31d44b72b2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.4" +version = "1.6.5" weakdeps = ["Statistics"] [deps.StaticArrays.extensions] @@ -2373,9 +2375,9 @@ version = "0.5.2" [[deps.TiffImages]] deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] -git-tree-sha1 = "b7dc44cb005a7ef743b8fe98970afef003efdce7" +git-tree-sha1 = "7fd97bd1c5b1ff53a291cbd351d1d3d6ff4da5a5" uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69" -version = "0.6.6" +version = "0.6.7" [[deps.TiledIteration]] deps = ["OffsetArrays", "StaticArrayInterface"] @@ -2528,9 +2530,9 @@ version = "1.6.1" [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"] -git-tree-sha1 = "04a51d15436a572301b5abbb9d099713327e9fc4" +git-tree-sha1 = "24b81b59bd35b3c42ab84fa589086e19be919916" uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.10.4+0" +version = "2.11.5+0" [[deps.XSLT_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 9c7938150ffe02eab6c7755256444e5567e3b125..4a68251fa6c640d9b47930a4156fc584dafee7d5 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -47,6 +47,34 @@ include("post_processing/post_processing.jl") include("utils.jl") include("save_best.jl") +# Number of counterfactuals: +n_ind_specified = false +if any(contains.(ARGS, "n_individuals=")) + n_ind_specified = true + n_individuals = + ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> + x -> replace(x, "n_individuals=" => "") |> x -> parse(Int, x) +else + n_individuals = 100 +end + +"Number of individuals to use in benchmarking." +const N_IND = n_individuals + +"Boolean flag to check if number of individuals was specified." +const N_IND_SPECIFIED = n_ind_specified + +if any(contains.(ARGS, "n_each=")) + n_each = + ARGS[findall(contains.(ARGS, "n_each="))][1] |> + x -> replace(x, "n_each=" => "") |> x -> parse(Int, x) +else + n_each = 16 +end + +"Number of objects to pass to each process." +const N_EACH = n_each + # Parallelization: plz = nothing @@ -60,7 +88,7 @@ end if "mpi" ∈ ARGS MPI.Init() const USE_MPI = true - plz = MPIParallelizer(MPI.COMM_WORLD; threaded = USE_THREADS) + plz = MPIParallelizer(MPI.COMM_WORLD; threaded = USE_THREADS, n_each = N_EACH) if MPI.Comm_rank(MPI.COMM_WORLD) != 0 global_logger(NullLogger()) else @@ -135,22 +163,6 @@ const TEST_SIZE = 0.2 "Boolean flag to check if upload was specified." const UPLOAD = "upload" ∈ ARGS -n_ind_specified = false -if any(contains.(ARGS, "n_individuals=")) - n_ind_specified = true - n_individuals = - ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> - x -> replace(x, "n_individuals=" => "") |> x -> parse(Int, x) -else - n_individuals = 100 -end - -"Number of individuals to use in benchmarking." -const N_IND = n_individuals - -"Boolean flag to check if number of individuals was specified." -const N_IND_SPECIFIED = n_ind_specified - "Boolean flag to check if grid search was specified." const GRID_SEARCH = "grid_search" ∈ ARGS