Skip to content
Snippets Groups Projects
Commit 9c8f92fa authored by pat-alt's avatar pat-alt
Browse files

method extensions setup

parent 57d27a94
No related branches found
No related tags found
No related merge requests found
name: CI
on:
push:
branches:
- main
tags: ['*']
pull_request:
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1.8'
- 'nightly'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
name: CompatHelper
on:
schedule:
- cron: 0 0 * * *
workflow_dispatch:
jobs:
CompatHelper:
runs-on: ubuntu-latest
steps:
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
name: TagBot
on:
issue_comment:
types:
- created
workflow_dispatch:
jobs:
TagBot:
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
runs-on: ubuntu-latest
steps:
- uses: JuliaRegistries/TagBot@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
ssh: ${{ secrets.DOCUMENTER_KEY }}
/.quarto/
/Manifest.toml
\ No newline at end of file
......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "35134a49496d4921bd108806ab39e4d23fb7a6c5"
project_hash = "7dee1a21a38b8ebe5b8194926ba2e4b4df21760e"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......
......@@ -9,9 +9,11 @@ CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat]
......
module CCE
# Write your package code here.
include("model.jl")
include("ConformalGenerator.jl")
end
......@@ -49,14 +49,14 @@ function ConformalGenerator(;
end
"""
gradient_penalty(
set_size_penalty(
generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation,
)
Additional penalty for ConformalGenerator.
"""
function gradient_penalty(
function set_size_penalty(
generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation,
)
......
using ConformalPrediction
using CounterfactualExplanations.Models
using Flux
using MLUtils
using SliceMap
using Statistics
"""
Models.ConformalModel <: AbstractDifferentiableJuliaModel
Constructor for models trained in `Flux.jl`.
"""
struct Models.ConformalModel <: AbstractDifferentiableJuliaModel
model::ConformalPrediction.ConformalProbabilisticSet
likelihood::Symbol
function FluxModel(model, likelihood)
if likelihood [:classification_binary, :classification_multi]
new(model, likelihood)
else
throw(
ArgumentError(
"`type` should be in `[:classification_binary,:classification_multi]`",
),
)
end
end
end
# Outer constructor method:
function Models.ConformalModel(model; likelihood::Symbol=:classification_binary)
Models.ConformalModel(model, likelihood)
end
# Methods
function logits(M::Models.ConformalModel, X::AbstractArray)
return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2))
end
function probs(M::Models.ConformalModel, X::AbstractArray)
if M.likelihood == :classification_binary
output = σ.(logits(M, X))
elseif M.likelihood == :classification_multi
output = softmax(logits(M, X))
end
return output
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment