From 043e9e3261969d089560bda6099a82be06c6710e Mon Sep 17 00:00:00 2001 From: Anji Liu Date: Mon, 12 Aug 2024 21:39:41 +0800 Subject: [PATCH] fix conditional sampling --- src/pyjuice/queries/sample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyjuice/queries/sample.py b/src/pyjuice/queries/sample.py index ea7f95e5..ba5f44f1 100644 --- a/src/pyjuice/queries/sample.py +++ b/src/pyjuice/queries/sample.py @@ -307,7 +307,7 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo if not conditional: assert num_samples is not None, "`num_samples` should be specified when doing unconditioned sampling." else: - num_samples = pc.node_mars.size(0) # Reuse the batch size + num_samples = pc.node_mars.size(1) # Reuse the batch size root_ns = pc.root_ns assert root_ns._output_ind_range[1] - root_ns._output_ind_range[0] == 1, "It is ambiguous to sample from multi-head PCs."