From 147398461f0c82fad829e67eb613020b13b965c1 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:10:50 +0200 Subject: [PATCH] phew --- Project.toml | 1 - experiments/benchmarking/benchmarking.jl | 2 +- experiments/mnist.jl | 2 +- experiments/setup_env.jl | 2 +- src/penalties.jl | 1 - 5 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index cc58ad66..c57589f5 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" -MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" PkgTemplates = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index 91dca412..61164009 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -8,7 +8,7 @@ function default_generators(; nsamples::Union{Nothing,Int}=nothing, nmin::Union{Nothing,Int}=nothing, reg_strength::Real=0.5, - dim_reduction::Union{Nothing,Int}=nothing, + dim_reduction::Bool=false, ) @info "Begin benchmarking counterfactual explanations." diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 14cc4244..07c6d11e 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -13,7 +13,7 @@ counterfactual_data.generative_model = vae # Dimensionality reduction: maxout_dim = vae.params.latent_dim -counterfactual_data.dt = MultivariateStats.fit(PCA, counterfactual_data.X; maxoutdim=maxout_dim); +counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); # Test data: test_data = load_mnist_test() diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 8757410a..568260ea 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -23,12 +23,12 @@ using MLJ: TunedModel, Grid, CV, fitted_params, report using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table, machine, fit!, Supervised using MLJEnsembles using MLJFlux -using MultivariateStats using Random using Serialization using Statistics import MPI +import MultivariateStats Random.seed!(2023) diff --git a/src/penalties.jl b/src/penalties.jl index c4272dd5..6e7dad33 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -4,7 +4,6 @@ using Distances using Flux using Images: assess_ssim using LinearAlgebra: norm -using MultivariateStats using Statistics: mean """ -- GitLab