Skip to content
Snippets Groups Projects
utils.jl 1.23 KiB
using CounterfactualExplanations.Parallelization: ThreadsParallelizer
using Distributions: Uniform
using Flux
using LinearAlgebra: norm
using Statistics: mean, std

function is_multi_processed(parallelizer::Union{Nothing,AbstractParallelizer})
    if isnothing(parallelizer) || isa(parallelizer, ThreadsParallelizer)
        return false
    else
        return isa(
            parallelizer,
            Base.get_extension(CounterfactualExplanations, :MPIExt).MPIParallelizer,
        )
    end
end

is_multi_processed(exper::Experiment) = is_multi_processed(exper.parallelizer)

function min_max_scale(x::AbstractArray)
    x_norm = (x .- minimum(x)) ./ (maximum(x) - minimum(x))
    x_norm = replace(x_norm, NaN => 0.0)
    return x_norm
end

function standardize(x::AbstractArray)
    x_norm = (x .- sum(x) / length(x)) ./ std(x)
    x_norm = replace(x_norm, NaN => 0.0)
    return x_norm
end

function get_learning_rate(opt::Flux.Optimise.AbstractOptimiser)
    if hasfield(typeof(opt), :eta)
        return opt.eta
    elseif hasfield(typeof(opt), :os)
        _os = opt.os
        opt = _os[findall([:eta in fieldnames(typeof(o)) for o in _os])][1]
        return opt.eta
    else 
        throw(ArgumentError("Cannot find learning rate."))
    end
end