Skip to content
Snippets Groups Projects
Commit b36ec996 authored by Pat Alt's avatar Pat Alt
Browse files

batching

parent d3f7291e
No related branches found
No related tags found
1 merge request!8985 overshooting
......@@ -276,9 +276,9 @@ version = "0.5.0"
[[deps.ChainRules]]
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"]
git-tree-sha1 = "dbeca245b0680f5393b4e6c40dcead7230ab0b3b"
git-tree-sha1 = "01b0594d8907485ed894bc59adfc0a24a9cde7a3"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.54.0"
version = "1.55.0"
[[deps.ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
......@@ -433,12 +433,12 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298"
version = "0.6.3"
[[deps.CounterfactualExplanations]]
deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"]
git-tree-sha1 = "082d60e57d67cf5302ca007ccad053270d401f5f"
deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PkgTemplates", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"]
git-tree-sha1 = "1f1861d5b947a7af85c2172e3791c03453b6a183"
repo-rev = "main"
repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
version = "0.1.29"
version = "0.1.30"
[deps.CounterfactualExplanations.extensions]
MPIExt = "MPI"
......@@ -543,9 +543,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[deps.Distributions]]
deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "9e11104e7b41a8a5f04e8694467fc1f94a135bd7"
git-tree-sha1 = "3d5873f811f582873bb9871fc9c451784d5dc8c7"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.101"
version = "0.25.102"
[deps.Distributions.extensions]
DistributionsChainRulesCoreExt = "ChainRulesCore"
......@@ -674,9 +674,9 @@ version = "1.16.1"
[[deps.FilePathsBase]]
deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"]
git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12"
git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa"
uuid = "48062228-2e41-5def-b9a4-89aafe57970f"
version = "0.9.20"
version = "0.9.21"
[[deps.FileWatching]]
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
......@@ -804,9 +804,9 @@ version = "0.6.1"
[[deps.GeoInterface]]
deps = ["Extents"]
git-tree-sha1 = "bb198ff907228523f3dee1070ceee63b9359b6ab"
git-tree-sha1 = "d53480c0793b13341c40199190f92c611aa2e93c"
uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f"
version = "1.3.1"
version = "1.3.2"
[[deps.GeometryBasics]]
deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"]
......@@ -845,9 +845,9 @@ version = "1.3.14+0"
[[deps.Graphs]]
deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"]
git-tree-sha1 = "1cf1d7dcb4bc32d7b4a5add4232db3750c27ecb4"
git-tree-sha1 = "899050ace26649433ef1af25bc17a815b3db52b7"
uuid = "86223c79-3864-5bf0-83f7-82e725a168b6"
version = "1.8.0"
version = "1.9.0"
[[deps.Grisu]]
git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2"
......@@ -856,9 +856,9 @@ version = "1.0.2"
[[deps.HDF5]]
deps = ["Compat", "HDF5_jll", "Libdl", "MPIPreferences", "Mmap", "Preferences", "Printf", "Random", "Requires", "UUIDs"]
git-tree-sha1 = "ec7df74b7b2022e8252a8bfd4ec23411491adc3b"
git-tree-sha1 = "26407bd1c60129062cec9da63dc7d08251544d53"
uuid = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
version = "0.17.0"
version = "0.17.1"
weakdeps = ["MPI"]
[deps.HDF5.extensions]
......@@ -1196,15 +1196,15 @@ version = "3.0.0+1"
[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "a9d2ce1d5007b1e8f6c5b89c5a31ff8bd146db5c"
git-tree-sha1 = "4ea2928a96acfcf8589e6cd1429eff2a3a82c366"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "6.2.1"
version = "6.3.0"
[[deps.LLVMExtra_jll]]
deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "7ca6850ae880cc99b59b88517545f91a52020afa"
git-tree-sha1 = "e7c01b69bcbcb93fd4cbc3d0fea7d229541e18d2"
uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab"
version = "0.0.25+0"
version = "0.0.26+0"
[[deps.LLVMOpenMP_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
......@@ -1481,9 +1481,9 @@ version = "0.4.3"
[[deps.MPI]]
deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"]
git-tree-sha1 = "df53d0e1e0dbebf2315f4cd35e13e52ad43416c2"
git-tree-sha1 = "b4d8707e42b693720b54f0b3434abee6dd4d947a"
uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195"
version = "0.20.15"
version = "0.20.16"
[deps.MPI.extensions]
AMDGPUExt = "AMDGPU"
......@@ -1567,9 +1567,9 @@ version = "0.1.4"
[[deps.MicrosoftMPI_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
git-tree-sha1 = "a8027af3d1743b3bfae34e54872359fdebb31422"
git-tree-sha1 = "a7023883872e52bc29bcaac74f19adf39347d2d5"
uuid = "9237b28f-5490-5468-be7b-bb81f5f5e6cf"
version = "10.1.3+4"
version = "10.1.4+0"
[[deps.Missings]]
deps = ["DataAPI"]
......@@ -1615,18 +1615,20 @@ version = "0.3.0"
[[deps.NNlib]]
deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"]
git-tree-sha1 = "6e4e90c2e2ef091ef50b91af65fa4bb09c3d0728"
git-tree-sha1 = "3bc568de99214f72a76c7773ade218819afcc36e"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.9.6"
version = "0.9.7"
[deps.NNlib.extensions]
NNlibAMDGPUExt = "AMDGPU"
NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
NNlibCUDAExt = "CUDA"
NNlibEnzymeCoreExt = "EnzymeCore"
[deps.NNlib.weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
[[deps.NPZ]]
......@@ -1728,9 +1730,9 @@ version = "0.3.1"
[[deps.OpenMPI_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "MPIPreferences", "TOML"]
git-tree-sha1 = "f3080f4212a8ba2ceb10a34b938601b862094314"
git-tree-sha1 = "e25c1778a98e34219a00455d6e4384e017ea9762"
uuid = "fe0851c0-eecd-5654-98d4-656369965a5c"
version = "4.1.5+0"
version = "4.1.6+0"
[[deps.OpenSSL]]
deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"]
......@@ -1774,9 +1776,9 @@ version = "10.42.0+0"
[[deps.PDMats]]
deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "bf6085e8bd7735e68c210c6e5d81f9a6fe192060"
git-tree-sha1 = "b7c4f29f93b548caa58f703580f4d79ab753c8ac"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.19"
version = "0.11.21"
[[deps.PNGFiles]]
deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"]
......@@ -2238,9 +2240,9 @@ weakdeps = ["OffsetArrays", "StaticArrays"]
[[deps.StaticArrays]]
deps = ["LinearAlgebra", "Random", "StaticArraysCore"]
git-tree-sha1 = "d5fb407ec3179063214bc6277712928ba78459e2"
git-tree-sha1 = "0adf069a2a490c47273727e029371b31d44b72b2"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.6.4"
version = "1.6.5"
weakdeps = ["Statistics"]
[deps.StaticArrays.extensions]
......@@ -2373,9 +2375,9 @@ version = "0.5.2"
[[deps.TiffImages]]
deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"]
git-tree-sha1 = "b7dc44cb005a7ef743b8fe98970afef003efdce7"
git-tree-sha1 = "7fd97bd1c5b1ff53a291cbd351d1d3d6ff4da5a5"
uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69"
version = "0.6.6"
version = "0.6.7"
[[deps.TiledIteration]]
deps = ["OffsetArrays", "StaticArrayInterface"]
......@@ -2528,9 +2530,9 @@ version = "1.6.1"
[[deps.XML2_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Zlib_jll"]
git-tree-sha1 = "04a51d15436a572301b5abbb9d099713327e9fc4"
git-tree-sha1 = "24b81b59bd35b3c42ab84fa589086e19be919916"
uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a"
version = "2.10.4+0"
version = "2.11.5+0"
[[deps.XSLT_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"]
......
......@@ -47,6 +47,34 @@ include("post_processing/post_processing.jl")
include("utils.jl")
include("save_best.jl")
# Number of counterfactuals:
n_ind_specified = false
if any(contains.(ARGS, "n_individuals="))
n_ind_specified = true
n_individuals =
ARGS[findall(contains.(ARGS, "n_individuals="))][1] |>
x -> replace(x, "n_individuals=" => "") |> x -> parse(Int, x)
else
n_individuals = 100
end
"Number of individuals to use in benchmarking."
const N_IND = n_individuals
"Boolean flag to check if number of individuals was specified."
const N_IND_SPECIFIED = n_ind_specified
if any(contains.(ARGS, "n_each="))
n_each =
ARGS[findall(contains.(ARGS, "n_each="))][1] |>
x -> replace(x, "n_each=" => "") |> x -> parse(Int, x)
else
n_each = 16
end
"Number of objects to pass to each process."
const N_EACH = n_each
# Parallelization:
plz = nothing
......@@ -60,7 +88,7 @@ end
if "mpi" ARGS
MPI.Init()
const USE_MPI = true
plz = MPIParallelizer(MPI.COMM_WORLD; threaded = USE_THREADS)
plz = MPIParallelizer(MPI.COMM_WORLD; threaded = USE_THREADS, n_each = N_EACH)
if MPI.Comm_rank(MPI.COMM_WORLD) != 0
global_logger(NullLogger())
else
......@@ -135,22 +163,6 @@ const TEST_SIZE = 0.2
"Boolean flag to check if upload was specified."
const UPLOAD = "upload" ARGS
n_ind_specified = false
if any(contains.(ARGS, "n_individuals="))
n_ind_specified = true
n_individuals =
ARGS[findall(contains.(ARGS, "n_individuals="))][1] |>
x -> replace(x, "n_individuals=" => "") |> x -> parse(Int, x)
else
n_individuals = 100
end
"Number of individuals to use in benchmarking."
const N_IND = n_individuals
"Boolean flag to check if number of individuals was specified."
const N_IND_SPECIFIED = n_ind_specified
"Boolean flag to check if grid search was specified."
const GRID_SEARCH = "grid_search" ARGS
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment