diff --git a/Manifest.toml b/Manifest.toml index 144bf11e5b9b34ebd0baf7785cb8ecdf4f9c2a96..d16a0962307e2701b4cb33190dcace9d31955db4 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -780,7 +780,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JointEnergyModels]] -deps = ["Distributions", "Flux", "StatsBase"] +deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] path = "../JointEnergyModels.jl" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.0" @@ -997,7 +997,7 @@ version = "0.3.2" [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] -git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0" +path = "../MLJFlux.jl" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" version = "0.2.9" diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index c38fdac5e7699d8f05b711dc80007a8dca43330c..6dd7ba4c2af64778888eecd9f46414a07292e102 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -180,9 +180,9 @@ version = "0.2.0" [[deps.CUDA_Runtime_jll]] deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "9ac3ffda60eeae5291be20f35ca264eb8e95bbc6" +git-tree-sha1 = "81eed046f28a0cdd0dc1f61d00a49061b7cc9433" uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.5.0+1" +version = "0.5.0+2" [[deps.CUDNN_jll]] deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] @@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JointEnergyModels]] -deps = ["ChainRulesCore", "Distributions", "Flux", "StatsBase"] +deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] path = "../../JointEnergyModels.jl" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.0" @@ -1293,7 +1293,7 @@ version = "0.3.2" [[deps.MLJFlux]] deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] -git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0" +path = "../../MLJFlux.jl" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" version = "0.2.9" diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 4b2b2d3d5346cae6d1418e4b254099aa420d0ae7..433e3ba30df6509ee002fed887ba04f5fa7bef51 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -7,50 +7,80 @@ eval(setup_notebooks) ```{julia} # Data: -counterfactual_data = load_mnist() +n_obs = 1000 +counterfactual_data = load_mnist(n_obs) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) +n_digits = Int(sqrt(input_dim)) +output_dim = length(unique(labels)) ``` First, let's create a couple of image classifier architectures: ```{julia} +# Model parameters: epochs = 100 -batchsize = Int(round(n_obs/10)) +batch_size = Int(round(n_obs/10)) +n_hidden = 32 +activation = Flux.swish +builder = MLJFlux.MLP(hidden=(n_hidden,), σ=activation) +α = [0.33,1.0,1e-1] # Simple MLP: mlp = NeuralNetworkClassifier( - builder=MLJFlux.MLP(hidden=(32,), σ=relu), + builder=builder, epochs=epochs, - batch_size=batchsize, + batch_size=batch_size, ) -# Deep Ensemble: -mlp_ens = EnsembleModel(model=mlp, n=5) - -# ResNet: -using MLJFlux, Metalhead -builder = MLJFlux.image_builder(ResNet, 18, pretrain=true) -resnet = ImageClassifier( +# Joint Energy Model: +ð’Ÿx = Uniform(-1,1) +ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) +sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy, input_size=(input_dim,), batch_size=batch_size) +jem = JointEnergyClassifier( + sampler; builder=builder, - epochs=epochs, - batch_size=batchsize, + batch_size=batch_size, + finaliser=Flux.softmax, + loss=Flux.Losses.crossentropy, + jem_training_params=(α=α,verbosity=10,), + sampling_steps=20, ) -# ResNet Ensemble: -resnet_ens = EnsembleModel(model=resnet, n=5) +# Deep Ensemble: +mlp_ens = EnsembleModel(model=mlp, n=5) ``` ```{julia} - -conf_model = conformal_model(mlp; method=:adaptive_inductive, coverage=.99) +cov = .9 +conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov) mach = machine(conf_model, X, labels) fit!(mach) M = CCE.ConformalModel(mach.model, mach.fitresult) ``` +```{julia} +jem = mach.model.model.jem +n_iter = 5000 +_w = 1500 +plts = [] +neach = 10 +for i in 1:10 + x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i) + plts_i = [] + for j in 1:size(x, 2) + xj = x[:,j] + xj = reshape(xj, (n_digits, n_digits)) + plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] + end + plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) + plts = [plts..., plt] +end +Plots.plot(plts..., size=(_w,_w), layout=(10,1)) +``` + ```{julia} test_data = load_mnist_test() f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) @@ -69,7 +99,7 @@ T = 100 # Generate counterfactual using generic generator: generator = GenericGenerator() ce_wachter = generate_counterfactual( - x, target, dt_reduced, M, generator; + x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) @@ -81,7 +111,7 @@ generator = CCEGenerator( # opt=CounterfactualExplanations.Generators.JSMADescent(η=0.5), ) ce_conformal = generate_counterfactual( - x, target, dt_reduced, M, generator; + x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, @@ -115,12 +145,8 @@ display(plt) savefig(plt, joinpath(www_path, "cce_mnist.png")) ``` - - ## Benchmark - - ```{julia} # Benchmark generators: generators = Dict( diff --git a/src/CCE.jl b/src/CCE.jl index 1b678a8121a413e2282e4a0096bdd67a02b0d605..809408f16ccc7b2e5619e2fcb5e163eef83c5501 100644 --- a/src/CCE.jl +++ b/src/CCE.jl @@ -9,9 +9,6 @@ include("losses.jl") include("generator.jl") include("sampling.jl") -using MLJFlux -MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = X' - export CCEGenerator, EnergySampler, set_size_penalty, distance_from_energy end \ No newline at end of file diff --git a/www/cce_mnist.png b/www/cce_mnist.png index 1fd6cc5db60fa44b4b1a090e151e4fababaef3fc..8af78f6df7b2364c9d6883f4b5dea1046c1ac173 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ