Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TestFCI failed in graph_utils.adj_precision #149

Closed
winstonyu opened this issue Nov 7, 2023 · 3 comments
Closed

TestFCI failed in graph_utils.adj_precision #149

winstonyu opened this issue Nov 7, 2023 · 3 comments

Comments

@winstonyu
Copy link

I tried to run TestFCI without modifications, but got errors like below. Any ideas what's going on? Thanks.

=======================================================================

/Users/xxx/opt/anaconda3/envs/python311/bin/python /Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pycharm/_jb_unittest_runner.py --target TestFCI.TestFCI
Testing started at 16:17 ...
Launching unittests with arguments python -m unittest TestFCI.TestFCI in /Users/xxx/PycharmProjects/pythonProject/causal-learn/tests

Depth=0, working on node 7: 100%|██████████| 8/8 [00:00<00:00, 522.64it/s]
Depth=0, working on node 4: 100%|██████████| 5/5 [00:00<00:00, 822.09it/s]
Depth=0, working on node 4: 100%|██████████| 5/5 [00:00<00:00, 1059.54it/s]
X3 --> X4
X3 --> X5
Depth=0, working on node 10: 100%|██████████| 11/11 [00:00<00:00, 516.45it/s]
X4 --> X2
X8 --> X2
X8 --> X3
X9 --> X3
X8 --> X4
X9 --> X4
X8 --> X5
X9 --> X5
X8 --> X9
Depth=0, working on node 5: 100%|██████████| 6/6 [00:00<00:00, 694.36it/s]
Depth=0, working on node 36: 100%|██████████| 37/37 [00:00<00:00, 150.56it/s]
X5 --> X2
X5 --> X3
X7 --> X36
X35 --> X9
X35 --> X10
X35 --> X12
X31 --> X16
X33 --> X16
X31 --> X18
X20 --> X21
X32 --> X20
X24 --> X21
X30 --> X26
X29 --> X30
X30 --> X31
X31 --> X32
X32 --> X33
X33 --> X34
X34 --> X35
X35 --> X36
Depth=0, working on node 47: 100%|██████████| 48/48 [00:00<00:00, 113.35it/s]
X4 --> X11
X7 --> X8
X8 --> X14
X8 --> X16
X8 --> X25
X10 --> X12
X10 --> X13
X11 --> X12
X13 --> X14
X14 --> X15
X15 --> X24
X19 --> X36
X20 --> X21
X20 --> X34
X39 --> X20
X21 --> X41
X24 --> X25
X28 --> X33
X32 --> X33
X34 --> X44
X38 --> X42
Depth=0, working on node 19: 100%|██████████| 20/20 [00:00<00:00, 334.04it/s]
X2 --> X8
X16 --> X2
X17 --> X2
X3 --> X9
X17 --> X3
X5 --> X11
X19 --> X5
X6 --> X13
X12 --> X14
X12 --> X16
X12 --> X17
X12 --> X19
Depth=0, working on node 26: 100%|██████████| 27/27 [00:00<00:00, 203.18it/s]
X2 --> X1
X2 --> X14
X3 --> X9
X3 --> X18
X4 --> X10
X4 --> X18
X4 --> X19
X4 --> X27
X8 --> X6
X6 --> X15
X9 --> X7
X7 --> X24
X10 --> X8
X8 --> X21
X8 --> X23
X8 --> X26
X9 --> X12
X9 --> X17
X9 --> X25
X13 --> X10
X13 --> X27
X15 --> X20
X21 --> X20
X24 --> X23
X25 --> X24
Depth=0, working on node 31: 100%|██████████| 32/32 [00:00<00:00, 174.52it/s]
X10 --> X2
X2 --> X11
X2 --> X12
X18 --> X10
X10 --> X19
X11 --> X19
X12 --> X20
X18 --> X26
X19 --> X27
X19 --> X30
X20 --> X28
X20 --> X31
X21 --> X29
X23 --> X31
X24 --> X32
Depth=0, working on node 55: 100%|██████████| 56/56 [00:00<00:00, 89.87it/s]
X8 --> X9
X8 --> X13
X12 --> X13
X15 --> X17
X15 --> X20
X15 --> X25
X17 --> X19
X25 --> X44
X35 --> X36
X40 --> X42
X42 --> X43
X43 --> X44
Depth=0, working on node 69: 100%|██████████| 70/70 [00:00<00:00, 82.04it/s]
X7 --> X8
X8 --> X9
X13 --> X14
X20 --> X13
X14 --> X24
X14 --> X29
X14 --> X32
X14 --> X33
X14 --> X49
X14 --> X53
X14 --> X54
X14 --> X70
X18 --> X19
X19 --> X37
X19 --> X39
X19 --> X59
X19 --> X61
X19 --> X64
X19 --> X65
X19 --> X68
X19 --> X70
X22 --> X30
X39 --> X40
Depth=0, working on node 75: 100%|██████████| 76/76 [00:00<00:00, 78.02it/s]
X3 --> X29
X9 --> X15
X15 --> X17
X17 --> X30
X17 --> X66
X21 --> X30
X25 --> X30
X31 --> X39
X31 --> X43
X31 --> X46
X31 --> X48
X31 --> X54
X31 --> X75
X39 --> X72
X46 --> X49
X48 --> X49
X54 --> X55
X71 --> X72
Depth=0, working on node 222: 100%|██████████| 223/223 [00:09<00:00, 24.69it/s]
X24 --> X26
X26 --> X28
X30 --> X133
X32 --> X36
X35 --> X36
X38 --> X39
X39 --> X54
X41 --> X44
X46 --> X48
X48 --> X50
X50 --> X52
X50 --> X66
X52 --> X54
X52 --> X84
X52 --> X89
X54 --> X59
X59 --> X60
X59 --> X127
X62 --> X144
X66 --> X68
X66 --> X74
X70 --> X72
X72 --> X74
X78 --> X80
X81 --> X82
X84 --> X124
X87 --> X128
X91 --> X92
X91 --> X129
X92 --> X129
X96 --> X101
X105 --> X106
X110 --> X112
X112 --> X114
X114 --> X116
X116 --> X118
X116 --> X131
X118 --> X120
X118 --> X149
X118 --> X151
X120 --> X122
X122 --> X123
X124 --> X125
X131 --> X133
X131 --> X144
X131 --> X147
X131 --> X153
X133 --> X153
X144 --> X153
X151 --> X223
X155 --> X157
X157 --> X158
X158 --> X159
X161 --> X163
X161 --> X164
X163 --> X194
X163 --> X195
X163 --> X196
X163 --> X208
X163 --> X222
X164 --> X166
X164 --> X168
X164 --> X170
X164 --> X186
X164 --> X193
X166 --> X174
X166 --> X175
X168 --> X171
X174 --> X179
X174 --> X181
X179 --> X183
X186 --> X187
X186 --> X206
X187 --> X191
X194 --> X199
X194 --> X200
X199 --> X203
X200 --> X201
X206 --> X223
X219 --> X221
X222 --> X223
Depth=0, working on node 19: 100%|██████████| 20/20 [00:00<00:00, 454.79it/s]
X3 --> X6
X3 --> X10
X4 --> X6
X10 --> X12
X10 --> X16
X12 --> X15
X15 --> X19
X16 --> X20
[(3, 13), (4, 11), (6, 1), (8, 4), (9, 1), (9, 5), (10, 2), (11, 7), (12, 1), (12, 2), (13, 4), (13, 5), (14, 0), (14, 8), (14, 13)]
/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/search/ConstraintBased/FCI.py:736: UserWarning: The number of features is much larger than the sample size!
warnings.warn("The number of features is much larger than the sample size!")
Depth=0, working on node 9: 100%|██████████| 10/10 [00:00<00:00, 956.95it/s]
X5 --> X8
Graph Nodes:
X1;X2;X3;X4;X5;X6;X7;X8;X9;X10;X11;X12;X13;X14;X15

Graph Edges:

  1. X15 --> X1
  2. X7 --> X2
  3. X10 --> X2
  4. X13 --> X2
  5. X11 --> X3
  6. X13 --> X3
  7. X4 --> X14
  8. X9 --> X5
  9. X5 --> X12
  10. X14 --> X5
  11. X10 --> X6
  12. X14 --> X6
  13. X12 --> X8
  14. X15 --> X9
  15. X15 --> X14

pag:
Graph Nodes:
X1;X7;X6;X3;X2;X8;X5;X10;X9;X4

Graph Edges:

  1. X1 o-> X6
  2. X1 o-> X5
  3. X1 o-o X9
  4. X7 o-> X2
  5. X5 o-> X6
  6. X10 o-> X6
  7. X9 o-> X6
  8. X4 o-> X6
  9. X3 o-> X2
  10. X10 o-> X2
  11. X5 --> X8
  12. X9 o-> X5
  13. X4 o-> X5

fci graph:
Graph Nodes:
X1;X2;X3;X4;X5;X6;X7;X8;X9;X10

Graph Edges:

  1. X1 o-> X5
  2. X1 o-> X6
  3. X1 o-o X9
  4. X3 o-> X2
  5. X7 o-> X2
  6. X10 o-> X2
  7. X4 o-> X5
  8. X4 o-> X6
  9. X5 o-> X6
  10. X5 --> X8
  11. X9 o-> X5
  12. X9 o-> X6
  13. X10 o-> X6

fci(data, d_separation, 0.05):

Error
Traceback (most recent call last):
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 213, in test_er_graph
self.run_simulate_data_test(pag, G)
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test
arrow_precision = graph_utils.arrow_precision(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision
confusion = ArrowConfusion(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init
if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq
return self.value == other.value
^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'value'

Depth=0, working on node 6: 100%|██████████| 7/7 [00:00<00:00, 1045.92it/s]

fci(data, d_separation, 0.05):

Error
Traceback (most recent call last):
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 147, in test_fritl
self.run_simulate_data_test(pag, G)
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test
arrow_precision = graph_utils.arrow_precision(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision
confusion = ArrowConfusion(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init
if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq
return self.value == other.value
^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'value'

Depth=0, working on node 3: 100%|██████████| 4/4 [00:00<00:00, 2099.78it/s]

fci(data, d_separation, 0.05):

Error
Traceback (most recent call last):
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 81, in test_simple_test
self.run_simulate_data_test(pag, G)
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test
arrow_precision = graph_utils.arrow_precision(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision
confusion = ArrowConfusion(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init
if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq
return self.value == other.value
^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'value'

Depth=0, working on node 6: 100%|██████████| 7/7 [00:00<00:00, 993.74it/s]
X4 --> X1
X2 --> X5

fci(data, d_separation, 0.05):

Error
Traceback (most recent call last):
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 107, in test_simple_test2
self.run_simulate_data_test(pag, G)
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test
arrow_precision = graph_utils.arrow_precision(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision
confusion = ArrowConfusion(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init
if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq
return self.value == other.value
^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'value'

Depth=0, working on node 4: 100%|██████████| 5/5 [00:00<00:00, 1673.84it/s]
X3 --> X4
X3 --> X5

fci(data, d_separation, 0.05):

Error
Traceback (most recent call last):
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 127, in test_simple_test3
self.run_simulate_data_test(pag, G)
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/tests/TestFCI.py", line 154, in run_simulate_data_test
arrow_precision = graph_utils.arrow_precision(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/utils/GraphUtils.py", line 431, in arrow_precision
confusion = ArrowConfusion(truth, est)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/ArrowConfusion.py", line 43, in init
if truth.get_endpoint(truth.get_node(nodes_name[i]), truth.get_node(nodes_name[j])) == Endpoint.ARROW:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxx/PycharmProjects/pythonProject/causal-learn/causallearn/graph/Endpoint.py", line 23, in eq
return self.value == other.value
^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'value'

Ran 7 tests in 232.656s

FAILED (errors=5)

Process finished with exit code 1

@priamai
Copy link

priamai commented Nov 16, 2023

Can you show me the code exactly from the unit test?

@winstonyu
Copy link
Author

winstonyu commented Nov 20, 2023

Can you show me the code exactly from the unit test?

import hashlib
import os
import random
import sys
sys.path.append("")
import time
import unittest

from networkx import DiGraph, erdos_renyi_graph, is_directed_acyclic_graph
import numpy as np
import pandas as pd

from causallearn.graph.Dag import Dag
from causallearn.graph.GraphNode import GraphNode
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.utils.cit import chisq, fisherz, kci, d_separation
from causallearn.utils.DAG2PAG import dag2pag
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge

######################################### Test Notes ###########################################
# All the benchmark results of loaded files (e.g. "./TestData/benchmark_returned_results/")    #
# are obtained from the code of causal-learn as of commit                                      #
# https://github.com/py-why/causal-learn/commit/5918419 (02-03-2022).                          #
#                                                                                              #
# We are not sure if the results are completely "correct" (reflect ground truth graph) or not. #
# So if you find your tests failed, it means that your modified code is logically inconsistent #
# with the code as of 5918419, but not necessarily means that your code is "wrong".            #
# If you are sure that your modification is "correct" (e.g. fixed some bugs in 5918419),       #
# please report it to us. We will then modify these benchmark results accordingly. Thanks :)   #
######################################### Test Notes ###########################################

BENCHMARK_TXTFILE_TO_MD5 = {
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_asia_fci_chisq_0.05.txt": "65f54932a9d8224459e56c40129e6d8b",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_cancer_fci_chisq_0.05.txt": "0312381641cb3b4818e0c8539f74e802",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_earthquake_fci_chisq_0.05.txt": "a1160b92ce15a700858552f08e43b7de",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_sachs_fci_chisq_0.05.txt": "dced4a202fc32eceb75f53159fc81f3b",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_survey_fci_chisq_0.05.txt": "b1a28eee1e0c6ea8a64ac1624585c3f4",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_alarm_fci_chisq_0.05.txt": "c3bbc2b8aba456a4258dd071a42085bc",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_barley_fci_chisq_0.05.txt": "4a5000e7a582083859ee6aef15073676",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_child_fci_chisq_0.05.txt": "6b7858589e12f04b0f489ba4589a1254",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_insurance_fci_chisq_0.05.txt": "9975942b936aa2b1fc90c09318ca2d08",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_water_fci_chisq_0.05.txt": "48eee804d59526187b7ecd0519556ee5",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hailfinder_fci_chisq_0.05.txt": "6b9a6b95b6474f8530e85c022f4e749c",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_hepar2_fci_chisq_0.05.txt": "4aae21ff3d9aa2435515ed2ee402294c",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_win95pts_fci_chisq_0.05.txt": "648fdf271e1440c06ca2b31b55ef1f3f",
    "tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_andes_fci_chisq_0.05.txt": "04092ae93e54c727579f08bf5dc34c77",
    "tests/TestData/benchmark_returned_results/linear_10_fci_fisherz_0.05.txt": "289c86f9c665bf82bbcc4c9e1dcec3e7"
}
#
INCONSISTENT_RESULT_GRAPH_ERRMSG = "Returned graph is inconsistent with the benchmark. Please check your code with the commit 5918419."
INCONSISTENT_RESULT_GRAPH_WITH_PAG_ERRMSG = "Returned graph is inconsistent with the truth PAG."

# verify files integrity first
for file_path, expected_MD5 in BENCHMARK_TXTFILE_TO_MD5.items():
    with open(file_path, 'rb') as fin:
        assert hashlib.md5(fin.read()).hexdigest() == expected_MD5, \
            f'{file_path} is corrupted. Please download it again from https://github.com/cmu-phil/causal-learn/blob/5918419/tests/TestData'


def gen_coef():
    return np.random.uniform(1, 3)


class TestFCI(unittest.TestCase):
    def test_simple_test(self):
        data = np.empty(shape=(0, 4))
        true_dag = DiGraph()
        ground_truth_edges = [(0, 1), (0, 2), (1, 3), (2, 3)]
        true_dag.add_edges_from(ground_truth_edges)
        G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)

        ground_truth_nodes = []
        for i in range(4):
            ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
        ground_truth_dag = Dag(ground_truth_nodes)
        for u, v in ground_truth_edges:
            ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])
        pag = dag2pag(ground_truth_dag, [])

        print(f'fci(data, d_separation, 0.05):')
        self.run_simulate_data_test(pag, G)

        nodes = G.get_nodes()
        assert G.is_adjacent_to(nodes[0], nodes[1])

        bk = BackgroundKnowledge().add_forbidden_by_node(nodes[0], nodes[1]).add_forbidden_by_node(nodes[1], nodes[0])
        G_with_background_knowledge, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag,
                                                 background_knowledge=bk)
        assert not G_with_background_knowledge.is_adjacent_to(nodes[0], nodes[1])

    def test_simple_test2(self):
        data = np.empty(shape=(0, 7))
        true_dag = DiGraph()
        ground_truth_edges = [(7, 0), (7, 1), (8, 3), (8, 4), (2, 5), (2, 6), (5, 1), (6, 3), (3, 0), (1, 4)]
        true_dag.add_edges_from(ground_truth_edges)
        G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)
        ground_truth_nodes = []
        for i in range(9):
            ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
        ground_truth_dag = Dag(ground_truth_nodes)
        for u, v in ground_truth_edges:
            ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])

        pag = dag2pag(ground_truth_dag, ground_truth_nodes[7: 9])

        print(f'fci(data, d_separation, 0.05):')
        self.run_simulate_data_test(pag, G)

    def test_simple_test3(self):

        data = np.empty(shape=(0, 5))
        true_dag = DiGraph()
        ground_truth_edges = [(0, 2), (1, 2), (2, 3), (2, 4)]
        true_dag.add_edges_from(ground_truth_edges)
        G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)

        ground_truth_nodes = []
        for i in range(5):
            ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
        ground_truth_dag = Dag(ground_truth_nodes)
        for u, v in ground_truth_edges:
            ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])

        pag = dag2pag(ground_truth_dag, [])

        print(f'fci(data, d_separation, 0.05):')
        self.run_simulate_data_test(pag, G)

    def test_fritl(self):
        data = np.empty(shape=(0, 7))
        true_dag = DiGraph()
        ground_truth_edges = [(7, 0), (7, 5), (8, 0), (8, 6), (9, 3), (9, 4), (9, 6),
                              (0, 1), (0, 2), (1, 2), (2, 4), (5, 6)]
        true_dag.add_edges_from(ground_truth_edges)
        G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)

        ground_truth_nodes = []
        for i in range(10):
            ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
        ground_truth_dag = Dag(ground_truth_nodes)
        for u, v in ground_truth_edges:
            ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])

        pag = dag2pag(ground_truth_dag, ground_truth_nodes[7: 10])

        print(f'fci(data, d_separation, 0.05):')
        self.run_simulate_data_test(pag, G)

    @staticmethod
    def run_simulate_data_test(truth, est):
        graph_utils = GraphUtils()
        adj_precision = graph_utils.adj_precision(truth, est)
        adj_recall = graph_utils.adj_recall(truth, est)
        arrow_precision = graph_utils.arrow_precision(truth, est)
        arrow_recall = graph_utils.adj_precision(truth, est)

        print(f'adj_precision: {adj_precision}')
        print(f'adj_recall: {adj_recall}')
        print(f'arrow_precision: {arrow_precision}')
        print(f'arrow_recall: {arrow_recall}')
        print()
        assert np.isclose([adj_precision, adj_recall, arrow_precision, arrow_recall], [1.0, 1.0, 1.0, 1.0]).all()

    def test_bnlearn_discrete_datasets(self):
        benchmark_names = [
            "asia", "cancer", "earthquake", "sachs", "survey",
            "alarm", "barley", "child", "insurance", "water",
            "hailfinder", "hepar2", "win95pts",
            "andes"
        ]

        bnlearn_path = 'tests/TestData/bnlearn_discrete_10000/data'
        for bname in benchmark_names:
            data = np.loadtxt(os.path.join(bnlearn_path, f'{bname}.txt'), skiprows=1)
            G, edges = fci(data, chisq, 0.05, verbose=False)
            benchmark_returned_graph = np.loadtxt(
                f'tests/TestData/benchmark_returned_results/bnlearn_discrete_10000_{bname}_fci_chisq_0.05.txt')
            assert np.all(G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG

    def test_continuous_dataset(self):
        data = np.loadtxt('tests/data_linear_10.txt', skiprows=1)
        G, edges = fci(data, fisherz, 0.05, verbose=False)
        benchmark_returned_graph = np.loadtxt(
            f'tests/TestData/benchmark_returned_results/linear_10_fci_fisherz_0.05.txt')
        assert np.all(G.graph == benchmark_returned_graph), INCONSISTENT_RESULT_GRAPH_ERRMSG

    def test_er_graph(self):
        random.seed(42)
        np.random.seed(42)
        p = 0.1
        for _ in range(5):
            data = np.empty(shape=(0, 10))
            true_dag = erdos_renyi_graph(15, p, directed=True)  # The last 5 variables are latent variables
            while not is_directed_acyclic_graph(true_dag):
                true_dag = erdos_renyi_graph(15, p, directed=True)
            ground_truth_edges = list(true_dag.edges)
            print(ground_truth_edges)
            G, edges = fci(data, d_separation, 0.05, verbose=False, true_dag=true_dag)

            ground_truth_nodes = []
            for i in range(15):
                ground_truth_nodes.append(GraphNode(f'X{i + 1}'))
            ground_truth_dag = Dag(ground_truth_nodes)
            for u, v in ground_truth_edges:
                ground_truth_dag.add_directed_edge(ground_truth_nodes[u], ground_truth_nodes[v])
            print(ground_truth_dag)
            pag = dag2pag(ground_truth_dag, ground_truth_nodes[10:])
            print('pag:')
            print(pag)
            print('fci graph:')
            print(G)
            print(f'fci(data, d_separation, 0.05):')
            self.run_simulate_data_test(pag, G)

@MarkDana
Copy link
Collaborator

Thank you @winstonyu for helping us identify this! This issue is from the endpoint comparison. A patch is updated in #154, and the issue here should be addressed.

@kunwuz kunwuz closed this as completed Jan 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants