Skip to content

Commit

Permalink
fix sampling
Browse files Browse the repository at this point in the history
liuanji committed Jun 20, 2024

Verified

This commit was signed with the committer’s verified signature.
mistic Tiago Costa
1 parent 534b1c4 commit 217259b
Showing 2 changed files with 63 additions and 4 deletions.
20 changes: 17 additions & 3 deletions src/pyjuice/queries/sample.py
Original file line number Diff line number Diff line change
@@ -184,6 +184,7 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo

# Iterate over sum layers in the current layer group
for layer in layer_group:

# Gather the indices to be processed
lsid, leid = layer._layer_nid_range
ind_n, ind_b = torch.where((node_samples >= lsid) & (node_samples < leid))
@@ -234,14 +235,27 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo
local_nid_offsets = ((global_cids[:,None] - nids[None,:]) * is_match.long()).sum(dim = 1)

target_nids = cids[local_nids,:] + local_nid_offsets[:,None]
target_cids = cids[local_nids,:]

target_idx = node_pointers[ind_b] + (torch.cumsum(mask, dim = 0)[ind_n, ind_b] - 1) * cids.size(1)

target_nids = target_nids[is_match.any(dim = 1),:]
target_idx = target_idx[is_match.any(dim = 1)]
if target_idx.max() + cids.size(1) > node_samples.size(0):

node_samples_new = torch.zeros([target_idx.max() + cids.size(1), num_samples], dtype = torch.long, device = pc.device)
node_samples_new[:,:] = -1

node_samples_new[:node_samples.size(0),:] = node_samples
node_samples = node_samples_new

match_filter = is_match.any(dim = 1)
target_nids = target_nids[match_filter,:]
target_cids = target_cids[match_filter,:]
target_idx = target_idx[match_filter]
target_b = ind_b[match_filter]

for i in range(cids.size(1)):
node_samples[target_idx+i, ind_b] = target_nids[:,i]
cmask = target_cids[:,i] != 0
node_samples[target_idx[cmask]+i, target_b[cmask]] = target_nids[cmask,i]

node_pointers = (node_samples != -1).sum(dim = 0)

47 changes: 46 additions & 1 deletion tests/queries/sample_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pyjuice as juice
import torch
import torchvision
import numpy as np
from torch.utils.data import TensorDataset, DataLoader

import pyjuice.nodes.distributions as dists
from pyjuice.utils import BitSet
@@ -37,5 +39,48 @@ def test_sample():
assert ((samples >= 0) & (samples < 2)).all()


def test_sample_hclt():

device = torch.device("cuda:0")

train_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = True, download = True)
test_dataset = torchvision.datasets.MNIST(root = "./examples/data", train = False, download = True)

train_data = train_dataset.data.reshape(60000, 28*28)
test_data = test_dataset.data.reshape(10000, 28*28)

num_features = train_data.size(1)

train_loader = DataLoader(
dataset = TensorDataset(train_data),
batch_size = 512,
shuffle = True,
drop_last = True
)
test_loader = DataLoader(
dataset = TensorDataset(test_data),
batch_size = 512,
shuffle = False,
drop_last = True
)

ns = juice.structures.HCLT(
train_data.float().to(device),
num_bins = 32,
sigma = 0.5 / 32,
num_latents = 128,
chunk_size = 32
)
ns.init_parameters(perturbation = 2.0)
pc = juice.TensorCircuit(ns)

pc.to(device)

samples = juice.queries.sample(pc, num_samples = 16)

assert ((samples >= 0) & (samples < 256)).all()


if __name__ == "__main__":
test_sample()
test_sample()
test_sample_hclt()

0 comments on commit 217259b

Please sign in to comment.