Is there a way to retrieve the data from a BART package model in R?

352 views Asked by At

I was wondering if there was a way to retrieve the data from a model built from the BART package in R?

It seems to be possible using other bart packages, such as dbarts... but I can't seem to find a way to get the original data back from a BART model. For example, if I create some data and run a BART and dbarts model, like so:

library(BART)
library(dbarts)

# create data
df <- data.frame(
  x = runif(100),
  y = runif(100),
  z = runif(100)
)

# create BART
BARTmodel <- wbart(x.train = df[,1:2],
                   y.train = df[,3])

# create dbarts
DBARTSmodel <- bart(x.train = df[,1:2],
                    y.train = df[,3],
                    keeptrees = TRUE)

Using the keeptrees option in dbarts allows me to retrieve the data using:

# retrieve data from dbarts
DBARTSmodel$fit$data@x

However, there doesn't seem to be any type of similar option when using BART. Is it even possible to retrieve the data from a BART model?

2

There are 2 answers

1
alan ocallaghan On BEST ANSWER

The Value: section of ?wbart suggests it doesn't return the input as part of the output, and none of the function arguments for wbart suggest that this can be changed.

Furthermore, if you look at the output of str, you can see that it's not present.

library(BART)
library(dbarts)

# create data
df <- data.frame(
  x = runif(100),
  y = runif(100),
  z = runif(100)
)

# create BART
BARTmodel <- wbart(x.train = df[,1:2],
                   y.train = df[,3])

# create dbarts
DBARTSmodel <- bart(x.train = df[,1:2],
                    y.train = df[,3],
                    keeptrees = TRUE)

str(BARTmodel)
#> List of 13
#>  $ sigma          : num [1:1100] 0.258 0.262 0.295 0.278 0.273 ...
#>  $ yhat.train.mean: num [1:100] 0.584 0.457 0.505 0.54 0.403 ...
#>  $ yhat.train     : num [1:1000, 1:100] 0.673 0.62 0.433 0.711 0.634 ...
#>  $ yhat.test.mean : num(0) 
#>  $ yhat.test      : num[1:1000, 0 ] 
#>  $ varcount       : int [1:1000, 1:2] 109 114 111 118 115 114 115 110 114 117 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : NULL
#>   .. ..$ : chr [1:2] "x" "y"
#>  $ varprob        : num [1:1000, 1:2] 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : NULL
#>   .. ..$ : chr [1:2] "x" "y"
#>  $ treedraws      :List of 2
#>   ..$ cutpoints:List of 2
#>   .. ..$ x: num [1:100] 0.0147 0.0245 0.0343 0.0442 0.054 ...
#>   .. ..$ y: num [1:100] 0.0395 0.0491 0.0586 0.0681 0.0776 ...
#>   ..$ trees    : chr "1000 200 2\n1\n1 0 0 0.01185590432\n3\n1 1 30 -0.01530736435\n2 0 0 0.01064412946\n3 0 0 0.02413784284\n3\n1 0 "| __truncated__
#>  $ proc.time      : 'proc_time' Named num [1:5] 1.406 0.008 1.415 0 0
#>   ..- attr(*, "names")= chr [1:5] "user.self" "sys.self" "elapsed" "user.child" ...
#>  $ mu             : num 0.501
#>  $ varcount.mean  : Named num [1:2] 115 110
#>   ..- attr(*, "names")= chr [1:2] "x" "y"
#>  $ varprob.mean   : Named num [1:2] 0.5 0.5
#>   ..- attr(*, "names")= chr [1:2] "x" "y"
#>  $ rm.const       : int [1:2] 1 2
#>  - attr(*, "class")= chr "wbart"

Whereas the output of str() for the bart output, while long, does contain the input:

str(DBARTSmodel)
#> List of 11
#>  $ call           : language bart(x.train = df[, 1:2], y.train = df[, 3], keeptrees = TRUE)
#>  $ first.sigma    : num [1:100] 0.289 0.311 0.268 0.253 0.242 ...
#>  $ sigma          : num [1:1000] 0.288 0.307 0.248 0.257 0.293 ...
#>  $ sigest         : num 0.295
#>  $ yhat.train     : num [1:1000, 1:100] 0.715 0.677 0.508 0.51 0.827 ...
#>  $ yhat.train.mean: num [1:100] 0.583 0.456 0.504 0.544 0.404 ...
#>  $ yhat.test      : NULL
#>  $ yhat.test.mean : NULL
#>  $ varcount       : int [1:1000, 1:2] 128 118 120 142 130 145 145 150 138 138 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : NULL
#>   .. ..$ : chr [1:2] "x" "y"
#>  $ y              : num [1:100] 0.8489 0.0817 0.4371 0.8566 0.0878 ...
#>  $ fit            :Reference class 'dbartsSampler' [package "dbarts"] with 5 fields
#>   ..$ pointer:<externalptr> 
#>   ..$ control:Formal class 'dbartsControl' [package "dbarts"] with 18 slots
#>   .. .. ..@ binary          : logi FALSE
#>   .. .. ..@ verbose         : logi TRUE
#>   .. .. ..@ keepTrainingFits: logi TRUE
#>   .. .. ..@ useQuantiles    : logi FALSE
#>   .. .. ..@ keepTrees       : logi TRUE
#>   .. .. ..@ n.samples       : int 1000
#>   .. .. ..@ n.burn          : int 100
#>   .. .. ..@ n.trees         : int 200
#>   .. .. ..@ n.chains        : int 1
#>   .. .. ..@ n.threads       : int 1
#>   .. .. ..@ n.thin          : int 1
#>   .. .. ..@ printEvery      : int 100
#>   .. .. ..@ printCutoffs    : int 0
#>   .. .. ..@ rngKind         : chr "default"
#>   .. .. ..@ rngNormalKind   : chr "default"
#>   .. .. ..@ rngSeed         : int NA
#>   .. .. ..@ updateState     : logi TRUE
#>   .. .. ..@ call            : language bart(x.train = df[, 1:2], y.train = df[, 3], keeptrees = TRUE)
#>   ..$ model  :Formal class 'dbartsModel' [package "dbarts"] with 9 slots
#>   .. .. ..@ p.birth_death  : num 0.5
#>   .. .. ..@ p.swap         : num 0.1
#>   .. .. ..@ p.change       : num 0.4
#>   .. .. ..@ p.birth        : num 0.5
#>   .. .. ..@ node.scale     : num 0.5
#>   .. .. ..@ tree.prior     :Formal class 'dbartsCGMPrior' [package "dbarts"] with 3 slots
#>   .. .. .. .. ..@ power             : num 2
#>   .. .. .. .. ..@ base              : num 0.95
#>   .. .. .. .. ..@ splitProbabilities: num(0) 
#>   .. .. ..@ node.prior     :Formal class 'dbartsNormalPrior' [package "dbarts"] with 0 slots
#>  list()
#>   .. .. ..@ node.hyperprior:Formal class 'dbartsFixedHyperprior' [package "dbarts"] with 1 slot
#>   .. .. .. .. ..@ k: num 2
#>   .. .. ..@ resid.prior    :Formal class 'dbartsChiSqPrior' [package "dbarts"] with 2 slots
#>   .. .. .. .. ..@ df      : num 3
#>   .. .. .. .. ..@ quantile: num 0.9
#>   ..$ data   :Formal class 'dbartsData' [package "dbarts"] with 10 slots
#>   .. .. ..@ y                    : num [1:100] 0.8489 0.0817 0.4371 0.8566 0.0878 ...
#>   .. .. ..@ x                    : num [1:100, 1:2] 0.152 0.666 0.967 0.248 0.668 ...
#>   .. .. .. ..- attr(*, "dimnames")=List of 2
#>   .. .. .. .. ..$ : NULL
#>   .. .. .. .. ..$ : chr [1:2] "x" "y"
#>   .. .. .. ..- attr(*, "drop")=List of 2
#>   .. .. .. .. ..$ x: logi FALSE
#>   .. .. .. .. ..$ y: logi FALSE
#>   .. .. .. ..- attr(*, "term.labels")= chr [1:2] "x" "y"
#>   .. .. ..@ varTypes             : int [1:2] 0 0
#>   .. .. ..@ x.test               : NULL
#>   .. .. ..@ weights              : NULL
#>   .. .. ..@ offset               : NULL
#>   .. .. ..@ offset.test          : NULL
#>   .. .. ..@ n.cuts               : int [1:2] 100 100
#>   .. .. ..@ sigma                : num 0.295
#>   .. .. ..@ testUsesRegularOffset: logi NA
#>   ..$ state  :List of 1
#>   .. ..$ :Formal class 'dbartsState' [package "dbarts"] with 6 slots
#>   .. .. .. ..@ trees     : int [1:1055] 0 18 -1 0 49 -1 -1 0 60 -1 ...
#>   .. .. .. ..@ treeFits  : num [1:100, 1:200] -0.02252 0.00931 0.00931 0.02688 0.00931 ...
#>   .. .. .. ..@ savedTrees: int [1:2340360] 0 797997482 1070928224 1 -402902351 1070268808 -1 -1094651769 -1081938039 -1 ...
#>   .. .. .. ..@ sigma     : num 0.297
#>   .. .. .. ..@ k         : num 2
#>   .. .. .. ..@ rng.state : int [1:18] 0 1078575104 0 1078575104 -1657977906 1075613906 0 1078558720 277209871 -1068236140 ...
#>   .. ..- attr(*, "runningTime")= num 0.477
#>   .. ..- attr(*, "currentNumSamples")= int 1000
#>   .. ..- attr(*, "currentSampleNum")= int 0
#>   .. ..- attr(*, "numCuts")= int [1:2] 100 100
#>   .. ..- attr(*, "cutPoints")=List of 2
#>   .. .. ..$ : num [1:100] 0.0147 0.0245 0.0343 0.0442 0.054 ...
#>   .. .. ..$ : num [1:100] 0.0395 0.0491 0.0586 0.0681 0.0776 ...
#>   ..and 40 methods, of which 26 are  possibly relevant:
#>   ..  copy#envRefClass, getLatents, getPointer, getTrees, initialize, plotTree,
#>   ..  predict, printTrees, run, sampleNodeParametersFromPrior,
#>   ..  sampleTreesFromPrior, setControl, setCutPoints, setData, setModel,
#>   ..  setOffset, setPredictor, setResponse, setSigma, setState, setTestOffset,
#>   ..  setTestPredictor, setTestPredictorAndOffset, setWeights,
#>   ..  show#envRefClass, storeState
#>  - attr(*, "class")= chr "bart"
1
Le Paul On

You can achieve what you are looking for using bartModelMatrix() function form BART package.

This function it will determinate the number of cutpoints necessary for each column.

In this way, you'll have so many columns as variables you have in your df. In your example you're only insterested in x and y, so you'll only care for the first and second column from bartModelMatrix() matrix obtained.

So for the example you gave:

# create data
df <- data.frame(
  x = runif(100),
  y = runif(100),
  z = runif(100),
)

# create BART
BARTmodel <- wbart(x.train = df[,1:2],
                   y.train = df[,3])

# create dbarts
DBARTSmodel <- bart(x.train = df[,1:2],
                    y.train = df[,3],
                    keeptrees = TRUE)


BARTmatrix <- bartModelMatrix(df) 
BARTmatrix <- BARTmatrix[,1:2]

BARTmatrix == DBARTSmodel$fit$data@x

Hope that helped you