Complicated, long equations in MCMC [Julia, Turing.jl]

105 views Asked by At

In my model I use several complicated functions (long equations, basically some polynomials) to evaluate my model. I want to perform efficient MCMC (NUTS) using Turing.jl. [For now data is only simulated with added noise, but I'll use some experimental data later]

What is the best praxis to perform such inference. Note, that for this particular case, code is running quite ok, but I want to use this script for a lot more complicated equations. And for a bit more complicated case (longer polynomials in hh,ii,h2,Q_generated function) script becomes extremely slow on sampling.

What do you think I should do?

Here is my code...

# Load necessary libraries
using Turing, MCMCChains, Random,Plots,StatsPlots, Statistics

# Generated function used to evaluate model
@generated function w_generated(x, dG, dH, dcp)
    quote
        return exp(-1 / R * (dG / T0 + dH * (1 / x - 1 / T0) + dcp * (1 - T0 / x - log(x / T0))))
    end
end

@generated function dh_generated(x, dH, dcp)
    quote
        return dH + dcp * (x - T0)
    end
end

@generated function ii_generated(v, w, dcp)
    quote
        return dcp * v^2 * w * (6 * v^2 * w + 9 * v^2 + 6 * v * w^2 + 12 * v * w + 12 * v + 5 * w^4 + 8 * w^3 + 9 * w^2 + 8 * w + 5)
    end
end

@generated function hh_generated(v, w, dh)
    quote
        return dh * v^2 * w * (6 * v^2 * w + 9 * v^2 + 6 * v * w^2 + 12 * v * w + 12 * v + 5 * w^4 + 8 * w^3 + 9 * w^2 + 8 * w + 5)
    end
end

@generated function h2_generated(v, w, dh)
    quote
        return dh^2 * v^2 * w * (12 * v^2 * w + 9 * v^2 + 18 * v * w^2 + 24 * v * w + 12 * v + 25 * w^4 + 32 * w^3 + 27 * w^2 + 16 * w + 5)
    end
end

@generated function Q_generated(v, w)
    quote
        return 3 * v^5 + 3 * v^4 * w^2 + 9 * v^4 * w + 19 * v^4 + 2 * v^3 * w^3 + 6 * v^3 * w^2 + 12 * v^3 * w + 30 * v^3 + v^2 * w^5 + 2 * v^2 * w^4 + 3 * v^2 * w^3 + 4 * v^2 * w^2 + 5 * v^2 * w + 21 * v^2 + 7 * v + 1
    end
end

# Generate some artificial data with model functions
Random.seed!(12)
N = 101 # number of data points
R = 0.001987 # constant R
T0 = 273.15 # constant T0

x = collect(273.15:1:373.15) # temperature values in Kelvin from 0 to 100°C

# True parameter values
v_true = 0.048
dG_true = -0.24
dH_true = -1.0
dcp_true = -0.01
σ_true = 0.001 # added noise

# Generate data using these functions and the true parameter values
w_true = w_generated.(x, dG_true, dH_true, dcp_true)
dh_true = dh_generated.(x, dH_true, dcp_true)
Q_true = Q_generated.(v_true, w_true)
ii_true = ii_generated.(v_true, w_true, dcp_true)
hh_true = hh_generated.(v_true, w_true, dh_true)
h2_true = h2_generated.(v_true, w_true, dh_true)

#SIMULATED data
y = (ii_true .+ (hh_true + h2_true.^2) ./ (R .* x.^2)) ./ Q_true .+ randn(N) .* σ_true


# Define the polynomial regression model
@model function model_cp(x, y, N)
    σ ~ truncated(Normal(0,10), 0, Inf) # smaller noise
    v ~ Normal(0.048,0.06)
    dG ~ Normal(-0.2,0.2)
    dH ~ Normal(-1,0.5)
    dcp ~ Normal(0.01,0.03)
    for n in 1:N
        w = w_generated(x[n], dG, dH, dcp)
        dh = dh_generated(x[n], dH, dcp)
        Q = Q_generated(v, w)
        ii = ii_generated(v, w, dcp)
        hh = hh_generated(v, w, dh)
        h2 = h2_generated(v, w, dh)

        μ = (ii + (hh + h2^2) / (R .* x[n]^2)) / Q

        y[n] ~ Normal(μ, σ)
    end
end


# Perform MCMC inference, NUTS sampler
model = model_cp(x, y, N)
chains = sample(model, NUTS(),  30000, burn_in=6000)


# Plot the results
p = plot(chains) # MCMC diagnostics plots
savefig(p, "mcmc_diagnostics.png") # Save the diagnostics plot

# Plot the data with the best fit line
v_hat = mean(chains[:v])  # estimated v
dG_hat = mean(chains[:dG])  # estimated dG
dH_hat = mean(chains[:dH])  # estimated dH
dcp_hat = mean(chains[:dcp])  # estimated dcp

w_hat = w_generated.(x, dG_hat, dH_hat, dcp_hat)

dh_hat = dh_generated.(x, dH_hat, dcp_hat)
Q_hat = Q_generated.(v_hat, w_hat)
ii_hat = ii_generated.(v_hat, w_hat, dcp_hat)
hh_hat = hh_generated.(v_hat, w_hat, dh_hat)
h2_hat = h2_generated.(v_hat, w_hat, dh_hat)

y_hat = (ii_hat .+ (hh_hat + h2_hat.^2) ./ (R .* x.^2)) ./ Q_hat

p = scatter(x, y, label="Data")  # plot the original data
plot!(x, y_hat, label="Best Fit Line")  # add the best fit line
savefig(p, "data_and_best_fit.png")  # Save the data and best fit plot

I tried with @generated functions, but it doesn't seem to help that much on some more complicated, longer equations.

0

There are 0 answers