Skip to content
Snippets Groups Projects
Commit 560bf547 authored by pat-alt's avatar pat-alt
Browse files

ufff

parent adf59a74
No related branches found
No related tags found
No related merge requests found
......@@ -47,4 +47,8 @@ generator = CCE.ConformalGenerator()
x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))
y = predict_label(M, counterfactual_data, x)[1]
target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y][1]
```
\ No newline at end of file
```
```{julia}
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
```
......@@ -59,18 +59,22 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
yhat = SliceMap.slicemap(X, dims=(1, 2)) do x
conf_model = M.model
fitresult = M.fitresult
x = MLJBase.table(permutedims(x))
= MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...)
= map() do pp
L = .decoder.classes
probas = pdf.(pp, L)
return probas
# x = MLJBase.table(permutedims(x))
# p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...)
# p̂ = map(p̂) do pp
# L = p̂.decoder.classes
# probas = pdf.(pp, L)
# return probas
# end
= fitresult[1](x)
if size(, 2) > 1
= reduce(hcat, )
end
= reduce(hcat, )
= reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol())))
if M.likelihood == :classification_binary
= reduce(hcat, (map(y -> y[2] - y[1], eachcol())))
end
= ndims() > 1 ? : permutedims([])
return
end
return yhat
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment