diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index d5ce1cefc51ec584ee743bd7a5c9e33a91c7de1f..c04ca57ef70b11f8d4239ff7fc307fdb25d28318 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -1,53 +1,66 @@ - - ```{julia} include("notebooks/setup.jl") eval(setup_notebooks) ``` +# MNIST + ```{julia} # Data: -counterfactual_data = load_mnist() +counterfactual_data = load_mnist(1000) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) +X = table(permutedims(X)) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) -M = load_mnist_mlp() - -# Target: -factual_label = 8 -x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) -target = 3 -factual = predict_label(M, counterfactual_data, x)[1] -γ = 0.9 -T = 50 ``` ```{julia} -builder = MLJFlux.@builder M.model -clf = NeuralNetworkClassifier(builder=builder, epochs=100) +epochs = 100 +clf = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=relu), epochs=epochs) conf_model = conformal_model(clf; method=:simple_inductive) mach = machine(conf_model, X, labels) fit!(mach) ``` ```{julia} -# Search: +M = CCE.ConformalModel(mach.model, mach.fitresult) +``` + +```{julia} +test_data = load_mnist_test() +f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) +println("F1 score (test): $(round(f1,digits=3))") +``` + +```{julia} +# Set up search: +factual_label = 8 +x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +target = 3 +factual = predict_label(M, counterfactual_data, x)[1] +γ = 0.9 +T = 100 + +# Generate counterfactual using generic generator: generator = GenericGenerator() ce_wachter = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) -generator = deepcopy(generator) |> gen -> @objective(gen, _ + 0.001distance + 1.0set_size_penalty) + +# Generate counterfactual using CCE generator: +generator = CCEGenerator(λ=[0.0,10.0], temp=0.01, opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0)) ce_conformal = generate_counterfactual( x, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, + converge_when=:generator_conditions, ) ``` ```{julia} -p1 = plot( +p1 = Plots.plot( convert2image(MNIST, reshape(x,28,28)), axis=nothing, size=(img_height, img_height), @@ -55,13 +68,13 @@ p1 = plot( ) plts = [p1] -ces = zip([ce_wachter,ce_jsma]) +ces = zip([ce_wachter,ce_conformal]) counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces)) phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces)) for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3)) ce, _phat = (x[1],x[2]) _title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))" - plt = plot( + plt = Plots.plot( convert2image(MNIST, reshape(ce,28,28)), axis=nothing, size=(img_height, img_height), @@ -69,6 +82,7 @@ for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3)) ) plts = [plts..., plt] end -plt = plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) -savefig(plt, joinpath(www_path, "you_may_not_like_it.png")) +plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) +display(plt) +savefig(plt, joinpath(www_path, "cce_mnist.png")) ``` \ No newline at end of file diff --git a/www/cce_mnist.png b/www/cce_mnist.png new file mode 100644 index 0000000000000000000000000000000000000000..85eed69096b075c1519480191c1693cf0ceaf950 Binary files /dev/null and b/www/cce_mnist.png differ