diff --git a/src/pyjuice/queries/sample.py b/src/pyjuice/queries/sample.py index ea7f95e5..ba5f44f1 100644 --- a/src/pyjuice/queries/sample.py +++ b/src/pyjuice/queries/sample.py @@ -307,7 +307,7 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo if not conditional: assert num_samples is not None, "`num_samples` should be specified when doing unconditioned sampling." else: - num_samples = pc.node_mars.size(0) # Reuse the batch size + num_samples = pc.node_mars.size(1) # Reuse the batch size root_ns = pc.root_ns assert root_ns._output_ind_range[1] - root_ns._output_ind_range[0] == 1, "It is ambiguous to sample from multi-head PCs."