Skip to content

Commit

Permalink
Merge pull request #202 from OliverSchacht/causallearn-pr
Browse files Browse the repository at this point in the history
Add two variants of the KCI test
  • Loading branch information
MarkDana authored Nov 5, 2024
2 parents d450dd8 + b803533 commit f6aa500
Show file tree
Hide file tree
Showing 8 changed files with 1,069 additions and 3 deletions.
534 changes: 534 additions & 0 deletions causallearn/utils/FastKCI/FastKCI.py

Large diffs are not rendered by default.

Empty file.
403 changes: 403 additions & 0 deletions causallearn/utils/RCIT/RCIT.py

Large diffs are not rendered by default.

Empty file.
58 changes: 56 additions & 2 deletions causallearn/utils/cit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from scipy.stats import chi2, norm

from causallearn.utils.KCI.KCI import KCI_CInd, KCI_UInd
from causallearn.utils.FastKCI.FastKCI import FastKCI_CInd, FastKCI_UInd
from causallearn.utils.RCIT.RCIT import RCIT as RCIT_CInd
from causallearn.utils.RCIT.RCIT import RIT as RCIT_UInd
from causallearn.utils.PCUtils import Helper

CONST_BINCOUNT_UNIQUE_THRESHOLD = 1e5
Expand All @@ -13,6 +16,8 @@
mv_fisherz = "mv_fisherz"
mc_fisherz = "mc_fisherz"
kci = "kci"
rcit = "rcit"
fastkci = "fastkci"
chisq = "chisq"
gsq = "gsq"
d_separation = "d_separation"
Expand All @@ -23,15 +28,19 @@ def CIT(data, method='fisherz', **kwargs):
Parameters
----------
data: numpy.ndarray of shape (n_samples, n_features)
method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "chisq", "gsq"]
kwargs: placeholder for future arguments, or for KCI specific arguments now
method: str, in ["fisherz", "mv_fisherz", "mc_fisherz", "kci", "rcit", "fastkci", "chisq", "gsq"]
kwargs: placeholder for future arguments, or for KCI, FastKCI or RCIT specific arguments now
TODO: utimately kwargs should be replaced by explicit named parameters.
check https://github.com/cmu-phil/causal-learn/pull/62#discussion_r927239028
'''
if method == fisherz:
return FisherZ(data, **kwargs)
elif method == kci:
return KCI(data, **kwargs)
elif method == fastkci:
return FastKCI(data, **kwargs)
elif method == rcit:
return RCIT(data, **kwargs)
elif method in [chisq, gsq]:
return Chisq_or_Gsq(data, method_name=method, **kwargs)
elif method == mv_fisherz:
Expand All @@ -43,6 +52,7 @@ def CIT(data, method='fisherz', **kwargs):
else:
raise ValueError("Unknown method: {}".format(method))


class CIT_Base(object):
# Base class for CIT, contains basic operations for input check and caching, etc.
def __init__(self, data, cache_path=None, **kwargs):
Expand Down Expand Up @@ -193,6 +203,50 @@ def __call__(self, X, Y, condition_set=None):
self.pvalue_cache[cache_key] = p
return p

class FastKCI(CIT_Base):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
kci_ui_kwargs = {k: v for k, v in kwargs.items() if k in
['K', 'J', 'alpha']}
kci_ci_kwargs = {k: v for k, v in kwargs.items() if k in
['K', 'J', 'alpha', 'use_gp']}
self.check_cache_method_consistent(
'kci', hashlib.md5(json.dumps(kci_ci_kwargs, sort_keys=True).encode('utf-8')).hexdigest())
self.assert_input_data_is_valid()
self.kci_ui = FastKCI_UInd(**kci_ui_kwargs)
self.kci_ci = FastKCI_CInd(**kci_ci_kwargs)

def __call__(self, X, Y, condition_set=None):
# Kernel-based conditional independence test.
Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set)
if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key]
p = self.kci_ui.compute_pvalue(self.data[:, Xs], self.data[:, Ys])[0] if len(condition_set) == 0 else \
self.kci_ci.compute_pvalue(self.data[:, Xs], self.data[:, Ys], self.data[:, condition_set])[0]
self.pvalue_cache[cache_key] = p
return p

class RCIT(CIT_Base):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
rit_kwargs = {k: v for k, v in kwargs.items() if k in
['approx']}
rcit_kwargs = {k: v for k, v in kwargs.items() if k in
['approx', 'num_f', 'num_f2', 'rcit']}
self.check_cache_method_consistent(
'kci', hashlib.md5(json.dumps(rcit_kwargs, sort_keys=True).encode('utf-8')).hexdigest())
self.assert_input_data_is_valid()
self.rit = RCIT_UInd(**rit_kwargs)
self.rcit = RCIT_CInd(**rcit_kwargs)

def __call__(self, X, Y, condition_set=None):
# Kernel-based conditional independence test.
Xs, Ys, condition_set, cache_key = self.get_formatted_XYZ_and_cachekey(X, Y, condition_set)
if cache_key in self.pvalue_cache: return self.pvalue_cache[cache_key]
p = self.rit.compute_pvalue(self.data[:, Xs], self.data[:, Ys])[0] if len(condition_set) == 0 else \
self.rcit.compute_pvalue(self.data[:, Xs], self.data[:, Ys], self.data[:, condition_set])[0]
self.pvalue_cache[cache_key] = p
return p

class Chisq_or_Gsq(CIT_Base):
def __init__(self, data, method_name, **kwargs):
def _unique(column):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
'matplotlib',
'networkx',
'pydot',
'tqdm'
'tqdm',
'momentchi2'
],
url='https://github.com/py-why/causal-learn',
packages=setuptools.find_packages(),
Expand Down
36 changes: 36 additions & 0 deletions tests/TestCIT_FastKCI.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

import numpy as np

import causallearn.utils.cit as cit


class TestCIT_FastKCI(unittest.TestCase):
def test_Gaussian_dist(self):
np.random.seed(10)
X = np.random.randn(1200, 1)
X_prime = np.random.randn(1200, 1)
Y = X + 0.5 * np.random.randn(1200, 1)
Z = Y + 0.5 * np.random.randn(1200, 1)
data = np.hstack((X, X_prime, Y, Z))

pvalue01 = []
pvalue03 = []
pvalue032 = []
for K in [3, 10]:
for J in [8, 16]:
for use_gp in [True, False]:
cit_CIT = cit.CIT(data, 'fastkci', K=K, J=J, use_gp=use_gp)
pvalue01.append(round(cit_CIT(0, 1), 4))
pvalue03.append(round(cit_CIT(0, 3), 4))
pvalue032.append(round(cit_CIT(0, 3, {2}), 4))

pvalue01 = np.array(pvalue01)
pvalue03 = np.array(pvalue03)
pvalue032 = np.array(pvalue032)
self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)),
"pvalue01 contains invalid values")
self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)),
"pvalue03 contains invalid values")
self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)),
"pvalue032 contains invalid values")
38 changes: 38 additions & 0 deletions tests/TestCIT_RCIT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest

import numpy as np

import causallearn.utils.cit as cit


class TestCIT_RCIT(unittest.TestCase):
def test_Gaussian_dist(self):
np.random.seed(10)
X = np.random.randn(300, 1)
X_prime = np.random.randn(300, 1)
Y = X + 0.5 * np.random.randn(300, 1)
Z = Y + 0.5 * np.random.randn(300, 1)
data = np.hstack((X, X_prime, Y, Z))

pvalue01 = []
pvalue03 = []
pvalue032 = []
for approx in ["lpd4", "hbe", "gamma", "chi2", "perm"]:
for num_f in [50, 100]:
for num_f2 in [5, 10]:
for rcit in [True, False]:
cit_CIT = cit.CIT(data, 'rcit', approx=approx, num_f=num_f,
num_f2=num_f2, rcit=rcit)
pvalue01.append(round(cit_CIT(0, 1), 4))
pvalue03.append(round(cit_CIT(0, 3), 4))
pvalue032.append(round(cit_CIT(0, 3, {2}), 4))

pvalue01 = np.array(pvalue01)
pvalue03 = np.array(pvalue03)
pvalue032 = np.array(pvalue032)
self.assertTrue(np.all((0.0 <= pvalue01) & (pvalue01 <= 1.0)),
"pvalue01 contains invalid values")
self.assertTrue(np.all((0.0 <= pvalue03) & (pvalue03 <= 1.0)),
"pvalue03 contains invalid values")
self.assertTrue(np.all((0.0 <= pvalue032) & (pvalue032 <= 1.0)),
"pvalue032 contains invalid values")

0 comments on commit f6aa500

Please sign in to comment.