diff --git a/src/pyjuice/queries/sample.py b/src/pyjuice/queries/sample.py index 7fcade47..ea7f95e5 100644 --- a/src/pyjuice/queries/sample.py +++ b/src/pyjuice/queries/sample.py @@ -303,7 +303,7 @@ def sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_targ ) -def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bool = False): +def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bool = False, _sample_input_ns: bool = True): if not conditional: assert num_samples is not None, "`num_samples` should be specified when doing unconditioned sampling." else: @@ -425,8 +425,12 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo ind_node = node_samples[ind_n, ind_b] pc.node_flows[ind_node, ind_b] = 1.0 - for layer in pc.input_layer_group: - seed = random.randint(0, 2**31) - layer.sample(samples, pc.node_flows, seed = seed) + if _sample_input_ns: + for layer in pc.input_layer_group: + seed = random.randint(0, 2**31) + layer.sample(samples, pc.node_flows, seed = seed) - return samples.permute(1, 0).contiguous() + return samples.permute(1, 0).contiguous() + else: + # In this case, we do not explicitly sample input nodes + return None