Skip to content
Snippets Groups Projects
Commit 80ec8764 authored by pat-alt's avatar pat-alt
Browse files

almost there

parent ad2cb720
No related branches found
No related tags found
No related merge requests found
......@@ -780,7 +780,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.12.0"
[[deps.JointEnergyModels]]
deps = ["Distributions", "Flux", "StatsBase"]
deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
path = "../JointEnergyModels.jl"
uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
version = "0.1.0"
......@@ -997,7 +997,7 @@ version = "0.3.2"
[[deps.MLJFlux]]
deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0"
path = "../MLJFlux.jl"
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
version = "0.2.9"
......
......@@ -180,9 +180,9 @@ version = "0.2.0"
[[deps.CUDA_Runtime_jll]]
deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
git-tree-sha1 = "9ac3ffda60eeae5291be20f35ca264eb8e95bbc6"
git-tree-sha1 = "81eed046f28a0cdd0dc1f61d00a49061b7cc9433"
uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
version = "0.5.0+1"
version = "0.5.0+2"
[[deps.CUDNN_jll]]
deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
......@@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
version = "1.12.0"
[[deps.JointEnergyModels]]
deps = ["ChainRulesCore", "Distributions", "Flux", "StatsBase"]
deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
path = "../../JointEnergyModels.jl"
uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
version = "0.1.0"
......@@ -1293,7 +1293,7 @@ version = "0.3.2"
[[deps.MLJFlux]]
deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0"
path = "../../MLJFlux.jl"
uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
version = "0.2.9"
......
......@@ -7,50 +7,80 @@ eval(setup_notebooks)
```{julia}
# Data:
counterfactual_data = load_mnist()
n_obs = 1000
counterfactual_data = load_mnist(n_obs)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
n_digits = Int(sqrt(input_dim))
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
batchsize = Int(round(n_obs/10))
batch_size = Int(round(n_obs/10))
n_hidden = 32
activation = Flux.swish
builder = MLJFlux.MLP(hidden=(n_hidden,), σ=activation)
α = [0.33,1.0,1e-1]
# Simple MLP:
mlp = NeuralNetworkClassifier(
builder=MLJFlux.MLP(hidden=(32,), σ=relu),
builder=builder,
epochs=epochs,
batch_size=batchsize,
batch_size=batch_size,
)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=5)
# ResNet:
using MLJFlux, Metalhead
builder = MLJFlux.image_builder(ResNet, 18, pretrain=true)
resnet = ImageClassifier(
# Joint Energy Model:
𝒟x = Uniform(-1,1)
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=(input_dim,), batch_size=batch_size)
jem = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=batchsize,
batch_size=batch_size,
finaliser=Flux.softmax,
loss=Flux.Losses.crossentropy,
jem_training_params=(α=α,verbosity=10,),
sampling_steps=20,
)
# ResNet Ensemble:
resnet_ens = EnsembleModel(model=resnet, n=5)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=5)
```
```{julia}
conf_model = conformal_model(mlp; method=:adaptive_inductive, coverage=.99)
cov = .9
conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov)
mach = machine(conf_model, X, labels)
fit!(mach)
M = CCE.ConformalModel(mach.model, mach.fitresult)
```
```{julia}
jem = mach.model.model.jem
n_iter = 5000
_w = 1500
plts = []
neach = 10
for i in 1:10
x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
xj = reshape(xj, (n_digits, n_digits))
plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
end
plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
plts = [plts..., plt]
end
Plots.plot(plts..., size=(_w,_w), layout=(10,1))
```
```{julia}
test_data = load_mnist_test()
f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
......@@ -69,7 +99,7 @@ T = 100
# Generate counterfactual using generic generator:
generator = GenericGenerator()
ce_wachter = generate_counterfactual(
x, target, dt_reduced, M, generator;
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
......@@ -81,7 +111,7 @@ generator = CCEGenerator(
# opt=CounterfactualExplanations.Generators.JSMADescent(η=0.5),
)
ce_conformal = generate_counterfactual(
x, target, dt_reduced, M, generator;
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
......@@ -115,12 +145,8 @@ display(plt)
savefig(plt, joinpath(www_path, "cce_mnist.png"))
```
## Benchmark
```{julia}
# Benchmark generators:
generators = Dict(
......
......@@ -9,9 +9,6 @@ include("losses.jl")
include("generator.jl")
include("sampling.jl")
using MLJFlux
MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = X'
export CCEGenerator, EnergySampler, set_size_penalty, distance_from_energy
end
\ No newline at end of file
www/cce_mnist.png

13.4 KiB | W: | H:

www/cce_mnist.png

14.5 KiB | W: | H:

www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
  • 2-up
  • Swipe
  • Onion skin
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment