Creating a 'marginsplot' in R

164 views Asked by At

Inspired by this youtube https://www.youtube.com/watch?v=7maMbX_65b0 by the ever fantastic Chuck Huber, how can I recreate the marginsplot of Stata in R?

In other words, for my cplot() line at the end of the code block, how can I get the plot to show the predicted values for increments of 'age' by the levels of 'smoke'?

Any help is always much appreciated!

library(margins)
set.seed(42)
n <- 1000
patient <- data.frame(id=1:n,
                      treat = factor(sample(c('Treat','Control'), n, rep=TRUE, prob=c(.5, .5))),
                      age=sample(18:80, n, replace=TRUE),
                      sex = factor(sample(c('Male','Female'), n, rep=TRUE, prob=c(.6, .4))),
                      smoke=factor(sample(c("Never", 'Former', 'Current'), n, rep=TRUE, prob=c(.25, .6, .15))),
                      outcome=runif(n, min=16, max=45))

model <- lm(outcome ~ treat*age + smoke, data = patient)
cplot(model, x="age", by="smoke", overlay=TRUE)
3

There are 3 answers

2
jay.sf On BEST ANSWER

With replicating Stata's marginsplot, what you want is to visualize the mean of all possible predictions of your model along the desired x and y variables―such as age and smoke in this case.

lm1 <- lm(outcome ~ treat*age*smoke, data=patient)

So first, we create all possible desired combinations using expand.grid

.newdata <- expand.grid(
  treat=unique(patient$treat),
  age=with(patient, min(age):max(age)),
  sex=unique(patient$sex),
  smoke=unique(patient$smoke)
  )

in order to feed predict, which results we cbind.

.newdata <- cbind(.newdata, predict(lm1, newdata=.newdata, interval='conf'))

Next, we aggregate the mean of the fited values across the smoke and age variables, together with the respective lwr and upr bounds of the confidence intervals.

agg <- aggregate(cbind(fit, lwr, upr) ~ smoke + age, .newdata, mean)

With this, we are already done with preprocessing and are ready to plot.

par(mar=c(4, 4, 3, 2) + .1)
plot.new();plot.window(range(agg$age) + c(0, 2), range(agg[3:5]) + c(0, 2))
by(agg, agg$smoke, \(x) with(x, lines(age + as.integer(smoke) - 2, fit, col=smoke)))
dec <- agg$age %% 10 == 0
by(agg[dec, ], agg[dec, ]$smoke, \(x) 
   with(x, points(age + as.integer(smoke) - 2, fit, col=smoke, pch=20)))
by(agg[dec, ], agg[dec, ]$smoke, \(x) 
   with(x, arrows(age + as.integer(smoke) - 2, lwr, age + as.integer(smoke) - 2, upr, 
                  col=smoke, code=3, angle=90, length=.05)))
axis(1, axTicks(1)); axis(2, axTicks(2))
mtext('age', 1, 2.5); mtext('pred. outcome', 2, 2.5)
legend('topleft', pch=20, col=1:3, legend=unique(agg$smoke),
       title='smoke', horiz=TRUE, cex=.9)
box()

enter image description here

Data:

set.seed(42)
n <- 1000
patient <- data.frame(
  id=1:n, treat=factor(sample(c('Treat','Control'), n, T)),
  age=sample(18:80, n, T), sex=factor(sample(c('Male','Female'), n, T, c(.6, .4))),
  smoke=factor(sample(c("Never", 'Former', 'Current'), n, T, c(.25, .6, .15))),
  outcome=runif(n, min=16, max=45))
0
SamR On

What I think you are looking for can be done with sjPlot::plot_model():

library(ggplot2)
library(sjPlot)

plot_model(
    model,
    type = "pred",
    terms = c("age", "smoke"),
    ci.lvl = NA
) +
    theme_bw()

enter image description here

0
Vincent On

You can do all of this with the marginaleffects package (disclaimer: I am the maintainer). On the website, you will find over 25 vignettes, including a full vignette on plots:

Notice that I added an interaction to make the plots more fun-looking:

library(marginaleffects)
library(ggplot2)
set.seed(42)
n <- 1000
patient <- data.frame(id=1:n,
                      treat = factor(sample(c('Treat','Control'), n, rep=TRUE, prob=c(.5, .5))),
                      age=sample(18:80, n, replace=TRUE),
                      sex = factor(sample(c('Male','Female'), n, rep=TRUE, prob=c(.6, .4))),
                      smoke=factor(sample(c("Never", 'Former', 'Current'), n, rep=TRUE, prob=c(.25, .6, .15))),
                      outcome=runif(n, min=16, max=45))

model <- lm(outcome ~ treat * age * smoke, data = patient)

plot_predictions(model, condition = c("age", "smoke")) + theme_minimal()

Or you can plot slopes:

plot_slopes(model, variables = "treat", condition = c("age", "smoke")) + theme_minimal()