Skip to content

Commit

Permalink
Merge pull request #136 from OHDSI/is134_getDbCohortMethodData
Browse files Browse the repository at this point in the history
Is134 get db cohort method data
  • Loading branch information
schuemie authored Apr 13, 2023
2 parents f40b976 + bbbb1d4 commit 6a766e5
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 90 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ URL: https://ohdsi.github.io/CohortMethod, https://github.com/OHDSI/CohortMethod
BugReports: https://github.com/OHDSI/CohortMethod/issues
Depends:
R (>= 3.6.0),
DatabaseConnector (>= 4.0.0),
DatabaseConnector (>= 6.0.0),
Cyclops (>= 3.1.2),
FeatureExtraction (>= 3.0.0),
Andromeda (>= 0.6.3)
Expand Down
227 changes: 142 additions & 85 deletions R/DataLoadingSaving.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,30 +118,13 @@ getDbCohortMethodData <- function(connectionDetails,
outcomeTable = "condition_occurrence",
cdmVersion = "5",
firstExposureOnly = FALSE,
removeDuplicateSubjects = FALSE,
removeDuplicateSubjects = "keep all",
restrictToCommonPeriod = FALSE,
washoutPeriod = 0,
maxCohortSize = 0,
covariateSettings) {
if (is.null(studyStartDate)) {
studyStartDate <- ""
}
if (is.null(studyEndDate)) {
studyEndDate <- ""
}
if (is.logical(removeDuplicateSubjects)) {
if (removeDuplicateSubjects) {
removeDuplicateSubjects <- "remove all"
} else {
removeDuplicateSubjects <- "keep all"
}
}
errorMessages <- checkmate::makeAssertCollection()
if (is(connectionDetails, "connectionDetails")) {
checkmate::assertClass(connectionDetails, "connectionDetails", add = errorMessages)
} else {
checkmate::assertClass(connectionDetails, "ConnectionDetails", add = errorMessages)
}
checkmate::assertClass(connectionDetails, "ConnectionDetails", add = errorMessages)
checkmate::assertCharacter(cdmDatabaseSchema, len = 1, add = errorMessages)
checkmate::assertCharacter(tempEmulationSchema, len = 1, null.ok = TRUE, add = errorMessages)
checkmate::assertInt(targetId, add = errorMessages)
Expand All @@ -162,18 +145,27 @@ getDbCohortMethodData <- function(connectionDetails,
checkmate::assertList(covariateSettings, add = errorMessages)
checkmate::reportAssertions(collection = errorMessages)

if (studyStartDate != "" && regexpr("^[12][0-9]{3}[01][0-9][0-3][0-9]$", studyStartDate) == -1) {
if (is.null(studyStartDate)) {
studyStartDate <- ""
}
if (is.null(studyEndDate)) {
studyEndDate <- ""
}
if (studyStartDate != "" &&
regexpr("^[12][0-9]{3}[01][0-9][0-3][0-9]$", studyStartDate) == -1) {
stop("Study start date must have format YYYYMMDD")
}
if (studyEndDate != "" && regexpr("^[12][0-9]{3}[01][0-9][0-3][0-9]$", studyEndDate) == -1) {
if (studyEndDate != "" &&
regexpr("^[12][0-9]{3}[01][0-9][0-3][0-9]$", studyEndDate) == -1) {
stop("Study end date must have format YYYYMMDD")
}

connection <- DatabaseConnector::connect(connectionDetails)
on.exit(DatabaseConnector::disconnect(connection))

message("Constructing target and comparator cohorts")
renderedSql <- SqlRender::loadRenderTranslateSql("CreateCohorts.sql",
renderedSql <- SqlRender::loadRenderTranslateSql(
sqlFilename = "CreateCohorts.sql",
packageName = "CohortMethod",
dbms = connectionDetails$dbms,
tempEmulationSchema = tempEmulationSchema,
Expand All @@ -191,65 +183,39 @@ getDbCohortMethodData <- function(connectionDetails,
)
DatabaseConnector::executeSql(connection, renderedSql)

sampled <- FALSE
if (maxCohortSize != 0) {
renderedSql <- SqlRender::loadRenderTranslateSql("CountCohorts.sql",
packageName = "CohortMethod",
dbms = connectionDetails$dbms,
outList <- downSample(
connection = connection,
tempEmulationSchema = tempEmulationSchema,
target_id = targetId
)
counts <- DatabaseConnector::querySql(connection, renderedSql, snakeCaseToCamelCase = TRUE)
ParallelLogger::logDebug("Pre-sample total row count is ", sum(counts$rowCount))
preSampleCounts <- dplyr::tibble(dummy = 0)
idx <- which(counts$treatment == 1)
if (length(idx) == 0) {
preSampleCounts$targetPersons <- 0
preSampleCounts$targetExposures <- 0
} else {
preSampleCounts$targetPersons <- counts$personCount[idx]
preSampleCounts$targetExposures <- counts$rowCount[idx]
}
idx <- which(counts$treatment == 0)
if (length(idx) == 0) {
preSampleCounts$comparatorPersons <- 0
preSampleCounts$comparatorExposures <- 0
} else {
preSampleCounts$comparatorPersons <- counts$personCount[idx]
preSampleCounts$comparatorExposures <- counts$rowCount[idx]
}
preSampleCounts$dummy <- NULL
if (preSampleCounts$targetExposures > maxCohortSize) {
message("Downsampling target cohort from ", preSampleCounts$targetExposures, " to ", maxCohortSize)
sampled <- TRUE
}
if (preSampleCounts$comparatorExposures > maxCohortSize) {
message("Downsampling comparator cohort from ", preSampleCounts$comparatorExposures, " to ", maxCohortSize)
sampled <- TRUE
}
if (sampled) {
renderedSql <- SqlRender::loadRenderTranslateSql("SampleCohorts.sql",
packageName = "CohortMethod",
dbms = connectionDetails$dbms,
tempEmulationSchema = tempEmulationSchema,
max_cohort_size = maxCohortSize
)
DatabaseConnector::executeSql(connection, renderedSql)
}
targetId = targetId,
maxCohortSize = maxCohortSize)
sampled <- outList$sampled
preSampleCounts <- outList$preSampleCounts
} else {
sampled <- FALSE
}

message("Fetching cohorts from server")
start <- Sys.time()
cohortSql <- SqlRender::loadRenderTranslateSql("GetCohorts.sql",
cohortSql <- SqlRender::loadRenderTranslateSql(
sqlFilename = "GetCohorts.sql",
packageName = "CohortMethod",
dbms = connectionDetails$dbms,
tempEmulationSchema = tempEmulationSchema,
target_id = targetId,
sampled = sampled
)
cohorts <- DatabaseConnector::querySql(connection, cohortSql, snakeCaseToCamelCase = TRUE)
cohorts <- DatabaseConnector::querySql(
connection = connection,
sql = cohortSql,
snakeCaseToCamelCase = TRUE)
cohorts$rowId <- as.numeric(cohorts$rowId)
ParallelLogger::logDebug("Fetched cohort total rows in target is ", sum(cohorts$treatment), ", total rows in comparator is ", sum(!cohorts$treatment))
ParallelLogger::logDebug(
"Fetched cohort total rows in target is ",
sum(cohorts$treatment),
", total rows in comparator is ",
sum(!cohorts$treatment)
)
if (nrow(cohorts) == 0) {
warning("Target and comparator cohorts are empty")
} else if (sum(cohorts$treatment == 1) == 0) {
Expand All @@ -264,7 +230,8 @@ getDbCohortMethodData <- function(connectionDetails,
studyEndDate = studyEndDate
)
if (firstExposureOnly || removeDuplicateSubjects != "keep all" || washoutPeriod != 0) {
rawCountSql <- SqlRender::loadRenderTranslateSql("CountOverallExposedPopulation.sql",
rawCountSql <- SqlRender::loadRenderTranslateSql(
sqlFilename = "CountOverallExposedPopulation.sql",
packageName = "CohortMethod",
dbms = connectionDetails$dbms,
tempEmulationSchema = tempEmulationSchema,
Expand All @@ -276,7 +243,10 @@ getDbCohortMethodData <- function(connectionDetails,
study_start_date = studyStartDate,
study_end_date = studyEndDate
)
rawCount <- DatabaseConnector::querySql(connection, rawCountSql, snakeCaseToCamelCase = TRUE)
rawCount <- DatabaseConnector::querySql(
connection = connection,
sql = rawCountSql,
snakeCaseToCamelCase = TRUE)
if (nrow(rawCount) == 0) {
counts <- dplyr::tibble(
description = "Original cohorts",
Expand Down Expand Up @@ -315,16 +285,21 @@ getDbCohortMethodData <- function(connectionDetails,
substring(label, 1) <- toupper(substring(label, 1, 1))
if (sampled) {
preSampleCounts$description <- label
metaData$attrition <- rbind(metaData$attrition, preSampleCounts)
metaData$attrition <- rbind(metaData$attrition, getCounts(cohorts, "Random sample"))
metaData$attrition <-
rbind(metaData$attrition, preSampleCounts)
metaData$attrition <-
rbind(metaData$attrition, getCounts(cohorts, "Random sample"))
} else {
metaData$attrition <- rbind(metaData$attrition, getCounts(cohorts, label))
metaData$attrition <-
rbind(metaData$attrition, getCounts(cohorts, label))
}
} else {
if (sampled) {
preSampleCounts$description <- "Original cohorts"
metaData$attrition <- preSampleCounts
metaData$attrition <- rbind(metaData$attrition, getCounts(cohorts, "Random sample"))
metaData$attrition <- rbind(
metaData$attrition,
getCounts(cohorts, "Random sample"))
} else {
metaData$attrition <- getCounts(cohorts, "Original cohorts")
}
Expand All @@ -336,7 +311,10 @@ getDbCohortMethodData <- function(connectionDetails,
} else {
cohortTable <- "#cohort_person"
}
covariateSettings <- handleCohortCovariateBuilders(covariateSettings, exposureDatabaseSchema, exposureTable)
covariateSettings <- handleCohortCovariateBuilders(
covariateSettings = covariateSettings,
exposureDatabaseSchema = exposureDatabaseSchema,
exposureTable = exposureTable)
covariateData <- FeatureExtraction::getDbCovariateData(
connection = connection,
oracleTempSchema = tempEmulationSchema,
Expand All @@ -347,10 +325,13 @@ getDbCohortMethodData <- function(connectionDetails,
rowIdField = "row_id",
covariateSettings = covariateSettings
)
ParallelLogger::logDebug("Fetched covariates total count is ", nrow_temp(covariateData$covariates))
ParallelLogger::logDebug(
"Fetched covariates total count is ",
nrow_temp(covariateData$covariates))
message("Fetching outcomes from server")
start <- Sys.time()
outcomeSql <- SqlRender::loadRenderTranslateSql("GetOutcomes.sql",
outcomeSql <- SqlRender::loadRenderTranslateSql(
sqlFilename = "GetOutcomes.sql",
packageName = "CohortMethod",
dbms = connectionDetails$dbms,
tempEmulationSchema = tempEmulationSchema,
Expand All @@ -360,25 +341,29 @@ getDbCohortMethodData <- function(connectionDetails,
outcome_ids = outcomeIds,
sampled = sampled
)
outcomes <- DatabaseConnector::querySql(connection, outcomeSql, snakeCaseToCamelCase = TRUE)
outcomes <- DatabaseConnector::querySql(
connection = connection,
sql = outcomeSql,
snakeCaseToCamelCase = TRUE)
outcomes$rowId <- as.numeric(outcomes$rowId)
metaData$outcomeIds <- outcomeIds
delta <- Sys.time() - start
message("Fetching outcomes took ", signif(delta, 3), " ", attr(delta, "units"))
ParallelLogger::logDebug("Fetched outcomes total count is ", nrow(outcomes))

# Remove temp tables:
renderedSql <- SqlRender::loadRenderTranslateSql("RemoveCohortTempTables.sql",
renderedSql <- SqlRender::loadRenderTranslateSql(
sqlFilename = "RemoveCohortTempTables.sql",
packageName = "CohortMethod",
dbms = connectionDetails$dbms,
tempEmulationSchema = tempEmulationSchema,
sampled = sampled
)
DatabaseConnector::executeSql(connection,
renderedSql,
DatabaseConnector::executeSql(
connection = connection,
sql = renderedSql,
progressBar = FALSE,
reportOverallTime = FALSE
)
reportOverallTime = FALSE)

covariateData$cohorts <- cohorts
covariateData$outcomes <- outcomes
Expand All @@ -388,13 +373,85 @@ getDbCohortMethodData <- function(connectionDetails,
return(covariateData)
}

handleCohortCovariateBuilders <- function(covariateSettings, exposureDatabaseSchema, exposureTable) {
downSample <- function(connection,
tempEmulationSchema,
targetId,
maxCohortSize) {
sampled <- FALSE
renderedSql <- SqlRender::loadRenderTranslateSql(
"CountCohorts.sql",
packageName = "CohortMethod",
dbms = connection@dbms,
tempEmulationSchema = tempEmulationSchema,
target_id = targetId
)
counts <- DatabaseConnector::querySql(connection, renderedSql, snakeCaseToCamelCase = TRUE)
ParallelLogger::logDebug("Pre-sample total row count is ", sum(counts$rowCount))
preSampleCounts <- dplyr::bind_cols(
countPreSample(id = 1, counts = counts),
countPreSample(id = 0, counts = counts)
)
if (preSampleCounts$targetExposures > maxCohortSize) {
message("Downsampling target cohort from ", preSampleCounts$targetExposures,
" to ", maxCohortSize
)
sampled <- TRUE
}
if (preSampleCounts$comparatorExposures > maxCohortSize) {
message("Downsampling comparator cohort from ", preSampleCounts$comparatorExposures,
" to ", maxCohortSize
)
sampled <- TRUE
}
if (sampled) {
renderedSql <- SqlRender::loadRenderTranslateSql(
"SampleCohorts.sql",
packageName = "CohortMethod",
dbms = connection@dbms,
tempEmulationSchema = tempEmulationSchema,
max_cohort_size = maxCohortSize
)
DatabaseConnector::executeSql(connection, renderedSql)
}
return(list(sampled = sampled, preSampleCounts = preSampleCounts))
}

countPreSample <- function(id, counts) {
preSampleCounts <- dplyr::tibble(dummy = 0)
idx <- which(counts$treatment == id)

switch(
id + 1,
{
personsCol <- "comparatorPersons"
exposuresCol <- "comparatorExposures"
}, {
personsCol <- "targetPersons"
exposuresCol <- "targetExposures"
}
)

preSampleCounts[personsCol] <- 0
preSampleCounts[exposuresCol] <- 0

if (length(idx) != 0) {
preSampleCounts[personsCol] <- counts$personCount[idx]
preSampleCounts[exposuresCol] <- counts$rowCount[idx]
}
preSampleCounts$dummy <- NULL
return(preSampleCounts)
}

handleCohortCovariateBuilders <- function(covariateSettings,
exposureDatabaseSchema,
exposureTable) {
if (is(covariateSettings, "covariateSettings")) {
covariateSettings <- list(covariateSettings)
}
for (i in 1:length(covariateSettings)) {
object <- covariateSettings[[i]]
if ("covariateCohorts" %in% names(object) && is.null(object$covariateCohortTable)) {
if ("covariateCohorts" %in% names(object) &&
is.null(object$covariateCohortTable)) {
object$covariateCohortDatabaseSchema <- exposureDatabaseSchema
object$covariateCohortTable <- exposureTable
covariateSettings[[i]] <- object
Expand Down
2 changes: 1 addition & 1 deletion man/getDbCohortMethodData.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test-simulation.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ test_that("createCohortMethodDataSimulationProfile", {
cdmVersion = "5",
washoutPeriod = 183,
firstExposureOnly = TRUE,
removeDuplicateSubjects = TRUE,
removeDuplicateSubjects = "keep all",
restrictToCommonPeriod = TRUE,
maxCohortSize = 100000,
covariateSettings = covarSettings
Expand Down Expand Up @@ -67,7 +67,7 @@ test_that("Test bad covariate data", {
cdmVersion = "5",
washoutPeriod = 183,
firstExposureOnly = TRUE,
removeDuplicateSubjects = TRUE,
removeDuplicateSubjects = "keep all",
restrictToCommonPeriod = TRUE,
maxCohortSize = 100000,
covariateSettings = covarSettings
Expand All @@ -93,7 +93,7 @@ test_that("Test bad covariate data", {
cdmVersion = "5",
washoutPeriod = 183,
firstExposureOnly = TRUE,
removeDuplicateSubjects = TRUE,
removeDuplicateSubjects = "keep all",
restrictToCommonPeriod = TRUE,
maxCohortSize = 100000,
covariateSettings = covarSettings
Expand Down

0 comments on commit 6a766e5

Please sign in to comment.