diff --git a/Manifest.toml b/Manifest.toml
index 144bf11e5b9b34ebd0baf7785cb8ecdf4f9c2a96..d16a0962307e2701b4cb33190dcace9d31955db4 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -780,7 +780,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JointEnergyModels]]
-deps = ["Distributions", "Flux", "StatsBase"]
+deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
 path = "../JointEnergyModels.jl"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
 version = "0.1.0"
@@ -997,7 +997,7 @@ version = "0.3.2"
 
 [[deps.MLJFlux]]
 deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
-git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0"
+path = "../MLJFlux.jl"
 uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
 version = "0.2.9"
 
diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml
index c38fdac5e7699d8f05b711dc80007a8dca43330c..6dd7ba4c2af64778888eecd9f46414a07292e102 100644
--- a/notebooks/Manifest.toml
+++ b/notebooks/Manifest.toml
@@ -180,9 +180,9 @@ version = "0.2.0"
 
 [[deps.CUDA_Runtime_jll]]
 deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
-git-tree-sha1 = "9ac3ffda60eeae5291be20f35ca264eb8e95bbc6"
+git-tree-sha1 = "81eed046f28a0cdd0dc1f61d00a49061b7cc9433"
 uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
-version = "0.5.0+1"
+version = "0.5.0+2"
 
 [[deps.CUDNN_jll]]
 deps = ["Artifacts", "CUDA_Runtime_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"]
@@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JointEnergyModels]]
-deps = ["ChainRulesCore", "Distributions", "Flux", "StatsBase"]
+deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
 path = "../../JointEnergyModels.jl"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
 version = "0.1.0"
@@ -1293,7 +1293,7 @@ version = "0.3.2"
 
 [[deps.MLJFlux]]
 deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
-git-tree-sha1 = "2ecdce4dd9214789ee1796103d29eaee7619ebd0"
+path = "../../MLJFlux.jl"
 uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
 version = "0.2.9"
 
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 4b2b2d3d5346cae6d1418e4b254099aa420d0ae7..433e3ba30df6509ee002fed887ba04f5fa7bef51 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -7,50 +7,80 @@ eval(setup_notebooks)
 
 ```{julia}
 # Data:
-counterfactual_data = load_mnist()
+n_obs = 1000
+counterfactual_data = load_mnist(n_obs)
 X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
 X = table(permutedims(X))
 labels = counterfactual_data.output_encoder.labels
 input_dim, n_obs = size(counterfactual_data.X)
+n_digits = Int(sqrt(input_dim))
+output_dim = length(unique(labels))
 ```
 
 First, let's create a couple of image classifier architectures:
 
 ```{julia}
+# Model parameters:
 epochs = 100
-batchsize = Int(round(n_obs/10))
+batch_size = Int(round(n_obs/10))
+n_hidden = 32
+activation = Flux.swish
+builder = MLJFlux.MLP(hidden=(n_hidden,), σ=activation)
+α = [0.33,1.0,1e-1]
 
 # Simple MLP:
 mlp = NeuralNetworkClassifier(
-    builder=MLJFlux.MLP(hidden=(32,), σ=relu), 
+    builder=builder, 
     epochs=epochs,
-    batch_size=batchsize,
+    batch_size=batch_size,
 )
 
-# Deep Ensemble:
-mlp_ens = EnsembleModel(model=mlp, n=5)
-
-# ResNet:
-using MLJFlux, Metalhead
-builder = MLJFlux.image_builder(ResNet, 18, pretrain=true)
-resnet = ImageClassifier(
+# Joint Energy Model:
+𝒟x = Uniform(-1,1)
+𝒟y = Categorical(ones(output_dim) ./ output_dim)
+sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=(input_dim,), batch_size=batch_size)
+jem = JointEnergyClassifier(
+    sampler;
     builder=builder,
-    epochs=epochs,
-    batch_size=batchsize,
+    batch_size=batch_size,
+    finaliser=Flux.softmax,
+    loss=Flux.Losses.crossentropy,
+    jem_training_params=(α=α,verbosity=10,),
+    sampling_steps=20,
 )
 
-# ResNet Ensemble:
-resnet_ens = EnsembleModel(model=resnet, n=5)
+# Deep Ensemble:
+mlp_ens = EnsembleModel(model=mlp, n=5)
 ```
 
 ```{julia}
-
-conf_model = conformal_model(mlp; method=:adaptive_inductive, coverage=.99)
+cov = .9
+conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov)
 mach = machine(conf_model, X, labels)
 fit!(mach)
 M = CCE.ConformalModel(mach.model, mach.fitresult)
 ```
 
+```{julia}
+jem = mach.model.model.jem
+n_iter = 5000
+_w = 1500
+plts = []
+neach = 10
+for i in 1:10
+    x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)
+    plts_i = []
+    for j in 1:size(x, 2)
+        xj = x[:,j]
+        xj = reshape(xj, (n_digits, n_digits))
+        plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
+    end
+    plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
+    plts = [plts..., plt]
+end
+Plots.plot(plts..., size=(_w,_w), layout=(10,1))
+```
+
 ```{julia}
 test_data = load_mnist_test()
 f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
@@ -69,7 +99,7 @@ T = 100
 # Generate counterfactual using generic generator:
 generator = GenericGenerator()
 ce_wachter = generate_counterfactual(
-    x, target, dt_reduced, M, generator; 
+    x, target, counterfactual_data, M, generator; 
     decision_threshold=γ, max_iter=T,
     initialization=:identity,
 )
@@ -81,7 +111,7 @@ generator = CCEGenerator(
     # opt=CounterfactualExplanations.Generators.JSMADescent(η=0.5),
 )
 ce_conformal = generate_counterfactual(
-    x, target, dt_reduced, M, generator; 
+    x, target, counterfactual_data, M, generator; 
     decision_threshold=γ, max_iter=T,
     initialization=:identity,
     converge_when=:generator_conditions,
@@ -115,12 +145,8 @@ display(plt)
 savefig(plt, joinpath(www_path, "cce_mnist.png"))
 ```
 
-
-
 ## Benchmark
 
-
-
 ```{julia}
 # Benchmark generators:
 generators = Dict(
diff --git a/src/CCE.jl b/src/CCE.jl
index 1b678a8121a413e2282e4a0096bdd67a02b0d605..809408f16ccc7b2e5619e2fcb5e163eef83c5501 100644
--- a/src/CCE.jl
+++ b/src/CCE.jl
@@ -9,9 +9,6 @@ include("losses.jl")
 include("generator.jl")
 include("sampling.jl")
 
-using MLJFlux
-MLJFlux.reformat(X, ::Type{<:AbstractMatrix}) = X'
-
 export CCEGenerator, EnergySampler, set_size_penalty, distance_from_energy
 
 end
\ No newline at end of file
diff --git a/www/cce_mnist.png b/www/cce_mnist.png
index 1fd6cc5db60fa44b4b1a090e151e4fababaef3fc..8af78f6df7b2364c9d6883f4b5dea1046c1ac173 100644
Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ