Skip to content
Snippets Groups Projects
Commit 07135cd4 authored by pat-alt's avatar pat-alt
Browse files

shoot me

parent 80ec8764
No related branches found
No related tags found
No related merge requests found
......@@ -25,8 +25,21 @@ epochs = 100
batch_size = Int(round(n_obs/10))
n_hidden = 32
activation = Flux.swish
builder = MLJFlux.MLP(hidden=(n_hidden,), σ=activation)
α = [0.33,1.0,1e-1]
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden),
BatchNorm(n_hidden, activation),
Dense(n_hidden, n_out),
BatchNorm(n_out)
)
# builder = MLJFlux.MLP(
# hidden=(
# n_hidden,
# n_hidden,
# n_hidden,
# ),
# σ=activation
# )
α = [1.0,1.0,1e-1]
# Simple MLP:
mlp = NeuralNetworkClassifier(
......@@ -36,17 +49,18 @@ mlp = NeuralNetworkClassifier(
)
# Joint Energy Model:
𝒟x = Uniform(-1,1)
𝒟x = Uniform(0,1)
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=(input_dim,), batch_size=batch_size)
jem = JointEnergyClassifier(
sampler;
builder=builder,
batch_size=batch_size,
finaliser=Flux.softmax,
loss=Flux.Losses.crossentropy,
finaliser=x -> x,
loss=Flux.Losses.logitcrossentropy,
jem_training_params=(α=α,verbosity=10,),
sampling_steps=20,
epochs=epochs,
)
# Deep Ensemble:
......@@ -54,8 +68,8 @@ mlp_ens = EnsembleModel(model=mlp, n=5)
```
```{julia}
cov = .9
conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov)
cov = .90
conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov)
mach = machine(conf_model, X, labels)
fit!(mach)
M = CCE.ConformalModel(mach.model, mach.fitresult)
......@@ -63,7 +77,7 @@ M = CCE.ConformalModel(mach.model, mach.fitresult)
```{julia}
jem = mach.model.model.jem
n_iter = 5000
n_iter = 100
_w = 1500
plts = []
neach = 10
......@@ -78,7 +92,8 @@ for i in 1:10
plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
plts = [plts..., plt]
end
Plots.plot(plts..., size=(_w,_w), layout=(10,1))
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
display(plt)
```
```{julia}
......@@ -88,12 +103,14 @@ println("F1 score (test): $(round(f1,digits=3))")
```
```{julia}
Random.seed!(1234)
# Set up search:
factual_label = 9
factual_label = 2
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 4
target = 0
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.9
γ = 0.5
T = 100
# Generate counterfactual using generic generator:
......@@ -106,9 +123,9 @@ ce_wachter = generate_counterfactual(
# Generate counterfactual using CCE generator:
generator = CCEGenerator(
λ=[0.0,10.0],
temp=0.01,
# opt=CounterfactualExplanations.Generators.JSMADescent(η=0.5),
λ=[0.0,1.0],
temp=0.5,
# opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
)
ce_conformal = generate_counterfactual(
x, target, counterfactual_data, M, generator;
......
......@@ -86,7 +86,11 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
= []
end
= reduce(hcat, )
= reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol())))
if all(0.0 .<= vec() .<= 1.0)
= reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol())))
else
=
end
if M.likelihood == :classification_binary
= reduce(hcat, (map(y -> y[2] - y[1], eachcol())))
end
......
www/cce_mnist.png

14.5 KiB | W: | H:

www/cce_mnist.png

15.7 KiB | W: | H:

www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
www/cce_mnist.png
  • 2-up
  • Swipe
  • Onion skin
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment