Skip to content

Commit

Permalink
R interface to gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Jul 8, 2024
1 parent 2cb36c3 commit 4e5acda
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ S3method(vcov,cyclopsFit)
export(Multitype)
export(aconfint)
export(cacheCyclopsModelForJava)
export(clearCyclopsModelCache)
export(convertToCyclopsData)
export(convertToTimeVaryingCoef)
export(coverage)
Expand All @@ -38,6 +39,7 @@ export(getNumberOfRows)
export(getNumberOfStrata)
export(getUnivariableCorrelation)
export(getUnivariableSeparability)
export(gradient)
export(isInitialized)
export(listGPUDevices)
export(meanLinearPredictor)
Expand Down
18 changes: 18 additions & 0 deletions R/ModelFit.R
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,24 @@ logLik.cyclopsFit <- function(object, ...) {
out
}

#' @title Extract gradient
#'
#' @description
#' \code{gradient} returns the current gradient wrt the regression parameters of
#' the log-likelihood of the fit in a Cyclops model fit object
#'
#' @param object A Cyclops model fit object
#'
#' @export
gradient <- function(object) {

.checkInterface(object$cyclopsData, testOnly = TRUE)
gradient <- .cyclopsGetLogLikelihoodGradient(object$interface)
names(gradient) <- names(coef(object))

return(gradient)
}


#' @method print cyclopsFit
#' @title Print a Cyclops model fit object
Expand Down
14 changes: 14 additions & 0 deletions man/cyclopsLibraryFileName.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions man/gradient.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/cyclops/CyclicCoordinateDescent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,7 @@ double CyclicCoordinateDescent::getLogLikelihoodGradient(int index) {
checkAllLazyFlags();

double gradient, hessian;
computeNumeratorForGradient(index);
modelSpecifics.computeGradientAndHessian(index, &gradient, &hessian, useCrossValidation);

return gradient;
Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/test-gradient.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
library("testthat")
library("survival")

suppressWarnings(RNGversion("3.5.0"))

test_that("gradient", {

data <- Cyclops::createCyclopsData(Surv(stop, event) ~ (rx - 1) + size, data = bladder, modelType = "cox")

fit <- Cyclops::fitCyclopsModel(data)

gradientAtMode <- gradient(fit)
})


0 comments on commit 4e5acda

Please sign in to comment.