Skip to content

Commit

Permalink
Improve unit tests, better handling when allow_unbalanced_panel = TRU…
Browse files Browse the repository at this point in the history
…E but data is balanced, and speed improvements.
  • Loading branch information
pedrohcgs committed Sep 15, 2024
1 parent d889771 commit b4c8b89
Show file tree
Hide file tree
Showing 12 changed files with 682 additions and 597 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: did
Title: Treatment Effects with Multiple Periods and Groups
Version: 2.2.0.902
Version: 2.2.0.903
Authors@R: c(person("Brantly", "Callaway", email = "[email protected]", role = c("aut", "cre")), person("Pedro H. C.", "Sant'Anna", email="[email protected]", role = c("aut")))
URL: https://bcallaway11.github.io/did/, https://github.com/bcallaway11/did/
Description: The standard Difference-in-Differences (DID) setup involves two periods and two groups -- a treated group and untreated group. Many applications of DID methods involve more than two periods and have individuals that are treated at different points in time. This package contains tools for computing average treatment effect parameters in Difference in Differences setups with more than two periods and with variation in treatment timing using the methods developed in Callaway and Sant'Anna (2021) <doi:10.1016/j.jeconom.2020.12.001>. The main parameters are group-time average treatment effects which are the average treatment effect for a particular group at a a particular time. These can be aggregated into a fewer number of treatment effect parameters, and the package deals with the cases where there is selective treatment timing, dynamic treatment effects, calendar time effects, or combinations of these. There are also functions for testing the Difference in Differences assumption, and plotting group-time average treatment effects.
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
* Code improvements that made the package much faster and memory efficient

* Improved automated testing and regression testing

* Check if data is balanced if `panel = TRUE` and `allow_unbalanced_panel = TRUE`. If it is, disable `allow_unbalanced_panel` and proceed with panel data setup. This is different from the previous behavior, which would always proceed as if `panel = FALSE`.

# did 2.1.2

Expand Down
114 changes: 77 additions & 37 deletions R/compute.att_gt.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ compute.att_gt <- function(dp) {
#-----------------------------------------------------------------------------
# unpack DIDparams
#-----------------------------------------------------------------------------
data <- as.data.frame(dp$data)
data <- data.table::as.data.table(dp$data)
yname <- dp$yname
tname <- dp$tname
idname <- dp$idname
Expand Down Expand Up @@ -51,32 +51,30 @@ compute.att_gt <- function(dp) {
counter <- 1

# number of time periods
tlist.length <- length(tlist)
tfac <- 0

if (base_period != "universal") {
tlist.length <- tlist.length - 1
tfac <- 1
}
tlist.length <- ifelse(base_period != "universal", length(tlist) - 1, length(tlist))
tfac <- ifelse(base_period != "universal", 1, 0)

# influence function
inffunc <- Matrix::Matrix(data=0,nrow=n, ncol=nG*(nT-tfac), sparse=TRUE)
inffunc <- Matrix::Matrix(data = 0, nrow = n, ncol = nG*(nT-tfac), sparse = TRUE)

# list of collect sparse matrix updates
inffunc_updates <- list()
# counter for keeping track of updates
update_counter <- 1

# never treated option
nevertreated <- (control_group[1] == "nevertreated")

if(nevertreated) {
data$.C <- 1*(data[,gname] == 0)
data[, .C := as.integer(get(gname) == 0)]
}

# rename yname to .y
data$.y <- data[,yname]
data[, .y := get(yname), .SDcols = yname]

# loop over groups
for (g in 1:nG) {

# Set up .G once
data$.G <- 1*(data[,gname] == glist[g])
data[, .G := as.integer(get(gname) == glist[g]), .SDcols = gname]

# loop over time periods
for (t in 1:tlist.length) {
Expand All @@ -97,9 +95,9 @@ compute.att_gt <- function(dp) {
# that is, never treated + units that are eventually treated,
# but not treated by the current period (+ anticipation)
if(!nevertreated) {
data$.C <- 1 * ((data[,gname] == 0) |
((data[,gname] > (tlist[max(t,pret)+tfac]+anticipation)) &
(data[,gname] != glist[g])))
data[, .C := as.integer((get(gname) == 0) |
((get(gname) > (tlist[max(t, pret) + tfac] + anticipation)) &
(get(gname) != glist[g])))]
}


Expand Down Expand Up @@ -127,8 +125,16 @@ compute.att_gt <- function(dp) {
if (base_period == "universal") {
if (tlist[pret] == tlist[(t+tfac)]) {
attgt.list[[counter]] <- list(att=0, group=glist[g], year=tlist[(t+tfac)], post=0)
inffunc[,counter] <- rep(0,n)
counter <- counter+1
# inffunc[,counter] <- rep(0,n)
# counter <- counter+1
inffunc_updates[[update_counter]] <- list(
indices = rep(TRUE, n), # Apply to all units
values = as.matrix(rep(0, n)) # Zero influence function
)

# Update the counters
update_counter <- update_counter + 1
counter <- counter + 1
next
}
}
Expand All @@ -148,13 +154,15 @@ compute.att_gt <- function(dp) {
post.treat <- 1*(glist[g] <= tlist[t+tfac])

# total number of units (not just included in G or C)
disdat <- data[data[,tname] == tlist[t+tfac] | data[,tname] == tlist[pret],]
# disdat <- data[data[,tname] == tlist[t+tfac] | data[,tname] == tlist[pret],]
target_times <- c(tlist[t+tfac], tlist[pret])
disdat <- data[get(tname) %in% target_times]


if (panel) {
# transform disdat it into "cross-sectional" data where one of the columns
# contains the change in the outcome over time.
disdat <- panel2cs2(disdat, yname, idname, tname, balance_panel=FALSE)
disdat <- get_wide_data(disdat, yname, idname, tname)

# still total number of units (not just included in G or C)
n <- nrow(disdat)
Expand All @@ -163,7 +171,7 @@ compute.att_gt <- function(dp) {
disidx <- disdat$.G==1 | disdat$.C==1

# pick up the data that will be used to compute ATT(g,t)
disdat <- disdat[disidx,]
disdat <- disdat[disidx]

n1 <- nrow(disdat) # num obs. for computing ATT(g,t)

Expand Down Expand Up @@ -197,7 +205,8 @@ compute.att_gt <- function(dp) {

# checks for pscore based methods
if (est_method %in% c("dr", "ipw")) {
preliminary_logit <- glm(G ~ -1 + covariates, family=binomial(link=logit))
#preliminary_logit <- glm(G ~ -1 + covariates, family=binomial(link=logit))
preliminary_logit <- parglm::parglm(G ~ -1 + covariates, family = "binomial")
preliminary_pscores <- predict(preliminary_logit, type="response")
if (max(preliminary_pscores) >= 0.999) {
pscore_problems_likely <- TRUE
Expand All @@ -217,8 +226,16 @@ compute.att_gt <- function(dp) {

if (reg_problems_likely | pscore_problems_likely) {
attgt.list[[counter]] <- list(att=NA, group=glist[g], year=tlist[(t+tfac)], post=post.treat)
inffunc[,counter] <- NA
counter <- counter+1
# inffunc[,counter] <- NA
# counter <- counter+1
inffunc_updates[[update_counter]] <- list(
indices = rep(TRUE, n), # Apply to all units
values = as.matrix(rep(NA, n)) # NA influence function
)

# Update the counters
update_counter <- update_counter + 1
counter <- counter + 1
next
}
}
Expand Down Expand Up @@ -270,7 +287,7 @@ compute.att_gt <- function(dp) {
# this is the fix for unbalanced panels; 2nd criteria shouldn't do anything
# with true repeated cross sections, but should pick up the right time periods
# only with unbalanced panel
disidx <- (data$.rowid %in% rightids) & ( (data[,tname] == tlist[t+tfac]) | (data[,tname]==tlist[pret]))
disidx <- (data$.rowid %in% rightids) & ( (data[[tname]] == tlist[t+tfac]) | (data[[tname]]==tlist[pret]))

# pick up the data that will be used to compute ATT(g,t)
disdat <- data[disidx,]
Expand All @@ -281,8 +298,8 @@ compute.att_gt <- function(dp) {
# give short names for data in this iteration
G <- disdat$.G
C <- disdat$.C
Y <- disdat[,yname]
post <- 1*(disdat[,tname] == tlist[t+tfac])
Y <- disdat[[yname]]
post <- 1*(disdat[[tname]] == tlist[t+tfac])
# num obs. for computing ATT(g,t), have to be careful here
n1 <- sum(G+C)
w <- disdat$.w
Expand All @@ -309,8 +326,16 @@ compute.att_gt <- function(dp) {

if (skip_this_att_gt) {
attgt.list[[counter]] <- list(att=NA, group=glist[g], year=tlist[(t+tfac)], post=post.treat)
inffunc[,counter] <- NA
counter <- counter+1
# inffunc[,counter] <- NA
# counter <- counter+1
inffunc_updates[[update_counter]] <- list(
indices = rep(TRUE, n), # Apply to all units
values = as.matrix(rep(NA, n)) # NA influence function
)

# Update the counters
update_counter <- update_counter + 1
counter <- counter + 1
next
}

Expand Down Expand Up @@ -374,30 +399,45 @@ compute.att_gt <- function(dp) {
att = attgt$ATT, group = glist[g], year = tlist[(t+tfac)], post = post.treat
)

# recover the influence function
# start with vector of 0s because influence function
# for units that are not in G or C will be equal to 0
inf.func <- rep(0, n)

# populate the influence function in the right places
if(panel) {
inf.func[disidx] <- attgt$att.inf.func
# inf.func[disidx] <- attgt$att.inf.func
# Collect the indices and corresponding values for the update
inffunc_updates[[update_counter]] <- list(
indices = disidx,
values = attgt$att.inf.func
)
} else {
# aggregate inf functions by id (order by id)
aggte_inffunc = suppressWarnings(stats::aggregate(attgt$att.inf.func, list(rightids), sum))
disidx <- (unique(data$.rowid) %in% aggte_inffunc[,1])
inf.func[disidx] <- aggte_inffunc[,2]
#inf.func[disidx] <- aggte_inffunc[,2]
inffunc_updates[[update_counter]] <- list(
indices = disidx,
values = aggte_inffunc[,2]
)
}


# save it in influence function matrix
# inffunc[g,t,] <- inf.func
inffunc[,counter] <- inf.func
#inffunc[,counter] <- inf.func
update_counter <- update_counter + 1

# update counter
counter <- counter+1
} # end looping over t
} # end looping over g

# Apply the updates to the influence function matrix
update_inffunc <- rbindlist(lapply(seq_along(inffunc_updates), function(i) {
update <- inffunc_updates[[i]]
data.table(row = which(update$indices), col = i, value = update$values)
}))
# Apply updates to the sparse matrix
inffunc[cbind(update_inffunc$row, update_inffunc$col)] <- update_inffunc$value


return(list(attgt.list=attgt.list, inffunc=inffunc))
}
1 change: 1 addition & 0 deletions R/imports.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
#' @importFrom tidyr gather
#' @importFrom methods is
NULL
utils::globalVariables(c('.','.G','.y'))
Loading

0 comments on commit b4c8b89

Please sign in to comment.