Does python have an analogue to R's splines::ns()

407 views Asked by At

I would like to replicate making this Q matrix in python, but I can't seem to make it happen.

min = 0
max = 10
tau = seq(min, max)
pDegree = 5
Q <- splines::ns(tau, pDegree)
print(Q)

Here are some tries in python

import numpy as np
from patsy import dmatrix
from scipy import interpolate
min = 0
max = 10
tau = np.arange(min, max + 1)
pDegree = 5
# try one
spline_basis = dmatrix("bs(x, df=" + str(pDegree) + ", include_intercept=True) - 1", {"x": tau})
print(spline_basis)
# try two
spline_basis = dmatrix("bs(x, df=" + str(pDegree) + ", include_intercept=False) - 1", {"x": tau})
print(spline_basis)

This is the matrix I am getting in R

               1          2           3          4           5
 [1,] 0.00000000 0.00000000  0.00000000 0.00000000  0.00000000
 [2,] 0.02083333 0.00000000 -0.11620871 0.34862613 -0.23241742
 [3,] 0.16666667 0.00000000 -0.16903085 0.50709255 -0.33806170
 [4,] 0.47916667 0.02083333 -0.12149092 0.36447277 -0.24298185
 [5,] 0.66666667 0.16666667 -0.04225771 0.12677314 -0.08451543
 [6,] 0.47916667 0.47916667  0.01406302 0.02031093 -0.01354062
 [7,] 0.16666667 0.66666667  0.15476190 0.03571429 -0.02380952
 [8,] 0.02083333 0.47916667  0.44196429 0.11160714 -0.05357143
 [9,] 0.00000000 0.16666667  0.59523810 0.21428571  0.02380952
[10,] 0.00000000 0.02083333  0.35119048 0.32142857  0.30654762
[11,] 0.00000000 0.00000000 -0.14285714 0.42857143  0.71428571
2

There are 2 answers

3
langtang On

You have a couple of things going on here

  1. tau in your R code is 0:11, but tau in your python code is equivalent to 0:10
  2. You are generating B-splines in your python code, but using ns() in your R code.

I've found that I can use dmatrix to replicate the basis as estimated by mgcv package.

For example:

mgcv::smoothCon(mgcv::s(tau, bs="cr", k=5), data=data.frame(tau=tau))[[1]]$X
              [,1]        [,2]        [,3]        [,4]         [,5]
 [1,]  1.000000000  0.00000000  0.00000000  0.00000000  0.000000000
 [2,]  0.551840721  0.55522164 -0.13523666  0.03380917 -0.005634861
 [3,]  0.180959536  0.93527960 -0.14682838  0.03670709 -0.006117849
 [4,] -0.035821616  0.96624450  0.08768917 -0.02173446  0.003622411
 [5,] -0.076875604  0.62353762  0.54792315 -0.11350220  0.018917033
 [6,] -0.027771815  0.17264141  0.93603091 -0.09708061  0.016180101
 [7,]  0.016180101 -0.09708061  0.93603091  0.17264141 -0.027771815
 [8,]  0.018917033 -0.11350220  0.54792315  0.62353762 -0.076875604
 [9,]  0.003622411 -0.02173446  0.08768917  0.96624450 -0.035821616
[10,] -0.006117849  0.03670709 -0.14682838  0.93527960  0.180959536
[11,] -0.005634861  0.03380917 -0.13523666  0.55522164  0.551840721
[12,]  0.000000000  0.00000000  0.00000000  0.00000000  1.000000000

Python:

import numpy as np
from patsy import dmatrix
tau = np.arange(12)
print(dmatrix("cr(x, df=5) - 1", {"x": tau}))

[[ 1.          0.          0.          0.          0.        ]
 [ 0.55184072  0.55522164 -0.13523666  0.03380917 -0.00563486]
 [ 0.18095954  0.9352796  -0.14682838  0.03670709 -0.00611785]
 [-0.03582162  0.9662445   0.08768917 -0.02173446  0.00362241]
 [-0.0768756   0.62353762  0.54792315 -0.1135022   0.01891703]
 [-0.02777181  0.17264141  0.93603091 -0.09708061  0.0161801 ]
 [ 0.0161801  -0.09708061  0.93603091  0.17264141 -0.02777181]
 [ 0.01891703 -0.1135022   0.54792315  0.62353762 -0.0768756 ]
 [ 0.00362241 -0.02173446  0.08768917  0.9662445  -0.03582162]
 [-0.00611785  0.03670709 -0.14682838  0.9352796   0.18095954]
 [-0.00563486  0.03380917 -0.13523666  0.55522164  0.55184072]
 [ 0.          0.          0.          0.          1.        ]]
 
0
Francis On

I was dealing with a similar problem recently, so I wrote my own ns() by examining R's ns() and patsy's bs(). Here is the Python code:

import numpy as np
from scipy.interpolate import splev

def ns(x, df=None, knots=None, boundary_knots=None, include_intercept=False):
    degree = 3
    if boundary_knots is None:
        boundary_knots = [np.min(x), np.max(x)]

    if df is not None:
        nIknots = df - 1 - include_intercept
        if nIknots < 0:
            nIknots = 0
            
        if nIknots > 0:
            knots = np.linspace(0, 1, num=nIknots + 2)[1:-1]
            knots = np.quantile(x, knots)

    Aknots = np.sort(np.concatenate((boundary_knots * 4, knots)))
    n_bases = len(Aknots) - (degree + 1)
    
    basis = np.empty((x.shape[0], n_bases), dtype=float)
    for i in range(n_bases):
        coefs = np.zeros((n_bases,))
        coefs[i] = 1
        basis[:, i] = splev(x, (Aknots, coefs, degree))

    const = np.empty((2, n_bases), dtype=float)
    for i in range(n_bases):
        coefs = np.zeros((n_bases,))
        coefs[i] = 1
        const[:, i] = splev(boundary_knots, (Aknots, coefs, degree), der=2)

    if include_intercept is False:
        basis = basis[:, 1:]
        const = const[:, 1:]

    qr_const = np.linalg.qr(const.T, mode='complete')[0]
    basis = (qr_const.T @ basis.T).T[:, 2:]

    return basis

Running your example:

min = 0
max = 10
tau = np.arange(min, max + 1)
pDegree = 5
# try one
spline_basis = ns(tau, df=pDegree, include_intercept=True)
print(spline_basis)
# try two
spline_basis = ns(tau, df=pDegree, include_intercept=False)
print(spline_basis)

I get

[[-0.26726124  0.         -0.21428571  0.64285714 -0.42857143]
 [ 0.19783305  0.01066667 -0.14498945  0.43496835 -0.2899789 ]
 [ 0.52378074  0.08533333 -0.08463035  0.25389105 -0.1692607 ]
 [ 0.57850426  0.28266667 -0.04058487  0.12575462 -0.08383642]
 [ 0.39251408  0.53866667  0.01566699  0.06099904 -0.04066603]
 [ 0.15890871  0.66666667  0.1485417   0.0543749  -0.03624993]
 [ 0.03432428  0.53866667  0.38218025  0.09745926 -0.05430618]
 [ 0.00127127  0.28266667  0.57337881  0.17186357 -0.02924238]
 [ 0.          0.08533333  0.54361905  0.25714286  0.11390476]
 [ 0.          0.01066667  0.26438095  0.34285714  0.38209524]
 [ 0.          0.         -0.14285714  0.42857143  0.71428571]]
[[ 0.          0.          0.          0.          0.        ]
 [ 0.02083333  0.         -0.11620871  0.34862613 -0.23241742]
 [ 0.16666667  0.         -0.16903085  0.50709255 -0.3380617 ]
 [ 0.47916667  0.02083333 -0.12149092  0.36447277 -0.24298185]
 [ 0.66666667  0.16666667 -0.04225771  0.12677314 -0.08451543]
 [ 0.47916667  0.47916667  0.01406302  0.02031093 -0.01354062]
 [ 0.16666667  0.66666667  0.1547619   0.03571429 -0.02380952]
 [ 0.02083333  0.47916667  0.44196429  0.11160714 -0.05357143]
 [ 0.          0.16666667  0.5952381   0.21428571  0.02380952]
 [ 0.          0.02083333  0.35119048  0.32142857  0.30654762]
 [ 0.          0.         -0.14285714  0.42857143  0.71428571]]

which matches R's output.

Note that R's ns() contains a lot more case-specific processing that I did not implement. So use at your own risk.