diff --git a/Manifest.toml b/Manifest.toml index d16a0962307e2701b4cb33190dcace9d31955db4..4531e744c4232042cd1891c52c5235f2fdd795ca 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -304,7 +304,7 @@ uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.2" [[deps.CounterfactualExplanations]] -deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] +deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] path = "../CounterfactualExplanations.jl" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" version = "0.1.9" diff --git a/docs/Manifest.toml b/docs/Manifest.toml new file mode 100644 index 0000000000000000000000000000000000000000..bc62edc143f3749077aee8e4207d6539844d299a --- /dev/null +++ b/docs/Manifest.toml @@ -0,0 +1,7 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.8.5" +manifest_format = "2.0" +project_hash = "da39a3ee5e6b4b0d3255bfef95601890afd80709" + +[deps] diff --git a/docs/Project.toml b/docs/Project.toml new file mode 100644 index 0000000000000000000000000000000000000000..81648c0b16f0c3c6b1d2f26cd5994aa97a5b3d22 --- /dev/null +++ b/docs/Project.toml @@ -0,0 +1 @@ +[deps] diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 6dd7ba4c2af64778888eecd9f46414a07292e102..a1c318aaa24efd6ebab22371e0af1b3394f2a008 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -383,7 +383,7 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298" version = "0.6.2" [[deps.CounterfactualExplanations]] -deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] +deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] path = "../../CounterfactualExplanations.jl" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" version = "0.1.9" diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 0716c8bcfe509a95540812fd7f3907628b7c6783..07a6fec5496bb6ca4d3590372b22901a3c834dbf 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -22,12 +22,14 @@ First, let's create a couple of image classifier architectures: ```{julia} # Model parameters: epochs = 100 -batch_size = minimum([Int(round(n_obs/10)), 100]) -n_hidden = 32 +batch_size = minimum([Int(round(n_obs/10)), 128]) +n_hidden = 200 activation = Flux.relu builder = MLJFlux.@builder Flux.Chain( Dense(n_in, n_hidden), BatchNorm(n_hidden, activation), + Dense(n_hidden, n_hidden), + BatchNorm(n_hidden, activation), Dense(n_hidden, n_out), ) # builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation) @@ -51,7 +53,11 @@ mlp = NeuralNetworkClassifier( # Joint Energy Model: ð’Ÿx = Uniform(0,1) ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) -sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy, input_size=(input_dim,), batch_size=batch_size) +sampler = ConditionalSampler( + ð’Ÿx, ð’Ÿy, + input_size=(input_dim,), + batch_size=10 +) jem = JointEnergyClassifier( sampler; builder=builder, @@ -73,31 +79,33 @@ mlp_ens = EnsembleModel(model=mlp, n=5) ```{julia} cov = .90 -conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov) +conf_model = conformal_model(mlp; method=:adaptive_inductive, coverage=cov) mach = machine(conf_model, X, labels) fit!(mach) M = CCE.ConformalModel(mach.model, mach.fitresult) ``` ```{julia} -jem = mach.model.model.jem -n_iter = 100 -_w = 1500 -plts = [] -neach = 10 -for i in 1:10 - x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i) - plts_i = [] - for j in 1:size(x, 2) - xj = x[:,j] - xj = reshape(xj, (n_digits, n_digits)) - plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] +if mach.model.model isa JointEnergyModels.JointEnergyClassifier + jem = mach.model.model.jem + n_iter = 100 + _w = 1500 + plts = [] + neach = 10 + for i in 1:10 + x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i) + plts_i = [] + for j in 1:size(x, 2) + xj = x[:,j] + xj = reshape(xj, (n_digits, n_digits)) + plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] + end + plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) + plts = [plts..., plt] end - plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) - plts = [plts..., plt] + plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) + display(plt) end -plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) -display(plt) ``` ```{julia} @@ -110,12 +118,12 @@ println("F1 score (test): $(round(f1,digits=3))") Random.seed!(1234) # Set up search: -factual_label = 9 +factual_label = 4 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 4 +target = 9 factual = predict_label(M, counterfactual_data, x)[1] γ = 0.5 -T = 1 +T = 100 # Generate counterfactual using generic generator: generator = GenericGenerator() @@ -125,11 +133,14 @@ ce_wachter = generate_counterfactual( initialization=:identity, ) +# CCE: +λ=[0.0,1.0] +temp=0.5 + # Generate counterfactual using CCE generator: generator = CCEGenerator( - λ=[0.0,1.0], - temp=0.5, - # opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), + λ=λ, + temp=temp, ) ce_conformal = generate_counterfactual( x, target, counterfactual_data, M, generator; @@ -138,6 +149,19 @@ ce_conformal = generate_counterfactual( converge_when=:generator_conditions, ) +# Generate counterfactual using CCE generator: +generator = CCEGenerator( + λ=λ, + temp=temp, + opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), +) +ce_conformal_jsma = generate_counterfactual( + x, target, counterfactual_data, M, generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, +) + # Plot: p1 = Plots.plot( convert2image(MNIST, reshape(x,28,28)), @@ -147,12 +171,13 @@ p1 = Plots.plot( ) plts = [p1] -ces = zip([ce_wachter,ce_conformal]) +ces = zip([ce_wachter,ce_conformal,ce_conformal_jsma]) +_names = ["Wachter", "CCE", "CCE-JSMA"] counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) -for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3)) - ce, _phat = (x[1],x[2]) - _title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))" +for x in zip(eachslice(counterfactuals; dims=3), _names, eachslice(phat; dims=3)) + ce, _name, _phat = (x[1],x[2],x[3]) + _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" plt = Plots.plot( convert2image(MNIST, reshape(ce,28,28)), axis=nothing, diff --git a/src/model.jl b/src/model.jl index c724867b51e2d03c696b991d1d708766525d1cfb..cc7cde3c24186017c91040107d06e1623f187089 100644 --- a/src/model.jl +++ b/src/model.jl @@ -52,6 +52,7 @@ function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symb end # Construct model: + testmode!(fitresult[1]) M = ConformalModel(model, fitresult, likelihood) return M end diff --git a/src/penalties.jl b/src/penalties.jl index 09d2c2a28def9c2f4194647f9baf875d7af7c2d1..30ae653c91f21d06eb794ffa3369e9c5cdb6f9a7 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -3,21 +3,21 @@ using LinearAlgebra: norm using Statistics: mean """ - set_size_penalty(counterfactual_explanation::AbstractCounterfactualExplanation) + set_size_penalty(ce::AbstractCounterfactualExplanation) Penalty for smooth conformal set size. """ function set_size_penalty( - counterfactual_explanation::AbstractCounterfactualExplanation; + ce::AbstractCounterfactualExplanation; κ::Real=0.0, temp::Real=0.05, agg=mean ) - conf_model = counterfactual_explanation.M.model - fitresult = counterfactual_explanation.M.fitresult - X = CounterfactualExplanations.decode_state(counterfactual_explanation) - loss = map(eachslice(X, dims=3)) do x - x = Matrix(x) - if target_probs(counterfactual_explanation, x)[1] >= 0.5 + conf_model = ce.M.model + fitresult = ce.M.fitresult + X = CounterfactualExplanations.decode_state(ce) + loss = map(eachslice(X, dims=ndims(X))) do x + x = ndims(x) == 1 ? x[:,:] : x + if target_probs(ce, x)[1] >= 0.5 l = ConformalPrediction.smooth_size_loss( conf_model, fitresult, x'; κ=κ, @@ -35,19 +35,19 @@ function set_size_penalty( end function distance_from_energy( - counterfactual_explanation::AbstractCounterfactualExplanation; + ce::AbstractCounterfactualExplanation; n::Int=10000, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do - _dict = counterfactual_explanation.params + _dict = ce.params if !(:energy_sampler ∈ collect(keys(_dict))) - _dict[:energy_sampler] = CCE.EnergySampler(counterfactual_explanation; kwargs...) + _dict[:energy_sampler] = CCE.EnergySampler(ce; kwargs...) end sampler = _dict[:energy_sampler] push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) end - x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation) + x′ = CounterfactualExplanations.counterfactual(ce) loss = map(eachslice(x′, dims=3)) do x x = Matrix(x) Δ = map(eachcol(conditional_samples[1])) do xsample @@ -62,13 +62,13 @@ function distance_from_energy( end function distance_from_targets( - counterfactual_explanation::AbstractCounterfactualExplanation; + ce::AbstractCounterfactualExplanation; n::Int=10000, agg=mean ) - target_idx = counterfactual_explanation.data.output_encoder.labels .== counterfactual_explanation.target - target_samples = counterfactual_explanation.data.X[:,target_idx] |> + target_idx = ce.data.output_encoder.labels .== ce.target + target_samples = ce.data.X[:,target_idx] |> X -> X[:,rand(1:end,n)] - x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation) + x′ = CounterfactualExplanations.counterfactual(ce) loss = map(eachslice(x′, dims=3)) do x x = Matrix(x) Δ = map(eachcol(target_samples)) do xsample diff --git a/www/cce_mnist.png b/www/cce_mnist.png index ea956a80689729139b58966c11514539991a792a..ff97f8657cd5af3cd5f753e84cb38286614bf6ca 100644 Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ