diff --git a/artifacts/results/mnist_architectures.jls b/artifacts/results/mnist_architectures.jls index c227bf2cec8ce4c773d4a25f0e73440f03486b4c..026c750b06fc065e4216ac565fef8e5435c58644 100644 Binary files a/artifacts/results/mnist_architectures.jls and b/artifacts/results/mnist_architectures.jls differ diff --git a/artifacts/results/mnist_model_performance.csv b/artifacts/results/mnist_model_performance.csv index 574a21ce7d69e0b205f395e8be9186d24f1f05ee..44d942aa14f821ec89456ed42a1e53a3df523abe 100644 --- a/artifacts/results/mnist_model_performance.csv +++ b/artifacts/results/mnist_model_performance.csv @@ -1,5 +1,5 @@ acc,precision,f1score,mod_name,dataname -0.8985,0.8986458930056993,0.8970545599029716,JEM Ensemble,MNIST -0.9402,0.9404816585665723,0.9395994247759306,MLP,MNIST -0.9423,0.9418414616466535,0.9417117851276802,MLP Ensemble,MNIST -0.8054,0.8287830565193942,0.8039356667514831,JEM,MNIST +0.8962,0.8976935889149353,0.8951049347409266,JEM Ensemble,MNIST +0.9311,0.9338952205679255,0.9304980774275494,MLP,MNIST +0.9422,0.9415410195617474,0.9414792949387385,MLP Ensemble,MNIST +0.8170999999999999,0.8342846736086256,0.817181397488819,JEM,MNIST diff --git a/artifacts/results/mnist_model_performance.jls b/artifacts/results/mnist_model_performance.jls index 6d60c57129ae9269ccb884811e21d2240cb9100a..86b9818d85aa498dbb69f224537a2c5eb8699bd6 100644 Binary files a/artifacts/results/mnist_model_performance.jls and b/artifacts/results/mnist_model_performance.jls differ diff --git a/artifacts/results/mnist_models.jls b/artifacts/results/mnist_models.jls index 8c4687d10216f753591d636241c5eef485878d00..ba8d435a72a913b7b25ea3220c3e7c44b09ada19 100644 Binary files a/artifacts/results/mnist_models.jls and b/artifacts/results/mnist_models.jls differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index 3537f90018d2104bc168f3b6d848f7030487fcaa..b09419fb370b53a6a76a4abb4e92c74998d2f547 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 96a8d445d9ec1ad815bf56539ae121d82f7e742c..a2f53c7a30540d32a04b0f13a7985334afeb6ae0 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl new file mode 100644 index 0000000000000000000000000000000000000000..164e54c3997cdb6ad19b944b112e912c9c722bc2 --- /dev/null +++ b/experiments/benchmarking/benchmarking.jl @@ -0,0 +1,93 @@ +"The default benchmarking measures." +const default_measures = [ + CounterfactualExplanations.distance, + ECCCo.distance_from_energy, + ECCCo.distance_from_targets, + CounterfactualExplanations.Evaluation.validity, + CounterfactualExplanations.Evaluation.redundancy, + ECCCo.set_size_penalty +] + +function default_generators( + Λ::AbstractArray=[0.25, 0.75, 0.75], + Λ_Δ::AbstractArray=[Λ[1], Λ[2], 4.0], + use_variants::Bool=true, + use_class_loss::Bool=false, + opt=Flux.Optimise.Descent(0.01), +) + + @info "Begin benchmarking counterfactual explanations." + λâ‚, λ₂, λ₃ = Λ + λâ‚_Δ, λ₂_Δ, λ₃_Δ = Λ_Δ + + if use_variants + generator_dict = Dict( + "Wachter" => WachterGenerator(λ=λâ‚, opt=opt), + "REVISE" => REVISEGenerator(λ=λâ‚, opt=opt), + "Schut" => GreedyGenerator(), + "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss), + "ECCCo (no CP)" => ECCCoGenerator(λ=[λâ‚, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss), + "ECCCo (no EBM)" => ECCCoGenerator(λ=[λâ‚, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss), + "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true), + "ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λâ‚_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true), + "ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λâ‚_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true), + ) + else + generator_dict = Dict( + "Wachter" => WachterGenerator(λ=λâ‚, opt=opt), + "REVISE" => REVISEGenerator(λ=λâ‚, opt=opt), + "Schut" => GreedyGenerator(), + "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss), + "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true), + ) + end + return generator_dict +end + +""" + run_benchmark( + generators::Union{Nothing, Dict}=nothing, + measures::AbstractArray=default_measures, + ) + +Run the benchmarking procedure. +""" +function run_benchmark(; + n_individuals::Int, + dataname::String, + counterfactual_data::CounterfactualData, + model_dict::Dict, + generators::Union{Nothing, Dict}=nothing, + measures::AbstractArray=default_measures, +) + # Benchmark generators: + if isnothing(generators) + generator_dict = default_generators() + end + + # Run benchmark: + bmks = [] + labels = counterfactual_data.output_encoder.labels + for target in sort(unique(labels)) + for factual in sort(unique(labels)) + if factual == target + continue + end + bmk = benchmark( + counterfactual_data; + models=model_dict, + generators=generator_dict, + measure=measures, + suppress_training=true, dataname=dataname, + n_individuals=n_individuals, + target=target, factual=factual, + initialization=:identity, + converge_when=:generator_conditions + ) + push!(bmks, bmk) + end + end + bmk = reduce(vcat, bmks) + return bmk, generator_dict +end + diff --git a/experiments/data/data.jl b/experiments/data/data.jl new file mode 100644 index 0000000000000000000000000000000000000000..47101aabaf85c319410e32ae23bbaa959de99c95 --- /dev/null +++ b/experiments/data/data.jl @@ -0,0 +1,26 @@ +function prepare_data( + counterfactual_data::CounterfactualData; + ð’Ÿx=Normal(), + min_batch_size=128, + sampling_batch_size=50, +) + X, _ = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) + X = table(permutedims(X)) + labels = counterfactual_data.output_encoder.labels + input_dim, n_obs = size(counterfactual_data.X) + output_dim = length(unique(labels)) + save_name = replace(lowercase(dataname), " " => "_") + + # Model parameters: + batch_size = minimum([Int(round(n_obs / 10)), min_batch_size]) + sampling_batch_size = isnothing(sampling_batch_size) ? batch_size : sampling_batch_size + + # JEM parameters: + ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) + sampler = ConditionalSampler( + ð’Ÿx, ð’Ÿy, + input_size=(input_dim,), + batch_size=sampling_batch_size, + ) + return X, labels, n_obs, save_name, batch_size, sampler +end \ No newline at end of file diff --git a/experiments/mnist.jl b/experiments/mnist.jl index b1847303aac71f4822abdcb43dcaf55c24b45d43..29be88af49d4e531db9c21bf0f83e4a607b3db1e 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -17,6 +17,8 @@ counterfactual_data.generative_model = vae # Test data: test_data = load_mnist_test() +# Models: + # Generators: eccco_generator = ECCCoGenerator( λ=[0.1,0.25,0.25], diff --git a/experiments/models/additional_models.jl b/experiments/models/additional_models.jl new file mode 100644 index 0000000000000000000000000000000000000000..4a3f2824f69205837a6dbdc2ec3e24a26d86d014 --- /dev/null +++ b/experiments/models/additional_models.jl @@ -0,0 +1,50 @@ +""" + LeNetBuilder + +MLJFlux builder for a LeNet-like convolutional neural network. +""" +mutable struct LeNetBuilder + filter_size::Int + channels1::Int + channels2::Int +end + +""" + MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out) + +Overloads the MLJFlux build function for a LeNet-like convolutional neural network. +""" +function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out) + + # Setup: + _n_in = Int(sqrt(n_in)) + k, c1, c2 = b.filter_size, b.channels1, b.channels2 + mod(k, 2) == 1 || error("`filter_size` must be odd. ") + p = div(k - 1, 2) # padding to preserve image size on convolution: + preproc(x) = reshape(x, (_n_in, _n_in, 1, :)) + + # Model: + front = Flux.Chain( + Conv((k, k), 1 => c1, pad=(p, p), relu), + MaxPool((2, 2)), + Conv((k, k), c1 => c2, pad=(p, p), relu), + MaxPool((2, 2)), + Flux.flatten + ) + d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first + back = Flux.Chain( + Dense(d, 120, relu), + Dense(120, 84, relu), + Dense(84, n_out), + ) + chain = Flux.Chain(preproc, front, back) + + return chain +end + +""" + lenet5(builder=LeNetBuilder(5, 6, 16); kwargs...) + +Builds a LeNet-like convolutional neural network. +""" +lenet5(builder=LeNetBuilder(5, 6, 16); kwargs...) = NeuralNetworkClassifier(builder=builder; kwargs...) \ No newline at end of file diff --git a/experiments/models/default_models.jl b/experiments/models/default_models.jl new file mode 100644 index 0000000000000000000000000000000000000000..4b87ebce15cc741c84a5a720910f095e8c36727c --- /dev/null +++ b/experiments/models/default_models.jl @@ -0,0 +1,90 @@ +""" + default_builder(n_hidden::Int=16, activation::Function=Flux.swish) + +Default builder for MLPs. +""" +function default_builder(n_hidden::Int=16, activation::Function=Flux.swish) + builder = MLJFlux.MLP( + hidden=(n_hidden, n_hidden, n_hidden), + σ=activation + ) + return builder +end + +""" + default_models( + builder::GenericBuilder=default_builder(), + epochs::Int=25, + batch_size::Int=128, + finaliser::Function=Flux.softmax, + loss::Function=Flux.Losses.crossentropy, + sampler::AbstractSampler, + α::Float64, + verbosity::Int=10, + sampling_steps::Int=30, + n_ens::Int=5, + use_ensembling::Bool=true, + ) + +Builds a dictionary of default models for training. +""" +function default_models(; + sampler::AbstractSampler, + builder::MLJFlux.GenericBuilder=default_builder(), + epochs::Int=25, + batch_size::Int=128, + finaliser::Function=Flux.softmax, + loss::Function=Flux.Losses.crossentropy, + α::AbstractArray=[1.0, 1.0, 1e-1], + verbosity::Int=10, + sampling_steps::Int=30, + n_ens::Int=5, + use_ensembling::Bool=true, +) + + # Simple MLP: + mlp = NeuralNetworkClassifier( + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=finaliser, + loss=loss, + ) + + # Deep Ensemble: + mlp_ens = EnsembleModel(model=mlp, n=n_ens) + + # Joint Energy Model: + jem = JointEnergyClassifier( + sampler; + builder=builder, + epochs=epochs, + batch_size=batch_size, + finaliser=finaliser, + loss=loss, + jem_training_params=( + α=α, verbosity=verbosity, + ), + sampling_steps=sampling_steps + ) + + # Deep Ensemble of Joint Energy Models: + jem_ens = EnsembleModel(model=jem, n=n_ens) + + # Dictionary of models: + if !use_ensembling + models = Dict( + "MLP" => mlp, + "JEM" => jem, + ) + else + models = Dict( + "MLP" => mlp, + "MLP Ensemble" => mlp_ens, + "JEM" => jem, + "JEM Ensemble" => jem_ens, + ) + end + + return models +end \ No newline at end of file diff --git a/experiments/models/mnist.jl b/experiments/models/mnist.jl new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/experiments/models/models.jl b/experiments/models/models.jl new file mode 100644 index 0000000000000000000000000000000000000000..7adbaa0e00a8f05bb47852800036948d19fcb2b4 --- /dev/null +++ b/experiments/models/models.jl @@ -0,0 +1,3 @@ +include("additional_models.jl") +include("default_models.jl") +include("train_models.jl") \ No newline at end of file diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl new file mode 100644 index 0000000000000000000000000000000000000000..429dbe64a8bea57cbfbd4ec5f56a7e0d2a2b3df6 --- /dev/null +++ b/experiments/models/train_models.jl @@ -0,0 +1,40 @@ +""" + train_models(models::Dict) + +Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects. +""" +function train_models(models::Dict, X, y; kwargs...) + model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models) + return model_dict +end + +""" + _train( + model::AbstractModel, + X::AbstractMatrix, + y::AbstractVector, + cov::Float64=0.95, + method::Symbol=:simple_inductive, + mod_name::String="model" + ) + +Trains a model and returns a `ConformalModel` object. +""" +function _train(model, X, y; cov=coverage, method=:simple_inductive, mod_name="model") + conf_model = conformal_model(model; method=method, coverage=cov) + mach = machine(conf_model, X, y) + @info "Begin training $mod_name." + fit!(mach) + @info "Finished training $mod_name." + M = ECCCo.ConformalModel(mach.model, mach.fitresult) + return M +end + +""" + save_models(model_dict::Dict; save_name::String, output_path) + +Helper function to save models. +""" +function save_models(model_dict::Dict; save_name::String, output_path) + Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict) +end \ No newline at end of file diff --git a/experiments/setup.jl b/experiments/setup.jl index 0f5bde9ffbc47225466e47d0ea0a9035206a50ae..c58235cfd5bac452df491676a09aa27ab81d8959 100644 --- a/experiments/setup.jl +++ b/experiments/setup.jl @@ -9,12 +9,24 @@ isdir(params_path) || mkdir(params_path) @info "All parameter choices will be saved to $params_path." test_size = 0.2 +# Constants: +if ENV["RETRAIN"] == "true" + const RETRAIN = true +else + const RETRAIN = false +end + # Artifacts: using LazyArtifacts @warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions." artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5") pretrained_path = joinpath(artifact_path, "results") +# Scripts: +include("models/models.jl") +include("benchmarking/benchmarking.jl") +include("data/data.jl") + function run_experiment( counterfactual_data, test_data; @@ -22,115 +34,42 @@ function run_experiment( output_path=output_path, params_path=params_path, pretrained_path=pretrained_path, - retrain=false, - epochs=100, - n_hidden=16, - activation=Flux.swish, - builder=MLJFlux.MLP( - hidden=(n_hidden, n_hidden, n_hidden), - σ=activation - ), - n_ens=5, + pretrained=true, + models::Union{Nothing, Dict}=nothing, + builder::Union{Nothing, MLJFlux.GenericBuilder}=nothing, ð’Ÿx=Normal(), sampling_batch_size=50, - α=[1.0, 1.0, 1e-1], - verbosity=10, - sampling_steps=30, - use_ensembling=false, coverage=.95, - λâ‚=0.25, - λ₂ = 0.75, - λ₃ = 0.75, - opt=Flux.Optimise.Descent(0.01), - use_class_loss=false, - use_variants=true, - n_individuals=25, generators=nothing, + n_individuals=50, ) # SETUP ---------- # Data - X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) - X = table(permutedims(X)) - labels = counterfactual_data.output_encoder.labels - input_dim, n_obs = size(counterfactual_data.X) - output_dim = length(unique(labels)) - save_name = replace(lowercase(dataname), " " => "_") - - # Model parameters: - batch_size = minimum([Int(round(n_obs / 10)), 128]) - sampling_batch_size = isnothing(sampling_batch_size) ? batch_size : sampling_batch_size - _loss = Flux.Losses.crossentropy # loss function - _finaliser = Flux.softmax # finaliser function - - # JEM parameters: - ð’Ÿy = Categorical(ones(output_dim) ./ output_dim) - sampler = ConditionalSampler( - ð’Ÿx, ð’Ÿy, - input_size=(input_dim,), - batch_size=sampling_batch_size, + X, labels, n_obs, save_name, batch_size, sampler = prepare_data( + counterfactual_data; + ð’Ÿx=ð’Ÿx, + sampling_batch_size=sampling_batch_size, ) # MODELS ---------- - - # Simple MLP: - mlp = NeuralNetworkClassifier( - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, - ) - - # Deep Ensemble: - mlp_ens = EnsembleModel(model=mlp, n=n_ens) - - # Joint Energy Model: - jem = JointEnergyClassifier( - sampler; - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, - jem_training_params=( - α=α, verbosity=verbosity, - ), - sampling_steps=sampling_steps - ) - - # Deep Ensemble of Joint Energy Models: - jem_ens = EnsembleModel(model=jem, n=n_ens) - - # Dictionary of models: - if !use_ensembling - models = Dict( - "MLP" => mlp, - "JEM" => jem, - ) - else - models = Dict( - "MLP" => mlp, - "MLP Ensemble" => mlp_ens, - "JEM" => jem, - "JEM Ensemble" => jem_ens, + if isnothing(builder) + builder = default_builder() + end + if isnothing(models) + @info "Using default models." + models = default_models(; + sampler=sampler, + builder=builder, + batch_size=batch_size, ) end # TRAINING ---------- - function _train(model, X=X, y=labels; cov=coverage, method=:simple_inductive, mod_name="model") - conf_model = conformal_model(model; method=method, coverage=cov) - mach = machine(conf_model, X, y) - @info "Begin training $mod_name." - fit!(mach) - @info "Finished training $mod_name." - M = ECCCo.ConformalModel(mach.model, mach.fitresult) - return M - end - if retrain - @info "Retraining models." - model_dict = Dict(mod_name => _train(model; mod_name=mod_name) for (mod_name, model) in models) + if !pretrained + @info "Training models." + model_dict = train_models(models, X, labels; coverage=coverage) Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict) else @info "Loading pre-trained models." @@ -174,74 +113,26 @@ function run_experiment( println(model_performance) # COUNTERFACTUALS ---------- - - @info "Begin benchmarking counterfactual explanations." - Λ = [λâ‚, λ₂, λ₃] + # Benchmark generators: + bmk, generator_dict = run_benchmark(; + n_individuals=n_individuals, + dataname=dataname, + counterfactual_data=counterfactual_data, + model_dict=model_dict, + generators=generators, + measures=measures, + ) + # Output: + opt = first(values(generator_dict)).opt generator_params = DataFrame( Dict( - :λ1 => λâ‚, - :λ2 => λ₂, - :λ3 => λ₃, :opt => string(typeof(opt)), :eta => opt.eta, :dataname => dataname, ) ) CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params) - - # Benchmark generators: - if !isnothing(generators) - generator_dict = generators - elseif use_variants - generator_dict = Dict( - "Wachter" => WachterGenerator(λ=λâ‚, opt=opt), - "REVISE" => REVISEGenerator(λ=λâ‚, opt=opt), - "Schut" => GreedyGenerator(), - "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss), - "ECCCo (no CP)" => ECCCoGenerator(λ=[λâ‚, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss), - "ECCCo (no EBM)" => ECCCoGenerator(λ=[λâ‚, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss), - ) - else - generator_dict = Dict( - "Wachter" => WachterGenerator(λ=λâ‚, opt=opt), - "REVISE" => REVISEGenerator(λ=λâ‚, opt=opt), - "Schut" => GreedyGenerator(), - "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss), - ) - end - - # Measures: - measures = [ - CounterfactualExplanations.distance, - ECCCo.distance_from_energy, - ECCCo.distance_from_targets, - CounterfactualExplanations.Evaluation.validity, - CounterfactualExplanations.Evaluation.redundancy, - ECCCo.set_size_penalty - ] - - bmks = [] - for target in sort(unique(labels)) - for factual in sort(unique(labels)) - if factual == target - continue - end - bmk = benchmark( - counterfactual_data; - models=model_dict, - generators=generator_dict, - measure=measures, - suppress_training=true, dataname=dataname, - n_individuals=n_individuals, - target=target, factual=factual, - initialization=:identity, - converge_when=:generator_conditions - ) - push!(bmks, bmk) - end - end - bmk = reduce(vcat, bmks) CSV.write(joinpath(output_path, "$(save_name)_benchmark.csv"), bmk()) end \ No newline at end of file diff --git a/notebooks/.CondaPkg/env/conda-meta/history b/notebooks/.CondaPkg/env/conda-meta/history index 6c7683b9e57e48bee3f8f60a08c00c9a8f1c2d8b..5637f50087d0ff3b125f0cd66ac6371249d198ae 100644 --- a/notebooks/.CondaPkg/env/conda-meta/history +++ b/notebooks/.CondaPkg/env/conda-meta/history @@ -1,4 +1,4 @@ -==> 2023-08-10 11:22:12 <== +==> 2023-08-17 12:26:42 <== # cmd: /Users/FA31DU/.julia/artifacts/6ecf04294c7f327e02e84972f34835649a5eb35e/bin/micromamba -r /Users/FA31DU/.julia/scratchspaces/0b3b1443-0f03-428d-bdfb-f27f9c1191ea/root create -y -p /Users/FA31DU/code/ECCCo.jl/notebooks/.CondaPkg/env --override-channels --no-channel-priority numpy[version='*'] pip[version='>=22.0.0'] python[version='>=3.7,<4',channel='conda-forge',build='*cpython*'] -c conda-forge # conda version: 3.8.0 +https://conda.anaconda.org/conda-forge/osx-64::xz-5.2.6-h775f41a_0 diff --git a/notebooks/.CondaPkg/meta b/notebooks/.CondaPkg/meta index 6b7ba5c59ae065f773611edc2fc41659836fb772..dbc26a2ad2201143acc917b9818eb94a6520946b 100644 Binary files a/notebooks/.CondaPkg/meta and b/notebooks/.CondaPkg/meta differ diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 8622af582d97144b077cc214f0cc19273882ecd3..4872255face583f3bcafdde699c16be1369d1c6c 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -4,6 +4,11 @@ julia_version = "1.9.0" manifest_format = "2.0" project_hash = "5dac8702e1bf52ac1887686257c409c28f8872ae" +[[deps.ANSIColoredPrinters]] +git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" +uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" +version = "0.0.1" + [[deps.ARFFFiles]] deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" @@ -118,10 +123,10 @@ uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" version = "0.1.0" [[deps.Automa]] -deps = ["ScanByte", "TranscodingStreams"] -git-tree-sha1 = "48e54446df62fdf9ef76959c32dc33f3cff659ee" +deps = ["TranscodingStreams"] +git-tree-sha1 = "ef9997b3d5547c48b41c7bd8899e812a917b409d" uuid = "67c07d97-cdcb-5c2c-af73-a7f9c32a568b" -version = "0.8.3" +version = "0.8.4" [[deps.AxisAlgorithms]] deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] @@ -186,9 +191,9 @@ uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" version = "0.1.5" [[deps.BufferedStreams]] -git-tree-sha1 = "5bcb75a2979e40b29eb250cb26daab67aa8f97f5" +git-tree-sha1 = "4ae47f9a4b1dc19897d3743ff13685925c5202ec" uuid = "e1450e63-4bb3-523b-b2a4-4ffa8c0fd77d" -version = "1.2.0" +version = "1.2.1" [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -272,9 +277,9 @@ version = "1.0.5" [[deps.CairoMakie]] deps = ["Base64", "Cairo", "Colors", "FFTW", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "PrecompileTools", "SHA"] -git-tree-sha1 = "e041782fed7614b1726fa250f2bf24fd5c789689" +git-tree-sha1 = "30562a68ded3dabe80109caf6b4de73a48ac27bc" uuid = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -version = "0.10.7" +version = "0.10.8" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] @@ -368,9 +373,9 @@ version = "0.15.4" [[deps.CodeTracking]] deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "8dd599a2fdbf3132d4c0be3a016f8f1518e28fa8" +git-tree-sha1 = "a1296f0fe01a4c3f9bf0dc2934efbf4416f5db31" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.3.2" +version = "1.3.4" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] @@ -386,9 +391,9 @@ version = "0.4.0" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "dd3000d954d483c1aad05fe1eb9e6a715c97013e" +git-tree-sha1 = "d9a8f86737b665e15a9641ecbac64deef9ce6724" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.22.0" +version = "3.23.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -435,12 +440,6 @@ weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] CompatLinearAlgebraExt = "LinearAlgebra" -[[deps.CompatHelperLocal]] -deps = ["DocStringExtensions", "Pkg", "UUIDs"] -git-tree-sha1 = "be25ab802a22a212ce4da944fe60d7c250ddcfe1" -uuid = "5224ae11-6099-4aaa-941d-3aab004bd678" -version = "0.1.25" - [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" @@ -481,10 +480,10 @@ uuid = "992eb4ea-22a4-4c89-a5bb-47a3300528ab" version = "0.2.18" [[deps.ConformalPrediction]] -deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "Plots", "ProgressMeter", "Random", "StatsBase", "Tables"] +deps = ["CategoricalArrays", "ChainRules", "ComputationalResources", "Flux", "LazyArtifacts", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlux", "MLJModelInterface", "MLUtils", "NaturalSort", "Plots", "ProgressMeter", "Random", "Serialization", "StatsBase", "Tables"] path = "../../ConformalPrediction.jl" uuid = "98bfc277-1877-43dc-819b-a3e38c30242f" -version = "0.1.7" +version = "0.1.8" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] @@ -515,10 +514,10 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298" version = "0.6.3" [[deps.CounterfactualExplanations]] -deps = ["CSV", "CUDA", "CategoricalArrays", "Chain", "ChainRulesCore", "CompatHelperLocal", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "JuliaFormatter", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "PythonCall", "RCall", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] +deps = ["CSV", "CUDA", "CategoricalArrays", "Chain", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "JuliaFormatter", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "PythonCall", "RCall", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"] path = "../../CounterfactualExplanations.jl" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.13" +version = "0.1.14" [[deps.CpuId]] deps = ["Markdown"] @@ -629,9 +628,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "27a18994a5991b1d2e2af7833c4f8ecf9af6b9ea" +git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.99" +version = "0.25.100" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -647,6 +646,12 @@ git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" version = "0.9.3" +[[deps.Documenter]] +deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] +git-tree-sha1 = "39fd748a73dce4c05a9655475e437170d8fb1b67" +uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +version = "0.27.25" + [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" @@ -688,9 +693,9 @@ version = "0.5.2" [[deps.EvoTrees]] deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "1b418518c0eb1fd1ef0a6d0bfc8051e6abb1232b" +git-tree-sha1 = "5023442c1f797c0fd6677b1a1886ab44f43f3378" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.15.2" +version = "0.16.0" [[deps.ExactPredicates]] deps = ["IntervalArithmetic", "Random", "StaticArraysCore", "Test"] @@ -839,9 +844,9 @@ version = "0.4.2" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.35" +version = "0.10.36" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] @@ -924,10 +929,10 @@ uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" version = "0.72.8+0" [[deps.GZip]] -deps = ["Libdl"] -git-tree-sha1 = "039be665faf0b8ae36e089cd694233f5dee3f7d6" +deps = ["Libdl", "Zlib_jll"] +git-tree-sha1 = "e062176cedd5fc4551eb15b956cbf433c246ce3f" uuid = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" -version = "0.5.1" +version = "0.6.0" [[deps.GeoInterface]] deps = ["Extents"] @@ -978,9 +983,9 @@ version = "1.8.0" [[deps.GridLayoutBase]] deps = ["GeometryBasics", "InteractiveUtils", "Observables"] -git-tree-sha1 = "678d136003ed5bceaab05cf64519e3f956ffa4ba" +git-tree-sha1 = "f57a64794b336d4990d90f80b147474b869b1bc4" uuid = "3955a311-db13-416c-9275-1d80ed98e5e9" -version = "0.9.1" +version = "0.9.2" [[deps.Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" @@ -1025,9 +1030,9 @@ version = "0.3.1" [[deps.HostCPUFeatures]] deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"] -git-tree-sha1 = "d38bd0d9759e3c6cfa19bdccc314eccf8ce596cc" +git-tree-sha1 = "eb8fed28f4994600e29beef49744639d985a04b2" uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" -version = "0.1.15" +version = "0.1.16" [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] @@ -1035,6 +1040,12 @@ git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" version = "0.3.23" +[[deps.IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.3" + [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" @@ -1199,9 +1210,9 @@ version = "0.1.5" [[deps.IntelOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83" +git-tree-sha1 = "ad37c091f7d7daf900963171600d7c1c5c3ede32" uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2023.1.0+0" +version = "2023.2.0+0" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -1280,10 +1291,10 @@ uuid = "1019f520-868f-41f5-a6de-eb00f4b6a39c" version = "0.1.5" [[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "a7e91ef94114d5bc8952bcaa8d6ff952cf709808" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" +version = "1.4.2" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -1317,9 +1328,9 @@ version = "2.1.91+0" [[deps.JuliaFormatter]] deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"] -git-tree-sha1 = "60567b51bd9e1e19ae2fd8a54dcd6bc5994727f0" +git-tree-sha1 = "680fb31c8b8e2cf482f48e55d8fa01ccc4469e04" uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" -version = "1.0.34" +version = "1.0.35" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -1602,15 +1613,15 @@ version = "0.10.5" [[deps.MKL_jll]] deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "154d7aaa82d24db6d8f7e4ffcfe596f40bff214b" +git-tree-sha1 = "eb006abbd7041c28e0d16260e50a24f8f9104913" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2023.1.0+0" +version = "2023.2.0+0" [[deps.MLDatasets]] deps = ["CSV", "Chemfiles", "DataDeps", "DataFrames", "DelimitedFiles", "FileIO", "FixedPointNumbers", "GZip", "Glob", "HDF5", "ImageShow", "JLD2", "JSON3", "LazyModules", "MAT", "MLUtils", "NPZ", "Pickle", "Printf", "Requires", "SparseArrays", "Statistics", "Tables"] -git-tree-sha1 = "a03a093b03824f07fe00931df76b18d99398ebb9" +git-tree-sha1 = "41922968c0aaca46baa5d658d3a173828313e2d0" uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458" -version = "0.7.11" +version = "0.7.12" [[deps.MLJ]] deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "Tables"] @@ -1650,9 +1661,9 @@ version = "0.5.1" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "c8b7e632d6754a5e36c0d94a4b466a5ba3a30128" +git-tree-sha1 = "e89d1ea12c5a50057bfb0c124d905669e5ed4ec9" uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.8.0" +version = "1.9.1" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] @@ -1685,15 +1696,15 @@ version = "0.5.10" [[deps.Makie]] deps = ["Animations", "Base64", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "Match", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Setfield", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"] -git-tree-sha1 = "729640354756782c89adba8857085a69e19be7ab" +git-tree-sha1 = "e81675589ba7199a82443e87fc52e17eeceac2e8" uuid = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" -version = "0.19.7" +version = "0.19.8" [[deps.MakieCore]] deps = ["Observables"] -git-tree-sha1 = "87a85ff81583bd392642869557cb633532989517" +git-tree-sha1 = "f56b09c8b964919373d61750c6d8d4d2c602a2be" uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b" -version = "0.6.4" +version = "0.6.5" [[deps.ManualMemory]] git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" @@ -1709,6 +1720,12 @@ version = "0.4.2" deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[deps.MarketData]] +deps = ["CSV", "Dates", "HTTP", "JSON3", "Random", "Reexport", "TimeSeries"] +git-tree-sha1 = "715536b6af6292883128e22857c83291e30fea25" +uuid = "945b72a4-3b13-509d-9b46-1525bb5c06de" +version = "0.13.12" + [[deps.Match]] git-tree-sha1 = "1d9bc5c1a6e7ee24effb93f175c9342f9154d97f" uuid = "7eb4fadd-790c-5f42-8a69-bfa0b872bfbf" @@ -1970,9 +1987,9 @@ version = "0.5.5+0" [[deps.Optim]] deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] -git-tree-sha1 = "e3a6546c1577bfd701771b477b794a52949e7594" +git-tree-sha1 = "963b004d15216f8129f6c0f7d187efa136570be0" uuid = "429524aa-4258-5aef-a3af-852621145aeb" -version = "1.7.6" +version = "1.7.7" [[deps.OptimBase]] deps = ["NLSolversBase", "Printf", "Reexport"] @@ -2026,6 +2043,12 @@ git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" version = "0.5.12" +[[deps.PalmerPenguins]] +deps = ["CSV", "DataDeps"] +git-tree-sha1 = "e7c581b0e29f7d35f47927d65d4965b413c10d90" +uuid = "8b842266-38fa-440a-9b57-31493939ab85" +version = "0.1.4" + [[deps.Pango_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "84a314e3926ba9ec66ac097e3635e270986b0f10" @@ -2080,9 +2103,9 @@ version = "1.9.0" [[deps.PkgTemplates]] deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"] -git-tree-sha1 = "8cd6b2e37d6fa76dff55d8832be2f3ca3b73ae56" +git-tree-sha1 = "82186fe066cbdc3a25ed9247ea709da73c52e941" uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" -version = "0.7.38" +version = "0.7.40" [[deps.PkgVersion]] deps = ["Pkg"] @@ -2243,9 +2266,9 @@ version = "0.7.4" [[deps.RCall]] deps = ["CategoricalArrays", "Conda", "DataFrames", "DataStructures", "Dates", "Libdl", "Missings", "REPL", "Random", "Requires", "StatsModels", "WinReg"] -git-tree-sha1 = "d441bdeea943f8e8f293e0e3a78fe2d7c3aa24e6" +git-tree-sha1 = "d9310ed05c2ff94c4e3a545a0e4c58ed36496179" uuid = "6f49c342-dc21-5d91-9882-a32aef131414" -version = "0.13.15" +version = "0.13.16" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] @@ -2362,12 +2385,6 @@ version = "0.2.1" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" -[[deps.SIMD]] -deps = ["PrecompileTools"] -git-tree-sha1 = "0e270732477b9e551d884e6b07e23bb2ec947790" -uuid = "fdea26ae-647d-5447-a871-4b548cad5224" -version = "3.4.5" - [[deps.SIMDTypes]] git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" uuid = "94e857df-77ce-4151-89e5-788b33177be4" @@ -2379,12 +2396,6 @@ git-tree-sha1 = "4b8586aece42bee682399c4c4aee95446aa5cd19" uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" version = "0.6.39" -[[deps.ScanByte]] -deps = ["Libdl", "SIMD"] -git-tree-sha1 = "d49e35f413186528f1d7cc675e67d0ed16fd7800" -uuid = "7b38b023-a4d7-4c5e-8d43-3f3097f304eb" -version = "0.4.0" - [[deps.ScientificTypes]] deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] git-tree-sha1 = "75ccd10ca65b939dab03b812994e571bf1e3e1da" @@ -2479,9 +2490,9 @@ version = "0.3.0" [[deps.SimplePolynomials]] deps = ["Mods", "Multisets", "Polynomials", "Primes"] -git-tree-sha1 = "ac7b9bd0d2d2ee86e9c7016fb76ff7c1037838e9" +git-tree-sha1 = "9f1b1f47279018b35316c62e829af1f3f6725a47" uuid = "cc47b68c-3164-5771-a705-2bc0097375a0" -version = "0.2.12" +version = "0.2.13" [[deps.SimpleRandom]] deps = ["Distributions", "LinearAlgebra", "Random"] @@ -2534,9 +2545,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.0" +version = "2.3.1" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -2736,10 +2747,39 @@ uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" version = "0.5.2" [[deps.Tidier]] -deps = ["Chain", "Cleaner", "DataFrames", "MacroTools", "Reexport", "ShiftedArrays", "Statistics"] -git-tree-sha1 = "251298c9b705805b26e1ed0ee5f150155aafa818" +deps = ["Reexport", "TidierCats", "TidierData", "TidierDates", "TidierPlots", "TidierStrings"] +git-tree-sha1 = "c7a6c4db043a4d27a4150a3ea07b03d3a9a158ca" uuid = "f0413319-3358-4bb0-8e7c-0c83523a93bd" -version = "0.7.7" +version = "1.0.1" + +[[deps.TidierCats]] +deps = ["CategoricalArrays", "DataFrames", "Reexport", "Statistics"] +git-tree-sha1 = "c4660f2c0ffd733ec243ea0a5447bd3bfae40c6d" +uuid = "79ddc9fe-4dbf-4a56-a832-df41fb326d23" +version = "0.1.1" + +[[deps.TidierData]] +deps = ["Chain", "Cleaner", "DataFrames", "MacroTools", "Reexport", "ShiftedArrays", "Statistics"] +git-tree-sha1 = "a4a83e2f5083ee6b18e0f01c99b4483b4f7978a2" +uuid = "fe2206b3-d496-4ee9-a338-6a095c4ece80" +version = "0.10.0" + +[[deps.TidierDates]] +deps = ["Dates", "Documenter", "Reexport"] +git-tree-sha1 = "ba1e0e3e7c99cdccb7c8d9d568e413283323716f" +uuid = "20186a3f-b5d3-468e-823e-77aae96fe2d8" +version = "0.1.0" + +[[deps.TidierPlots]] +deps = ["AlgebraOfGraphics", "CairoMakie", "DataFrames", "Makie", "MarketData", "PalmerPenguins", "Reexport"] +git-tree-sha1 = "1e2f273690efe000786b142bbe83b431fceb29f1" +uuid = "337ecbd1-5042-4e2a-ae6f-ca776f97570a" +version = "0.1.0" + +[[deps.TidierStrings]] +git-tree-sha1 = "1e704fbaf9f4d651ed9c59b4b6a6c325c0f09558" +uuid = "248e6834-d0f8-40ef-8fbb-8e711d883e9c" +version = "0.1.0" [[deps.TiffImages]] deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] @@ -2753,6 +2793,12 @@ git-tree-sha1 = "1176cc31e867217b06928e2f140c90bd1bc88283" uuid = "06e1c1a7-607b-532d-9fad-de7d9aa2abac" version = "0.5.0" +[[deps.TimeSeries]] +deps = ["Dates", "DelimitedFiles", "DocStringExtensions", "RecipesBase", "Reexport", "Statistics", "Tables"] +git-tree-sha1 = "8b9288d84da88ea44693ca8cf9c236da1778f274" +uuid = "9e3dc215-6440-5c97-bce1-76c03772f85e" +version = "0.23.2" + [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" @@ -2847,9 +2893,9 @@ version = "0.4.1" [[deps.Unitful]] deps = ["Dates", "LinearAlgebra", "Random"] -git-tree-sha1 = "1cd9b6d3f637988ca788007b7466c132feebe263" +git-tree-sha1 = "607c142139151faa591b5e80d8055a15e487095b" uuid = "1986cc42-f94f-5a68-af5c-568840ba703d" -version = "1.16.1" +version = "1.16.3" [deps.Unitful.extensions] ConstructionBaseUnitfulExt = "ConstructionBase" @@ -3094,9 +3140,9 @@ version = "1.5.5+0" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "5be3ddb88fc992a7d8ea96c3f10a49a7e98ebc7b" +git-tree-sha1 = "e2fe78907130b521619bc88408c859a472c4172b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.62" +version = "0.6.63" weakdeps = ["Colors", "Distances", "Tracker"] [deps.Zygote.extensions] diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 91b18b157d88c906f57a626f42b6d10e177cf814..dc6ed2663da9224e44925773825f7e07460ce844 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -164,8 +164,8 @@ end ```{julia} # Hyper: -_retrain = false -_regen = false +_retrain = true +_regen = true # Data: n_obs = 10000 @@ -547,32 +547,6 @@ lenet = NeuralNetworkClassifier( ) ``` -Robust Neural network: - -```{julia} -mutable struct RobustNetBuilder - n_hidden::Int - lipschitz_bound::Float32 -end - -function MLJFlux.build(b::RobustNetBuilder, rng, n_in, n_out) - n_hidden, γ = b.n_hidden, b.lipschitz_bound - _n_hidden = fill(n_hidden,2) - model_ps = DenseLBDNParams{Float32}(n_in, _n_hidden, n_out, γ; rng) - chain = Flux.Chain(DiffLBDN(model_ps)) - return chain -end - -# Final model: -rob_net = NeuralNetworkClassifier( - builder=RobustNetBuilder(60, 5.0), - epochs=600, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, -) -``` - Training all of them: ```{julia} diff --git a/paper/paper.pdf b/paper/paper.pdf index ff97640153419adbc346ad33dedef9f984346f65..6a465936f6699620fc527e6085cfbfc0d5a1467e 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ