Skip to content
Snippets Groups Projects
Commit 356d7aa7 authored by pat-alt's avatar pat-alt
Browse files

method extensions setup

parent b141b1f4
No related branches found
No related tags found
No related merge requests found
name: CI
on:
push:
branches:
- main
tags: ['*']
pull_request:
concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1.8'
- 'nightly'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
name: CompatHelper
on:
schedule:
- cron: 0 0 * * *
workflow_dispatch:
jobs:
CompatHelper:
runs-on: ubuntu-latest
steps:
- name: Pkg.add("CompatHelper")
run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- name: CompatHelper.main()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main()'
name: TagBot
on:
issue_comment:
types:
- created
workflow_dispatch:
jobs:
TagBot:
if: github.event_name == 'workflow_dispatch' || github.actor == 'JuliaTagBot'
runs-on: ubuntu-latest
steps:
- uses: JuliaRegistries/TagBot@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
ssh: ${{ secrets.DOCUMENTER_KEY }}
/.quarto/ /.quarto/
/Manifest.toml
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
julia_version = "1.8.5" julia_version = "1.8.5"
manifest_format = "2.0" manifest_format = "2.0"
project_hash = "35134a49496d4921bd108806ab39e4d23fb7a6c5" project_hash = "7dee1a21a38b8ebe5b8194926ba2e4b4df21760e"
[[deps.AbstractFFTs]] [[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"] deps = ["ChainRulesCore", "LinearAlgebra"]
......
...@@ -9,9 +9,11 @@ CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" ...@@ -9,9 +9,11 @@ CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845" MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
[compat] [compat]
......
module CCE module CCE
# Write your package code here. include("model.jl")
include("ConformalGenerator.jl")
end end
...@@ -49,14 +49,14 @@ function ConformalGenerator(; ...@@ -49,14 +49,14 @@ function ConformalGenerator(;
end end
""" """
gradient_penalty( set_size_penalty(
generator::ConformalGenerator, generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation, counterfactual_explanation::AbstractCounterfactualExplanation,
) )
Additional penalty for ConformalGenerator. Additional penalty for ConformalGenerator.
""" """
function gradient_penalty( function set_size_penalty(
generator::ConformalGenerator, generator::ConformalGenerator,
counterfactual_explanation::AbstractCounterfactualExplanation, counterfactual_explanation::AbstractCounterfactualExplanation,
) )
......
using ConformalPrediction
using CounterfactualExplanations.Models
using Flux
using MLUtils
using SliceMap
using Statistics
"""
Models.ConformalModel <: AbstractDifferentiableJuliaModel
Constructor for models trained in `Flux.jl`.
"""
struct Models.ConformalModel <: AbstractDifferentiableJuliaModel
model::ConformalPrediction.ConformalProbabilisticSet
likelihood::Symbol
function FluxModel(model, likelihood)
if likelihood [:classification_binary, :classification_multi]
new(model, likelihood)
else
throw(
ArgumentError(
"`type` should be in `[:classification_binary,:classification_multi]`",
),
)
end
end
end
# Outer constructor method:
function Models.ConformalModel(model; likelihood::Symbol=:classification_binary)
Models.ConformalModel(model, likelihood)
end
# Methods
function logits(M::Models.ConformalModel, X::AbstractArray)
return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2))
end
function probs(M::Models.ConformalModel, X::AbstractArray)
if M.likelihood == :classification_binary
output = σ.(logits(M, X))
elseif M.likelihood == :classification_multi
output = softmax(logits(M, X))
end
return output
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment