diff --git a/Project.toml b/Project.toml index cc58ad66c61be5b5b9ac1e8f8291ed03a889e0e7..c57589f5a50edb2b1bd7ca8a2d4e97e95a0c0b4b 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PkgTemplates = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index 91dca4127a8ef02e12b7b386732039d28e913932..6116400973cf56050bcca3d92deafdafe49dd82b 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -8,7 +8,7 @@ function default_generators(; nsamples::Union{Nothing,Int}=nothing, nmin::Union{Nothing,Int}=nothing, reg_strength::Real=0.5, - dim_reduction::Union{Nothing,Int}=nothing, + dim_reduction::Bool=false, ) @info "Begin benchmarking counterfactual explanations." diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 14cc424470bf9b3ac774402a35e91dc374e335eb..07c6d11e4775381cf128b727138511c092220e2d 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -13,7 +13,7 @@ counterfactual_data.generative_model = vae # Dimensionality reduction: maxout_dim = vae.params.latent_dim -counterfactual_data.dt = MultivariateStats.fit(PCA, counterfactual_data.X; maxoutdim=maxout_dim); +counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); # Test data: test_data = load_mnist_test() diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 8757410ac1b8d0ce61fac9f14b35ebfa193b05bc..568260ea49f910a6696b8873a713317e38f97b87 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -23,12 +23,12 @@ using MLJ: TunedModel, Grid, CV, fitted_params, report using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table, machine, fit!, Supervised using MLJEnsembles using MLJFlux -using MultivariateStats using Random using Serialization using Statistics import MPI +import MultivariateStats Random.seed!(2023) diff --git a/src/penalties.jl b/src/penalties.jl index c4272dd5ba0827a8ace1e42686c3dd9b0b4e24a0..6e7dad339917507c47467eb66a3983cd1f20a14f 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -4,7 +4,6 @@ using Distances using Flux using Images: assess_ssim using LinearAlgebra: norm -using MultivariateStats using Statistics: mean """