In my model I use several complicated functions (long equations, basically some polynomials) to evaluate my model. I want to perform efficient MCMC (NUTS) using Turing.jl. [For now data is only simulated with added noise, but I'll use some experimental data later]
What is the best praxis to perform such inference. Note, that for this particular case, code is running quite ok, but I want to use this script for a lot more complicated equations. And for a bit more complicated case (longer polynomials in hh,ii,h2,Q_generated function) script becomes extremely slow on sampling.
What do you think I should do?
Here is my code...
# Load necessary libraries
using Turing, MCMCChains, Random,Plots,StatsPlots, Statistics
# Generated function used to evaluate model
@generated function w_generated(x, dG, dH, dcp)
quote
return exp(-1 / R * (dG / T0 + dH * (1 / x - 1 / T0) + dcp * (1 - T0 / x - log(x / T0))))
end
end
@generated function dh_generated(x, dH, dcp)
quote
return dH + dcp * (x - T0)
end
end
@generated function ii_generated(v, w, dcp)
quote
return dcp * v^2 * w * (6 * v^2 * w + 9 * v^2 + 6 * v * w^2 + 12 * v * w + 12 * v + 5 * w^4 + 8 * w^3 + 9 * w^2 + 8 * w + 5)
end
end
@generated function hh_generated(v, w, dh)
quote
return dh * v^2 * w * (6 * v^2 * w + 9 * v^2 + 6 * v * w^2 + 12 * v * w + 12 * v + 5 * w^4 + 8 * w^3 + 9 * w^2 + 8 * w + 5)
end
end
@generated function h2_generated(v, w, dh)
quote
return dh^2 * v^2 * w * (12 * v^2 * w + 9 * v^2 + 18 * v * w^2 + 24 * v * w + 12 * v + 25 * w^4 + 32 * w^3 + 27 * w^2 + 16 * w + 5)
end
end
@generated function Q_generated(v, w)
quote
return 3 * v^5 + 3 * v^4 * w^2 + 9 * v^4 * w + 19 * v^4 + 2 * v^3 * w^3 + 6 * v^3 * w^2 + 12 * v^3 * w + 30 * v^3 + v^2 * w^5 + 2 * v^2 * w^4 + 3 * v^2 * w^3 + 4 * v^2 * w^2 + 5 * v^2 * w + 21 * v^2 + 7 * v + 1
end
end
# Generate some artificial data with model functions
Random.seed!(12)
N = 101 # number of data points
R = 0.001987 # constant R
T0 = 273.15 # constant T0
x = collect(273.15:1:373.15) # temperature values in Kelvin from 0 to 100°C
# True parameter values
v_true = 0.048
dG_true = -0.24
dH_true = -1.0
dcp_true = -0.01
σ_true = 0.001 # added noise
# Generate data using these functions and the true parameter values
w_true = w_generated.(x, dG_true, dH_true, dcp_true)
dh_true = dh_generated.(x, dH_true, dcp_true)
Q_true = Q_generated.(v_true, w_true)
ii_true = ii_generated.(v_true, w_true, dcp_true)
hh_true = hh_generated.(v_true, w_true, dh_true)
h2_true = h2_generated.(v_true, w_true, dh_true)
#SIMULATED data
y = (ii_true .+ (hh_true + h2_true.^2) ./ (R .* x.^2)) ./ Q_true .+ randn(N) .* σ_true
# Define the polynomial regression model
@model function model_cp(x, y, N)
σ ~ truncated(Normal(0,10), 0, Inf) # smaller noise
v ~ Normal(0.048,0.06)
dG ~ Normal(-0.2,0.2)
dH ~ Normal(-1,0.5)
dcp ~ Normal(0.01,0.03)
for n in 1:N
w = w_generated(x[n], dG, dH, dcp)
dh = dh_generated(x[n], dH, dcp)
Q = Q_generated(v, w)
ii = ii_generated(v, w, dcp)
hh = hh_generated(v, w, dh)
h2 = h2_generated(v, w, dh)
μ = (ii + (hh + h2^2) / (R .* x[n]^2)) / Q
y[n] ~ Normal(μ, σ)
end
end
# Perform MCMC inference, NUTS sampler
model = model_cp(x, y, N)
chains = sample(model, NUTS(), 30000, burn_in=6000)
# Plot the results
p = plot(chains) # MCMC diagnostics plots
savefig(p, "mcmc_diagnostics.png") # Save the diagnostics plot
# Plot the data with the best fit line
v_hat = mean(chains[:v]) # estimated v
dG_hat = mean(chains[:dG]) # estimated dG
dH_hat = mean(chains[:dH]) # estimated dH
dcp_hat = mean(chains[:dcp]) # estimated dcp
w_hat = w_generated.(x, dG_hat, dH_hat, dcp_hat)
dh_hat = dh_generated.(x, dH_hat, dcp_hat)
Q_hat = Q_generated.(v_hat, w_hat)
ii_hat = ii_generated.(v_hat, w_hat, dcp_hat)
hh_hat = hh_generated.(v_hat, w_hat, dh_hat)
h2_hat = h2_generated.(v_hat, w_hat, dh_hat)
y_hat = (ii_hat .+ (hh_hat + h2_hat.^2) ./ (R .* x.^2)) ./ Q_hat
p = scatter(x, y, label="Data") # plot the original data
plot!(x, y_hat, label="Best Fit Line") # add the best fit line
savefig(p, "data_and_best_fit.png") # Save the data and best fit plot
I tried with @generated functions, but it doesn't seem to help that much on some more complicated, longer equations.