diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 0000000000000000000000000000000000000000..0421f56fc401cb895c0039b51cd90aefad61335c --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,35 @@ +name: CI +on: + push: + branches: + - main + tags: ['*'] + pull_request: +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} +jobs: + test: + name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + version: + - '1.8' + - 'nightly' + os: + - ubuntu-latest + arch: + - x64 + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@v1 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + - uses: julia-actions/cache@v1 + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml new file mode 100644 index 0000000000000000000000000000000000000000..cba9134c670f0708cf98c92f7fdef055a6c7f5d3 --- /dev/null +++ b/.github/workflows/CompatHelper.yml @@ -0,0 +1,16 @@ +name: CompatHelper +on: + schedule: + - cron: 0 0 * * * + workflow_dispatch: +jobs: + CompatHelper: + runs-on: ubuntu-latest + steps: + - name: Pkg.add("CompatHelper") + run: julia -e 'using Pkg; Pkg.add("CompatHelper")' + - name: CompatHelper.main() + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} + run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/.github/workflows/TagBot.yml b/.github/workflows/TagBot.yml new file mode 100644 index 0000000000000000000000000000000000000000..f49313b662013f43aac7de2c738e1163a9715ff4 --- /dev/null +++ b/.github/workflows/TagBot.yml @@ -0,0 +1,15 @@ +name: TagBot +on: + issue_comment: + types: + - created + workflow_dispatch: +jobs: + TagBot: + if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot' + runs-on: ubuntu-latest + steps: + - uses: JuliaRegistries/TagBot@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} + ssh: ${{ secrets.DOCUMENTER_KEY }} diff --git a/.gitignore b/.gitignore index 075b2542afb820ca0c990f02a196dfbb35c41a3a..fca7ef8844ea952e9d11059bbaa1f8d379539e24 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ /.quarto/ + +/Manifest.toml \ No newline at end of file diff --git a/Manifest.toml b/Manifest.toml index a069a7cb011d5a5e0262b73f2d294fe47e6ab9a8..f445e33da81b4a0f39f5402beb88bcb6fec7654d 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "35134a49496d4921bd108806ab39e4d23fb7a6c5" +project_hash = "7dee1a21a38b8ebe5b8194926ba2e4b4df21760e" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] diff --git a/Project.toml b/Project.toml index a653749ff9751d23b839584f5e078c7e56149e2c..d6df84732c414063cc6b317e040cc4929c292518 100644 --- a/Project.toml +++ b/Project.toml @@ -9,9 +9,11 @@ CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] diff --git a/src/CCE.jl b/src/CCE.jl index 8932e48f5446f20fdc03a8ef616d68e974618620..58f70b9cdc1699e50aa93f9a2a413a221ed405e5 100644 --- a/src/CCE.jl +++ b/src/CCE.jl @@ -1,5 +1,6 @@ module CCE -# Write your package code here. +include("model.jl") +include("ConformalGenerator.jl") end diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl index 88eb862bfde6075ed38d3bd980e08c2c6b57bbcd..c772b69a64248905d8531fabf67383f26af41ed9 100644 --- a/src/ConformalGenerator.jl +++ b/src/ConformalGenerator.jl @@ -49,14 +49,14 @@ function ConformalGenerator(; end """ - gradient_penalty( + set_size_penalty( generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation, ) Additional penalty for ConformalGenerator. """ -function gradient_penalty( +function set_size_penalty( generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation, ) diff --git a/src/model.jl b/src/model.jl new file mode 100644 index 0000000000000000000000000000000000000000..1a1d3d815de55ef4bb6f14d8ff6d8c4f328623cf --- /dev/null +++ b/src/model.jl @@ -0,0 +1,46 @@ +using ConformalPrediction +using CounterfactualExplanations.Models +using Flux +using MLUtils +using SliceMap +using Statistics + +""" + Models.ConformalModel <: AbstractDifferentiableJuliaModel + +Constructor for models trained in `Flux.jl`. +""" +struct Models.ConformalModel <: AbstractDifferentiableJuliaModel + model::ConformalPrediction.ConformalProbabilisticSet + likelihood::Symbol + function FluxModel(model, likelihood) + if likelihood ∈ [:classification_binary, :classification_multi] + new(model, likelihood) + else + throw( + ArgumentError( + "`type` should be in `[:classification_binary,:classification_multi]`", + ), + ) + end + end +end + +# Outer constructor method: +function Models.ConformalModel(model; likelihood::Symbol=:classification_binary) + Models.ConformalModel(model, likelihood) +end + +# Methods +function logits(M::Models.ConformalModel, X::AbstractArray) + return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2)) +end + +function probs(M::Models.ConformalModel, X::AbstractArray) + if M.likelihood == :classification_binary + output = σ.(logits(M, X)) + elseif M.likelihood == :classification_multi + output = softmax(logits(M, X)) + end + return output +end \ No newline at end of file