Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vasilis/orf speed #316

Merged
merged 31 commits into from
Nov 17, 2020
Merged

Vasilis/orf speed #316

merged 31 commits into from
Nov 17, 2020

Conversation

vsyrgkanis
Copy link
Collaborator

@vsyrgkanis vsyrgkanis commented Nov 13, 2020

  • Speed up prediction time of ORF. Replaced in-class pointwise_effect function with static method that takes the subset of arrays for each point. Also using processes instead of threading. Leads to at around a 10x speed-up.
  • Enabled global residualization in continuous treatment orthoforest. This then replicates exactly the functionality of grf and brings parity with grf. Added class causal_forest.CausalForest, that essentially is a continuous treatment ortho forest with global residualization. Global residualization is computationally much faster, but might lose statistical power especially with W not none, as it fits a global residualizer and not a local residualizer for each target test point. This now allows us to do sth like ForestDML but with many treatments. Moreover, unlike ForestDML this uses exactly the splitting criterion of the causal forest.
  • Enabled _inference in all ortho forests by adding inference methods in the BLBInference class in orthoforest.

Speed up example

For instance, here is a code example with 1000 trees and 1000 prediction points:

np.random.seed(123)
n = 2000
p = 10
X = np.random.normal(size=(n, p))
def true_propensity(x): return .4 + .2 * (x[:, 0] > 0)
def true_effect(x): return (x[:, 0] * (x[:, 0] > 0))
def true_conf(x): return x[:, 1] + np.clip(x[:, 2], - np.inf, 0)

T = np.random.binomial(1, true_propensity(X))
Y = true_effect(X) * T + true_conf(X) + np.random.normal(size=(n,))

X_test = np.zeros((1000, p))
X_test[:, 0] = np.linspace(-2, 2, 1000)


est3 = DiscreteTreatmentOrthoForest(model_Y=Lasso(alpha=0.01),
                                    propensity_model=LogisticRegression(C=1),
                                    model_Y_final=WeightedLassoCV(cv=3),
                                    propensity_model_final=LogisticRegressionCV(cv=3),
                                    n_trees=1000, min_leaf_size=10)
est3.fit(Y, T, X)
lb, ub = est3.effect_interval(X_test)
pred3 = est3.effect(X_test)

Here is the running time of the new code:
image

and here is the running time of the current master code:
image

Moreover, now there is a big benefit for parallelism as with 8 cores we get around 5x speed-up, compared to non-parallel.
image

@vsyrgkanis vsyrgkanis added the enhancement New feature or request label Nov 14, 2020
…rs in discrete ORF and passed directly the one hot encoding as T. Removed cross-fitting for Y_hat in the first stage from discrete ORF, which was there by mistake. Removed the creation of split_indices if split_indices is None, but we are in the first stage, since we are not doing cross fitting. Removed the use of np.insert, as it was slower than setting a slice to a constant.
…lynomialfeatures(degree=1, include_bias=True) with np.hstack
… ortho forest. This now replicates exactly the grf functionality. Added some missing tests regarding shape of output of orf and fixed some bad shapes according to API for column y or column t. Added tests for the global residualization. Replaced polynomial fit trasnform in second stage param func with hstack.
… with global_res=True in forest basic examples notebook.
econml/causal_forest.py Outdated Show resolved Hide resolved
econml/ortho_forest.py Show resolved Hide resolved
econml/ortho_forest.py Show resolved Hide resolved
econml/sklearn_extensions/model_selection.py Outdated Show resolved Hide resolved
econml/utilities.py Show resolved Hide resolved
vasilismsr and others added 3 commits November 16, 2020 17:22
…old names with warning. Made Regwrapper private
… both in local and global residualization. Also adding check that all treatments are represented in nuisance estimator, similar to the drorthoforest.
econml/ortho_forest.py Outdated Show resolved Hide resolved
econml/ortho_forest.py Show resolved Hide resolved
econml/ortho_forest.py Outdated Show resolved Hide resolved
econml/ortho_forest.py Outdated Show resolved Hide resolved
econml/sklearn_extensions/model_selection.py Outdated Show resolved Hide resolved
econml/tests/test_orf.py Outdated Show resolved Hide resolved
@vsyrgkanis vsyrgkanis merged commit 839c225 into master Nov 17, 2020
@vsyrgkanis vsyrgkanis deleted the vasilis/orf_speed branch November 17, 2020 20:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants