-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathpredict.sgd.R
70 lines (68 loc) · 2.09 KB
/
predict.sgd.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#' Model Predictions
#'
#' Form predictions using the estimated model parameters from stochastic
#' gradient descent.
#'
#' @param object object of class \code{sgd}.
#' @param newdata design matrix to form predictions on
#' @param type the type of prediction required. The default "link" is
#' on the scale of the linear predictors; the alternative '"response"'
#' is on the scale of the response variable. Thus for a default
#' binomial model the default predictions are of log-odds
#' (probabilities on logit scale) and 'type = "response"' gives the
#' predicted probabilities. The '"terms"' option returns a matrix
#' giving the fitted values of each term in the model formula on the
#' linear predictor scale.
#' @param \dots further arguments passed to or from other methods.
#'
#' @details
#' A column of 1's must be included to \code{newdata} if the
#' parameters include a bias (intercept) term.
#'
#' @export
predict.sgd <- function(object, newdata, type="link", ...) {
if (!(object$model %in% c("lm", "glm", "m"))) {
stop("'model' not supported")
}
if (!(type %in% c("link", "response", "term"))) {
stop("'type' not recognized")
}
if (object$model %in% c("lm", "glm")) {
if (type %in% c("link", "response")) {
eta <- newdata %*% coef(object)
if (type == "response") {
y <- object$model.out$family$linkinv(eta)
return(y)
}
return(eta)
}
eta <- newdata %*% diag(coef(object))
return(eta)
} else if (object$model == "m") {
if (type %in% c("link", "response")) {
eta <- newdata %*% coef(object)
if (type == "response") {
y <- eta
return(y)
}
return(eta)
}
eta <- newdata %*% diag(coef(object))
return(eta)
}
}
#' @export
#' @rdname predict.sgd
predict_all <- function(object, newdata, ...) {
if (object$model %in% c("lm", "glm")) {
eta <- newdata %*% object$estimates
y <- object$model.out$family$linkinv(eta)
} else if (object$model == "m") {
eta <- newdata %*% object$estimates
y <- eta
# TODO
} else {
stop("'model' not recognized")
}
return(y)
}