Using generalized additive models (GAMs) to learn

Một phần của tài liệu Practical data science with r (Trang 248 - 259)

In chapter 7, we looked at how to use linear regression to model and predict quantita- tive output, and how to use logistic regression to predict class probabilities. Linear and logistic regression models are powerful tools, especially when you want to under- stand the relationship between the input variables and the output. They’re robust to correlated variables (when regularized), and logistic regression preserves the mar- ginal probabilities of the data. The primary shortcoming of both these models is that they assume that the relationship between the inputs and the output is monotone.

That is, if more is good, than much more is always better.

But what if the actual relationship is non-monotone? For example, for under- weight patients, increasing weight can increase health. But there’s a limit: at some point more weight is bad. Linear and logistic regression miss this distinction (but still often perform surprisingly well, hiding the issue). Generalized additive models (GAMs) are a way to model non-monotone responses within the framework of a linear or logis- tic model (or any other generalized linear model).

9.2.1 Understanding GAMs

Recall that, if y[i] is the numeric quantity you want to predict, and x[i,] is a row of inputs that corresponds to output y[i], then linear regression finds a function f(x) such that

f(x[i,]) = b0 + b[1] x[i,1] + b[2] x[i,2] + ... b[n] x[i,n]

And f(x[i,]) is as close to y[i] as possible.

In its simplest form, a GAM model relaxes the linearity constraint and finds a set of functions s_i() (and a constant term a0) such that

f(x[i,]) = a0 + s_1(x[i,1]) + s_2(x[i,2]) + ... s_n(x[i,n])

And f(x[i,]) is as close to y[i] as possible. The functions s_i() are smooth curve fits that are built up from polynomials. The curves are called splines and are designed to pass as closely as possible through the data without being too “wiggly”

(without overfitting). An example of a spline fit is shown in figure 9.2.

Let’s work on a concrete example.

222 CHAPTER 9 Exploring advanced methods

9.2.2 A one-dimensional regression example

Let’s consider a toy example where the response y is a noisy nonlinear function of the input variable x (in fact, it’s the function shown in figure 9.2). As usual, we’ll split the data into training and test sets.

set.seed(602957) x <- rnorm(1000)

noise <- rnorm(1000, sd=1.5)

y <- 3*sin(2*x) + cos(0.75*x) - 1.5*(x^2 ) + noise select <- runif(1000)

frame <- data.frame(y=y, x = x) train <- frame[select > 0.1,]

test <-frame[select <= 0.1,]

Given the data is from the nonlinear functions sin() and cos(), there shouldn’t be a good linear fit from x to y. We’ll start by building a (poor) linear regression.

> lin.model <- lm(y ~ x, data=train)

> summary(lin.model)

Listing 9.6 Preparing an artificial problem

Listing 9.7 Linear regression applied to our artificial example

−10 0

−2 0 2

x

y

Figure 9.2 A spline that has been fit through a series of points

223 Using generalized additive models (GAMs) to learn non-monotone relationships

Call:

lm(formula = y ~ x, data = train) Residuals:

Min 1Q Median 3Q Max

-17.698 -1.774 0.193 2.499 7.529 Coefficients:

Estimate Std. Error t value Pr(>|t|) (Intercept) -0.8330 0.1161 -7.175 1.51e-12 ***

x 0.7395 0.1197 6.180 9.74e-10 ***

---

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 3.485 on 899 degrees of freedom Multiple R-squared: 0.04075, Adjusted R-squared: 0.03968 F-statistic: 38.19 on 1 and 899 DF, p-value: 9.737e-10

#

# calculate the root mean squared error (rmse)

#

> resid.lin <- train$y-predict(lin.model)

> sqrt(mean(resid.lin^2)) [1] 3.481091

The resulting model’s predictions are plotted versus true response in figure 9.3. As expected, it’s a very poor fit, with an R-squared of about 0.04. In particular, the errors

−10 0

−3 −2 −1 0 1

pred

actual

Response v. Prediction, linear model

Figure 9.3 Linear model’s predictions versus actual response. The solid line is the line of

224 CHAPTER 9 Exploring advanced methods

are heteroscedastic:5 there are regions where the model systematically underpredicts and regions where it systematically overpredicts. If the relationship between x and y were truly linear (with noise), then the errors would be homoscedastic: the errors would be evenly distributed (mean 0) around the predicted value everywhere.

Let’s try finding a nonlinear model that maps x to y. We’ll use the function gam() in the package mgcv.6 When using gam(), you can model variables as either linear or nonlinear. You model a variable x as nonlinear by wrapping it in the s() notation. In this example, rather than using the formula y~x to describe the model, you’d use the formula y ~s(x). Then gam() will search for the spline s() that best describes the relationship between x and y, as shown in listing 9.8. Only terms surrounded by s() get the GAM/spline treatment.

> library(mgcv)

> glin.model <- gam(y~s(x), data=train)

> glin.model$converged [1] TRUE

> summary(glin.model)

Family: gaussian

Link function: identity

Formula:

y ~ s(x)

Parametric coefficients:

Estimate Std. Error t value Pr(>|t|) (Intercept) -0.83467 0.04852 -17.2 <2e-16 ***

---

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Approximate significance of smooth terms:

edf Ref.df F p-value

5 Heteroscedastic errors are errors whose magnitude is correlated with the quantity to be predicted. Heterosce- dastic errors are bad because they’re systematic and violate the assumption that errors are uncorrelated with outcomes, which is used in many proofs of the good properties of regression methods.

6 There’s an older package called gam, written by Hastie and Tibshirani, the inventors of GAMs. The gam pack- age works fine. But it’s incompatible with the mgcv package, which ggplot already loads. Since we’re using ggplot for plotting, we’ll use mgcv for our examples.

Listing 9.8 GAM applied to our artificial example

Load the mgcv package.

Build the model, specifying that x should be treated as a nonlinear variable.

The converged parameter tells you if the algorithm converged.

You should only trust the output if this is TRUE.

Setting family=gaussian and link=identity tells you that the model was treated with the same distributions assumptions as a standard linear regression.

The parametric coefficients are the linear terms (in this example, only the constant term). This section of the summary tells you which linear terms were significantly different from 0.

The smooth terms are the nonlinear terms.

This section of the summary tells you which nonlinear terms were significantly different from 0. It also tells you the effective degrees of freedom (edf) used up to build each smooth term. An edf near 1 indicates that the variable has an approximately linear relationship to the output.

225 Using generalized additive models (GAMs) to learn non-monotone relationships

s(x) 8.685 8.972 497.8 <2e-16 ***

---

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 R-sq.(adj) = 0.832 Deviance explained = 83.4%

GCV score = 2.144 Scale est. = 2.121 n = 901

#

# calculate the root mean squared error (rmse)

#

> resid.glin <- train$y-predict(glin.model)

> sqrt(mean(resid.glin^2)) [1] 1.448514

The resulting model’s predictions are plotted versus true response in figure 9.4. This fit is much better: the model explains over 80% of the variance (R-squared of 0.83), and the root mean squared error (RMSE) over the training data is less than half the RMSE of the linear model. Note that the points in figure 9.4 are distributed more or less evenly around the line of perfect prediction. The GAM has been fit to be homoscedastic, and any given prediction is as likely to be an overprediction as an underprediction.

“R-sq (adj)” is the adjusted R-squared.

“Deviance explained” is the raw R-squared (0.834).

−10 0

−15 −10 −5 0

pred

actual

Response v. Prediction, GAM

Figure 9.4 GAM’s predictions versus actual response. The solid line is the theoretical line of perfect prediction (prediction=actual).

226 CHAPTER 9 Exploring advanced methods

The use of splines gives GAMs a richer model space to choose from; this increased flex- ibility brings a higher risk of overfitting. Let’s check the models’ performances on the test data.

> actual <- test$y

> pred.lin <- predict(lin.model, newdata=test)

> pred.glin <- predict(glin.model, newdata=test)

> resid.lin <- actual-pred.lin

> resid.glin <- actual-pred.glin

> sqrt(mean(resid.lin^2)) [1] 2.792653

> sqrt(mean(resid.glin^2)) [1] 1.401399

> cor(actual, pred.lin)^2 [1] 0.1543172

> cor(actual, pred.glin)^2 [1] 0.7828869

The GAM performed similarly on both sets (RMSE of 1.40 on test versus 1.45 on train- ing; R-squared of 0.78 on test versus 0.83 on training). So there’s likely no overfit.

9.2.3 Extracting the nonlinear relationships

Once you fit a GAM, you’ll probably be interested in what the s() functions look like. Calling plot() on a GAM will give you a plot for each s() curve, so you can visu- alize nonlinearities. In our example, plot(glin.model) produces the top curve in figure 9.5.

The shape of the curve is quite similar to the scatter plot we saw in figure 9.2 (which is reproduced as the lower half of figure 9.5). In fact, the spline that’s superim- posed on the scatter plot in figure 9.2 is the same curve.

Listing 9.9 Comparing linear regression and GAM performance Modeling linear relationships using gam()

By default, gam() will perform standard linear regression. If you were to call gam() with the formula y ~ x, you’d get the same model that you got using lm(). More generally, the call gam(y ~ x1 + s(x2), data=...) would model the variable x1 as having a linear relationship with y, and try to fit the best possible smooth curve to model the relationship between x2 and y. Of course, the best smooth curve could be a straight line, so if you’re not sure whether the relationship between x and y is lin- ear, you can use s(x). If you see that the coefficient has an edf (effective degrees of freedom—see the model summary in listing 9.8) of about 1, then you can try refit- ting the variable as a linear term.

Call both models on the test data.

Compare the RMSE of the linear model and the GAM on the test data.

Compare the R-squared of the linear model and the GAM on test data.

227 Using generalized additive models (GAMs) to learn non-monotone relationships

We can extract the data points that were used to make this graph by using the predict() function with the argument type="terms". This produces a matrix where the ith column represents s(x[,i]). Listing 9.10 demonstrates how to reproduce the lower plot in figure 9.5.

−3 −2 −1 0 1 2 3

−20−15−10−505

x

s(x,8.69)

−10 0

−2 0 2

x

y

Figure 9.5 Top: The nonlinear function s(PWGT) discovered by gam(), as output by plot(gam.model) Bottom: The same spline superimposed over the training data

228 CHAPTER 9 Exploring advanced methods

> sx <- predict(glin.model, type="terms")

> summary(sx) s(x)

Min. :-17.527035 1st Qu.: -2.378636 Median : 0.009427 Mean : 0.000000 3rd Qu.: 2.869166 Max. : 4.084999

> xframe <- cbind(train, sx=sx[,1])

> ggplot(xframe, aes(x=x)) + geom_point(aes(y=y), alpha=0.4) + geom_line(aes(y=sx))

Now that we’ve worked through a simple example, let’s try a more realistic example with more variables.

9.2.4 Using GAM on actual data

For this example, we’ll predict a newborn baby’s weight (DBWT) using data from the CDC 2010 natality dataset that we used in section 7.2 (though this is not the risk data used in that chapter).7 As input, we’ll consider mother’s weight (PWGT), mother’s preg- nancy weight gain (WTGAIN), mother’s age (MAGER), and the number of prenatal medi- cal visits (UPREVIS).8

In the following listing, we’ll fit a linear model and a GAM, and compare.

> library(mgcv)

> library(ggplot2)

> load("NatalBirthData.rData")

> train <- sdata[sdata$ORIGRANDGROUP<=5,]

> test <- sdata[sdata$ORIGRANDGROUP>5,]

> form.lin <- as.formula("DBWT ~ PWGT + WTGAIN + MAGER + UPREVIS")

> linmodel <- lm(form.lin, data=train)

> summary(linmodel) Call:

lm(formula = form.lin, data = train) Residuals:

Min 1Q Median 3Q Max

-3155.43 -272.09 45.04 349.81 2870.55 Listing 9.10 Extracting a learned spline from a GAM

7 The dataset can be found at https://github.com/WinVector/zmPDSwR/blob/master/CDC/NatalBirthData .rData. A script for preparing the dataset from the original CDC extract can be found at https://github.com/

WinVector/zmPDSwR/blob/master/CDC/prepBirthWeightData.R.

8 We’ve chosen this example to highlight the mechanisms of gam(), not to find the best model for birth weight.

Adding other variables beyond the four we’ve chosen will improve the fit, but obscure the exposition.

Listing 9.11 Applying linear regression (with and without GAM) to health data

Build a linear model with four variables.

229 Using generalized additive models (GAMs) to learn non-monotone relationships

Coefficients:

Estimate Std. Error t value Pr(>|t|) (Intercept) 2419.7090 31.9291 75.784 < 2e-16 ***

PWGT 2.1713 0.1241 17.494 < 2e-16 ***

WTGAIN 7.5773 0.3178 23.840 < 2e-16 ***

MAGER 5.3213 0.7787 6.834 8.6e-12 ***

UPREVIS 12.8753 1.1786 10.924 < 2e-16 ***

---

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 562.7 on 14381 degrees of freedom Multiple R-squared: 0.06596, Adjusted R-squared: 0.0657 F-statistic: 253.9 on 4 and 14381 DF, p-value: < 2.2e-16

> form.glin <- as.formula("DBWT ~ s(PWGT) + s(WTGAIN) + s(MAGER) + s(UPREVIS)")

> glinmodel <- gam(form.glin, data=train)

> glinmodel$converged [1] TRUE

> summary(glinmodel) Family: gaussian

Link function: identity Formula:

DBWT ~ s(PWGT) + s(WTGAIN) + s(MAGER) + s(UPREVIS) Parametric coefficients:

Estimate Std. Error t value Pr(>|t|) (Intercept) 3276.948 4.623 708.8 <2e-16 ***

---

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Approximate significance of smooth terms:

edf Ref.df F p-value s(PWGT) 5.374 6.443 68.981 < 2e-16 ***

s(WTGAIN) 4.719 5.743 102.313 < 2e-16 ***

s(MAGER) 7.742 8.428 6.959 1.82e-09 ***

s(UPREVIS) 5.491 6.425 48.423 < 2e-16 ***

---

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 R-sq.(adj) = 0.0927 Deviance explained = 9.42%

GCV score = 3.0804e+05 Scale est. = 3.0752e+05 n = 14386

The GAM has improved the fit, and all four variables seem to have a nonlinear rela- tionship with birth weight, as evidenced by edfs all greater than 1. We could use plot(glinmodel) to examine the shape of the s() functions; instead, we’ll compare them with a direct smoothing curve of each variable against mother’s weight.

The model explains about 7% of the variance; all coefficients are significantly different from 0.

Build a GAM with the same variables.

Verify that the model has converged.

The model explains just under 10% of the variance; all variables have a nonlinear effect significantly different from 0.

230 CHAPTER 9 Exploring advanced methods

> terms <- predict(glinmodel, type="terms")

> tframe <-

cbind(DBWT = train$DBWT, as.data.frame(terms))

> colnames(tframe) <- gsub('[()]', '', colnames(tframe))

> pframe <- cbind(tframe, train[,c("PWGT", "WTGAIN",

"MAGER", "UPREVIS")])

> p1 <- ggplot(pframe, aes(x=PWGT)) +

geom_point(aes(y=scale(sPWGT, scale=F))) + geom_smooth(aes(y=scale(DBWT, scale=F))) + [...]

The plots of the s() splines compared with the smooth curves directly relating the input variables to birth weight are shown in figure 9.6. The smooth curves in each case are similar to the corresponding s() in shape, and nonlinear for all of the variables.

As usual, we should check for overfit with hold-out data.

pred.lin <- predict(linmodel, newdata=test) pred.glin <- predict(glinmodel, newdata=test) cor(pred.lin, test$DBWT)^2

# [1] 0.0616812

cor(pred.glin, test$DBWT)^2

# [1] 0.08857426

The performance of the linear model and the GAM were similar on the test set, as they were on the training set, so in this example there’s no substantial overfit.

Listing 9.12 Plotting GAM results

Listing 9.13 Checking GAM model performance on hold-out data

Get the matrix of s() functions.

Bind in birth weight; convert to data frame.

Make the column names reference- friendly (“s(PWGT)”

is converted to “sPWGT”, etc.).

Bind in the input variables.

Plot s(PWGT) shifted to be zero mean versus PWGT (mother’s weight) as points.

Plot the smoothing curve of DWBT (birth weight) shifted to be zero mean versus PWGT (mother’s weight).

Repeat for remaining variables (omitted for brevity).

Run both the linear model and the GAM on the test data.

Calculate R-squared for both models.

231 Using generalized additive models (GAMs) to learn non-monotone relationships

9.2.5 Using GAM for logistic regression

The gam() function can be used for logistic regression as well. Suppose that we wanted to predict the birth of underweight babies (defined as DBWT < 2000) from the same variables we’ve been using. The logistic regression call to do that would be as shown in the following listing.

form <- as.formula("DBWT < 2000 ~ PWGT + WTGAIN + MAGER + UPREVIS") logmod <- glm(form, data=train, family=binomial(link="logit"))

Listing 9.14 GLM logistic regression

-500 -250 0

100 200 300

PWGT

scale(sPWGT, scale = F)

-200 0 200

0 25 50 75 100

WTGAIN

scale(sWTGAIN, scale = F)

-750 -500 -250 0

20 30 40 50

MAGER

scale(sMAGER, scale = F)

-500 -250 0 250 500

0 10 20 30 40 50

UPREVIS

scale(sUPREVIS, scale = F)

Figure 9.6 Smoothing curves of each of the four input variables plotted against birth weight, compared with the splines discovered by gam(). All curves have been shifted to be zero mean for comparison of shape.

232 CHAPTER 9 Exploring advanced methods

The corresponding call to gam() also specifies the binomial family with the logit link.

> form2 <- as.formula("DBWT<2000~s(PWGT)+s(WTGAIN)+

s(MAGER)+s(UPREVIS)")

> glogmod <- gam(form2, data=train, family=binomial(link="logit"))

> glogmod$converged [1] TRUE

> summary(glogmod) Family: binomial Link function: logit Formula:

DBWT < 2000 ~ s(PWGT) + s(WTGAIN) + s(MAGER) + s(UPREVIS) Parametric coefficients:

Estimate Std. Error z value Pr(>|z|) (Intercept) -3.94085 0.06794 -58 <2e-16 ***

---

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Approximate significance of smooth terms:

edf Ref.df Chi.sq p-value s(PWGT) 1.905 2.420 2.463 0.36412 s(WTGAIN) 3.674 4.543 64.426 1.72e-12 ***

s(MAGER) 1.003 1.005 8.335 0.00394 **

s(UPREVIS) 6.802 7.216 217.631 < 2e-16 ***

---

Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 R-sq.(adj) = 0.0331 Deviance explained = 9.14%

UBRE score = -0.76987 Scale est. = 1 n = 14386

As with the standard logistic regression call, we recover the class probabilities with the call predict(glogmodel, newdata=train, type="response"). Again these models are coming out with low quality, and in practice we would look for more explanatory variables to build better screening models.

Listing 9.15 GAM logistic regression

Note that there’s no proof that the mother’s weight (PWGT) has a significant effect on outcome.

Một phần của tài liệu Practical data science with r (Trang 248 - 259)

Tải bản đầy đủ (PDF)

(417 trang)