-
Notifications
You must be signed in to change notification settings - Fork 1
/
test_core.py
97 lines (78 loc) · 3.34 KB
/
test_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pytest
import random
import numpy as np
from copy import copy
from core import sample_steiner_trees
from sample_pool import TreeSamplePool
from graph_helpers import is_steiner_tree
from tree_stat import TreeBasedStatistics
from graph_helpers import observe_uninfected_node
from random_steiner_tree.util import isolate_vertex, edges as gi_edges
from fixture import g, gi, obs
@pytest.mark.parametrize("return_type", ['nodes', 'tuples', 'tree'])
@pytest.mark.parametrize("method", ['cut', 'loop_erased'])
def test_sample_steiner_trees(g, gi, obs, return_type, method):
n_samples = 100
st_trees_all = sample_steiner_trees(g, obs, method, n_samples,
gi=gi,
return_type=return_type)
assert len(st_trees_all) == n_samples
for t in st_trees_all:
if return_type == 'nodes':
assert set(obs).issubset(t)
elif return_type == 'tree':
assert is_steiner_tree(t, obs)
elif return_type == 'tuples':
assert isinstance(t, tuple)
else:
raise Exception
@pytest.mark.parametrize("method", ['cut', 'loop_erased'])
@pytest.mark.parametrize("edge_weight", [1.0, 0.5, 0.0])
def test_TreeSamplePool_with_incremental_sampling(g, gi, obs, method, edge_weight):
edge_weights = g.new_edge_property("float")
edge_weights.set_value(edge_weight) # if edge =1.0, for sure to include all nodes
g.edge_properties['weights'] = edge_weights
n_samples = 100
sampler = TreeSamplePool(g, n_samples, method,
gi=gi,
return_type='nodes',
with_inc_sampling=True)
sampler.fill(obs)
assert len(sampler.samples) == n_samples
for t in sampler.samples:
assert isinstance(t, set)
assert set(obs).issubset(t)
if edge_weight == 1.0:
# if edge weight is 1, all nodes are infected
assert len(t) == g.num_vertices()
# update
n_rm = random.choice(
list(set(np.arange(g.num_vertices())) - set(obs)))
isolate_vertex(gi, n_rm)
observe_uninfected_node(g, n_rm, obs)
print('n_rm', n_rm)
print('n_rm.out_edges()', list(g.vertex(n_rm).out_edges()))
print('n_rm.in_edges()', list(g.vertex(n_rm).in_edges()))
edges = {e for e in gi_edges(gi) if n_rm in set(e)}
print('gi.vertex(n_rm).edges()', edges)
num_invalid_trees = sum(1 for t in sampler.samples if n_rm in t)
valid_trees = [t
for t in sampler.samples
if n_rm not in t] # this tree cannot be changed even after .update
valid_trees_old = copy(valid_trees)
new_samples = sampler.update_samples(obs, {n_rm: 0})
assert len(sampler.samples) == n_samples
assert len(new_samples) == num_invalid_trees
for t in new_samples:
# new samples are also incremented
assert isinstance(t, set)
assert set(obs).issubset(t)
if edge_weight == 1.0:
assert len(t) == (g.num_vertices() - 1) # because of noden isolation, now it's 99
else:
assert len(t) < (g.num_vertices() - 1)
for t in sampler.samples:
assert n_rm not in t # because n_rm is removed
# make sure valid trees before and after update remaint the same
for t1, t2 in zip(valid_trees, valid_trees_old):
assert t1 == t2