diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index 9c45239d3b594907d78abc306dbbcf901e96ba5b..8a84a6cd9dd4867228672e5351b8743152adc3c4 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -47,4 +47,8 @@ generator = CCE.ConformalGenerator() x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2))) y = predict_label(M, counterfactual_data, x)[1] target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y][1] -``` \ No newline at end of file +``` + +```{julia} +ce = generate_counterfactual(x, target, counterfactual_data, M, generator) +``` diff --git a/src/model.jl b/src/model.jl index 1db725d1d3ea680cc31f098539b71ba351127dd5..654af645614f6fd74519dd63812648f7855ba684 100644 --- a/src/model.jl +++ b/src/model.jl @@ -59,18 +59,22 @@ function Models.logits(M::ConformalModel, X::AbstractArray) yhat = SliceMap.slicemap(X, dims=(1, 2)) do x conf_model = M.model fitresult = M.fitresult - x = MLJBase.table(permutedims(x)) - p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...) - p̂ = map(p̂) do pp - L = p̂.decoder.classes - probas = pdf.(pp, L) - return probas + # x = MLJBase.table(permutedims(x)) + # p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...) + # p̂ = map(p̂) do pp + # L = p̂.decoder.classes + # probas = pdf.(pp, L) + # return probas + # end + p̂ = fitresult[1](x) + if size(p̂, 2) > 1 + p̂ = reduce(hcat, p̂) end - p̂ = reduce(hcat, p̂) ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂)))) if M.likelihood == :classification_binary ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ)))) end + ŷ = ndims(ŷ) > 1 ? ŷ : permutedims([ŷ]) return ŷ end return yhat