You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems like broadcasting for classical simulation speedup in qml.gradients.stoch_pulse_grad is broken in the latest master / rc branch. My suspicion is that this was introduced with the merging of #4215. Running the same code on e.g. previous commits like 3bf2964 works.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[2], line 35
31 params = jnp.ones(n_wires)
33 # works for https://github.com/PennyLaneAI/pennylane/commit/3bf2964e7038c8caed17ab722e2c45dc3bcf46e9
34 # breaks for
---> 35 grad = jax.grad(cost_ps)(params)
[... skipping hidden 10 frame]
File [~/Qottmann/Xanadu/pennylane/pennylane/qnode.py:950](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/qnode.py:950), in QNode.__call__(self, *args, **kwargs)
948 self.execute_kwargs.pop("mode")
949 # pylint: disable=unexpected-keyword-arg
--> 950 res = qml.execute(
951 [self.tape],
952 device=self.device,
953 gradient_fn=self.gradient_fn,
954 interface=self.interface,
955 gradient_kwargs=self.gradient_kwargs,
956 override_shots=override_shots,
957 **self.execute_kwargs,
958 )
960 res = res[0]
962 # convert result to the interface in case the qfunc has no parameters
File [~/Qottmann/Xanadu/pennylane/pennylane/interfaces/execution.py:642](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/interfaces/execution.py:642), in execute(tapes, device, gradient_fn, interface, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)
639 elif mapped_interface == "jax":
640 _execute = _get_jax_execute_fn(interface, tapes)
--> 642 res = _execute(
643 tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
644 )
646 except ImportError as e:
647 raise qml.QuantumFunctionError(
648 f"{mapped_interface} not found. Please install the latest "
649 f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
650 ) from e
File [~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:416](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:416), in execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n, max_diff)
407 return _execute_fwd(
408 parameters,
409 tapes,
(...)
412 _n=_n,
413 )
415 # PennyLane backward execution
--> 416 return _execute_bwd(
417 parameters,
418 tapes,
419 device,
420 execute_fn,
421 gradient_fn,
422 gradient_kwargs,
423 _n=_n,
424 max_diff=max_diff,
425 )
File [~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:502](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:502), in _execute_bwd(params, tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n, max_diff)
498 jvps = _compute_jvps(jacs, tangents[0], multi_measurements)
500 return res, jvps
--> 502 return execute_wrapper(params)
[... skipping hidden 5 frame]
File [~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:475](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:475), in _execute_bwd..execute_wrapper_jvp(primals, tangents)
473 if at_max_diff:
474 jvp_tapes, processing_fn = qml.gradients.batch_jvp(*_args, **_kwargs)
--> 475 jvps = processing_fn(execute_fn(jvp_tapes)[0])
476 else:
477 jvp_tapes, processing_fn = qml.gradients.batch_jvp(*_args, **_kwargs)
File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/jvp.py:439](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/jvp.py:439), in batch_jvp..processing_fn(results)
436 start += res_len
438 # postprocess results to compute the JVP
--> 439 jvp_ = processing_fns[t_idx](res_t)
441 if jvp_ is None:
442 if reduction == "append":
File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/jvp.py:329](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/jvp.py:329), in jvp..processing_fn(results)
327 def processing_fn(results):
328 # postprocess results to compute the Jacobian
--> 329 jac = fn(results)
330 _jvp_fn = compute_jvp_multi if multi_m else compute_jvp_single
332 # Jacobian without shot vectors
File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:839](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:839), in _expval_stoch_pulse_grad..processing_fn(results)
836 start += num_tapes
837 # Apply the postprocessing of the parameter-shift rule and contract
838 # with classical Jacobian, effectively computing the integral approximation
--> 839 g = _parshift_and_integrate(
840 res,
841 cjacs,
842 int_prefactor,
843 psr_coeffs,
844 single_measure,
845 has_partitioned_shots,
846 use_broadcasting,
847 )
848 grads.append(g)
850 # g will have been defined at least once (because otherwise all gradients would have
851 # been zero), providing a representative for a zero gradient to emulate its type/shape.
File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:269](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:269), in _parshift_and_integrate(results, cjacs, int_prefactor, psr_coeffs, single_measure, has_partitioned_shots, use_broadcasting)
266 return tuple(_psr_and_contract(r, cjacs, int_prefactor) for r in zip(*results))
267 if nesting_layers == 0:
268 # Single measurement without shot vector
--> 269 return _psr_and_contract(results, cjacs, int_prefactor)
271 # Multiple measurements with shot vector. Not supported with broadcasting yet.
272 if use_broadcasting:
273 # TODO: Remove once #2690 is resolved
File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:262](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:262), in _parshift_and_integrate.._psr_and_contract(res_list, cjacs, int_prefactor)
257 res = jnp.moveaxis(jnp.reshape(res, new_shape), 1, 0)
259 # Contract the results, parameter-shift rule coefficients and (classical) Jacobians,
260 # and include the rescaling factor from the Monte Carlo integral and from global
261 # prefactors of Pauli word generators.
--> 262 return _contract(psr_coeffs, res, cjacs) * int_prefactor
File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:193](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:193), in _parshift_and_integrate.._contract(coeffs, res, cjac)
190 def _contract(coeffs, res, cjac):
191 """Contract three tensors, the first two like a standard matrix multiplication
192 and the result with the third tensor along the first axes."""
--> 193 return jnp.tensordot(jnp.tensordot(coeffs, res, axes=1), cjac, axes=[[0], [0]])
File [~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3162](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3162), in tensordot(a, b, axes, precision)
3159 msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
3160 "of lists/tuples of ints.")
3161 raise TypeError(msg)
-> 3162 return lax.dot_general(a, b, (contracting_dims, ((), ())),
3163 precision=precision)
[... skipping hidden 26 frame]
File [~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/lax/lax.py:2496](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/lax/lax.py:2496), in _dot_general_shape_rule(lhs, rhs, dimension_numbers, precision, preferred_element_type)
2493 if not core.symbolic_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
2494 msg = ("dot_general requires contracting dimensions to have the same "
2495 "shape, got {} and {}.")
-> 2496 raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
2498 return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)
TypeError: dot_general requires contracting dimensions to have the same shape, got (2,) and (6,).
The text was updated successfully, but these errors were encountered:
The problem seems to be at the interplay of qml.ops.op_math.sum.Sum and broadcasting. When exchanging the observable by, say, qml.PauliZ(0) @ qml.PauliZ(1), the example above executes.
It seems like broadcasting for classical simulation speedup in
qml.gradients.stoch_pulse_grad
is broken in the latest master / rc branch. My suspicion is that this was introduced with the merging of #4215. Running the same code on e.g. previous commits like 3bf2964 works.output:
The text was updated successfully, but these errors were encountered: