Skip to content

Commit

Permalink
Merge pull request #209 from marcelortizv/even-faster-did
Browse files Browse the repository at this point in the history
Even faster did
  • Loading branch information
pedrohcgs authored Nov 13, 2024
2 parents 8d33d4c + d59c156 commit effec10
Show file tree
Hide file tree
Showing 22 changed files with 2,756 additions and 1,236 deletions.
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: did
Title: Treatment Effects with Multiple Periods and Groups
Version: 2.2.0.908
Version: 2.2.1.909
Authors@R: c(person("Brantly", "Callaway", email = "[email protected]", role = c("aut", "cre")), person("Pedro H. C.", "Sant'Anna", email="[email protected]", role = c("aut")))
URL: https://bcallaway11.github.io/did/, https://github.com/bcallaway11/did/
Description: The standard Difference-in-Differences (DID) setup involves two periods and two groups -- a treated group and untreated group. Many applications of DID methods involve more than two periods and have individuals that are treated at different points in time. This package contains tools for computing average treatment effect parameters in Difference in Differences setups with more than two periods and with variation in treatment timing using the methods developed in Callaway and Sant'Anna (2021) <doi:10.1016/j.jeconom.2020.12.001>. The main parameters are group-time average treatment effects which are the average treatment effect for a particular group at a a particular time. These can be aggregated into a fewer number of treatment effect parameters, and the package deals with the cases where there is selective treatment timing, dynamic treatment effects, calendar time effects, or combinations of these. There are also functions for testing the Difference in Differences assumption, and plotting group-time average treatment effects.
Expand All @@ -19,8 +19,9 @@ Imports:
generics,
methods,
tidyr,
parglm (>= 0.1.7),
data.table (>= 1.15.4),
parglm (>= 0.1.7)
dreamerr (>= 1.4.0)
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.2
VignetteBuilder: knitr
Expand Down
11 changes: 10 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ export(att_gt)
export(build_sim_dataset)
export(compute.aggte)
export(compute.att_gt)
export(compute.att_gt2)
export(conditional_did_pretest)
export(get_wide_data)
export(ggdid)
export(glance)
export(gplot)
export(indicator)
export(mboot)
export(pre_process_did)
export(pre_process_did2)
export(process_attgt)
export(reset.sim)
export(sim)
Expand All @@ -41,7 +42,15 @@ import(ggplot2)
import(ggpubr)
import(stats)
import(utils)
importFrom(DRDID,drdid_panel)
importFrom(DRDID,drdid_rc)
importFrom(DRDID,reg_did_panel)
importFrom(DRDID,reg_did_rc)
importFrom(DRDID,std_ipw_did_panel)
importFrom(DRDID,std_ipw_did_rc)
importFrom(dreamerr,check_set_arg)
importFrom(generics,glance)
importFrom(generics,tidy)
importFrom(methods,as)
importFrom(methods,is)
importFrom(tidyr,gather)
2 changes: 2 additions & 0 deletions R/DIDparams.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ DIDparams <- function(yname,
clustervars=NULL,
cband=TRUE,
print_details=TRUE,
faster_mode=FALSE,
pl=FALSE,
cores=1,
est_method="dr",
Expand Down Expand Up @@ -59,6 +60,7 @@ DIDparams <- function(yname,
clustervars=clustervars,
cband=cband,
print_details=print_details,
faster_mode=faster_mode,
pl=pl,
cores=cores,
est_method=est_method,
Expand Down
91 changes: 91 additions & 0 deletions R/DIDparams2.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#' @title DIDparams
#'
#' @description Object to hold DiD parameters that are passed across functions
#'
#' @inheritParams att_gt2
#' @inheritParams pre_process_did2
#' @param did_tensor list of outcome tensors that are used in the estimation
#' @param args list of arguments that are used in the estimation
#' @noRd
DIDparams2 <- function(did_tensors, args, call=NULL) {
# get the arguments from args
yname <- args$yname
tname <- args$tname
idname <- args$idname
gname <- args$gname
xformla <- args$xformla # formula of covariates
panel <- args$panel
est_method <- args$est_method
bstrap <- args$bstrap
biters <- args$biters
cband <- args$cband
anticipation <- args$anticipation
control_group <- args$control_group
allow_unbalanced_panel <- args$allow_unbalanced_panel
weightsname <- args$weightsname
base_period <- args$base_period
clustervars <- args$clustervars
cores <- args$cores
pl <- args$pl
print_details <- args$print_details
faster_mode <- args$faster_mode
alp <- args$alp
true_repeated_cross_sections <- args$true_repeated_cross_sections
time_periods_count <- args$time_periods_count
time_periods <- args$time_periods
treated_groups_count <- args$treated_groups_count
treated_groups <- args$treated_groups
id_count <- args$id_count

# get the arguments from did_tensors
outcomes_tensor <- did_tensors$outcomes_tensor
data <- did_tensors$data
time_invariant_data <- did_tensors$time_invariant_data
cohort_counts <- did_tensors$cohort_counts
period_counts <- did_tensors$period_counts
crosstable_counts <- did_tensors$crosstable_counts
covariates <- did_tensors$covariates # matrix of covariates
cluster_vector <- did_tensors$cluster
weights_vector <- did_tensors$weights


out <- list(yname=yname,
tname=tname,
idname=idname,
gname=gname,
xformla=xformla,
panel=panel,
est_method=est_method,
bstrap=bstrap,
biters=biters,
cband=cband,
anticipation=anticipation,
control_group=control_group,
allow_unbalanced_panel=allow_unbalanced_panel,
weightsname=weightsname,
base_period=base_period,
clustervars=clustervars,
cores=cores,
pl = pl,
print_details=print_details,
faster_mode=faster_mode,
alp=alp,
true_repeated_cross_sections=true_repeated_cross_sections,
time_periods_count=time_periods_count,
time_periods=time_periods,
treated_groups_count=treated_groups_count,
treated_groups=treated_groups,
id_count=id_count,
outcomes_tensor=outcomes_tensor,
data=data,
time_invariant_data=time_invariant_data,
cohort_counts=cohort_counts,
period_counts=period_counts,
crosstable_counts=crosstable_counts,
covariates=covariates,
cluster_vector=cluster_vector,
weights_vector=weights_vector,
call=call)
class(out) <- "DIDparams"
return(out)
}
117 changes: 86 additions & 31 deletions R/att_gt.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@
#' @param anticipation The number of time periods before participating
#' in the treatment where units can anticipate participating in the
#' treatment and therefore it can affect their untreated potential outcomes
#' @param faster_mode This option enables a faster version of `did`, optimizing
#' computation time for large datasets by improving data management within the package.
#' The default is set to `FALSE`. While the difference is minimal for small datasets,
#' it is recommended for use with large datasets.
#' @param base_period Whether to use a "varying" base period or a
#' "universal" base period. Either choice results in the same
#' post-treatment estimates of ATT(g,t)'s. In pre-treatment
Expand Down Expand Up @@ -181,47 +185,98 @@ att_gt <- function(yname,
clustervars=NULL,
est_method="dr",
base_period="varying",
faster_mode=FALSE,
print_details=FALSE,
pl=FALSE,
cores=1) {

# this is a DIDparams object
dp <- pre_process_did(yname=yname,
tname=tname,
idname=idname,
gname=gname,
xformla=xformla,
data=data,
panel=panel,
allow_unbalanced_panel=allow_unbalanced_panel,
control_group=control_group,
anticipation=anticipation,
weightsname=weightsname,
alp=alp,
bstrap=bstrap,
cband=cband,
biters=biters,
clustervars=clustervars,
est_method=est_method,
base_period=base_period,
print_details=print_details,
pl=pl,
cores=cores,
call=match.call()
)

#-----------------------------------------------------------------------------
# Compute all ATT(g,t)
#-----------------------------------------------------------------------------
results <- compute.att_gt(dp)
# Check if user wants to run faster mode:
if (faster_mode) {
# this is a DIDparams2 object
dp <- pre_process_did2(yname=yname,
tname=tname,
idname=idname,
gname=gname,
xformla=xformla,
data=data,
panel=panel,
allow_unbalanced_panel=allow_unbalanced_panel,
control_group=control_group,
anticipation=anticipation,
weightsname=weightsname,
alp=alp,
bstrap=bstrap,
cband=cband,
biters=biters,
clustervars=clustervars,
est_method=est_method,
base_period=base_period,
print_details=print_details,
faster_mode=faster_mode,
pl=pl,
cores=cores,
call=match.call()
)

#-----------------------------------------------------------------------------
# Compute all ATT(g,t)
#-----------------------------------------------------------------------------
results <- compute.att_gt2(dp)

} else {
# this is a DIDparams object
dp <- pre_process_did(yname=yname,
tname=tname,
idname=idname,
gname=gname,
xformla=xformla,
data=data,
panel=panel,
allow_unbalanced_panel=allow_unbalanced_panel,
control_group=control_group,
anticipation=anticipation,
weightsname=weightsname,
alp=alp,
bstrap=bstrap,
cband=cband,
biters=biters,
clustervars=clustervars,
est_method=est_method,
base_period=base_period,
print_details=print_details,
pl=pl,
cores=cores,
call=match.call()
)

#-----------------------------------------------------------------------------
# Compute all ATT(g,t)
#-----------------------------------------------------------------------------
results <- compute.att_gt(dp)
}

# extract ATT(g,t) and influence functions
attgt.list <- results$attgt.list
inffunc <- results$inffunc

# process results
attgt.results <- process_attgt(attgt.list)
# attgt.results <- process_attgt(attgt.list)
tryCatch(
{
# Attempt to run this line for process results
attgt.results <- process_attgt(attgt.list)
},
error = function(e) {
# Handle the error
if (faster_mode) {
# If faster_mode is TRUE, send this stop message
stop("An unexpected error occurred, normally associated with a singular matrix due to not enough control units. Try changing faster_mode=FALSE.")
} else {
# If faster_mode is FALSE, send this stop message
stop("An unexpected error occurred, normally associated with a singular matrix due to not enough control units.")
}
}
)
group <- attgt.results$group
att <- attgt.results$att
tt <- attgt.results$tt
Expand All @@ -236,7 +291,7 @@ att_gt <- function(yname,
# note to self: this def. won't work with unbalanced panel,
# same with clustered standard errors
# but it is always ignored b/c bstrap has to be true in that case
n <- dp$n
n <- ifelse(faster_mode, dp$id_count, dp$n)
V <- Matrix::t(inffunc)%*%inffunc/n
se <- sqrt(Matrix::diag(V)/n)

Expand Down
Loading

0 comments on commit effec10

Please sign in to comment.