Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] broadcasting may be broken in stoch_pulse_grad #4274

Closed
Qottmann opened this issue Jun 20, 2023 · 2 comments · Fixed by #4275
Closed

[bug] broadcasting may be broken in stoch_pulse_grad #4274

Qottmann opened this issue Jun 20, 2023 · 2 comments · Fixed by #4275
Labels
bug 🐛 Something isn't working

Comments

@Qottmann
Copy link
Contributor

It seems like broadcasting for classical simulation speedup in qml.gradients.stoch_pulse_grad is broken in the latest master / rc branch. My suspicion is that this was introduced with the merging of #4215. Running the same code on e.g. previous commits like 3bf2964 works.

n_wires = 2
atol = 1e-8
timespan = 0.1
H_obj = qml.dot(
    [1. for i in range(n_wires)],
    [qml.PauliZ(i) @ qml.PauliZ((i+1)%n_wires) for i in range(n_wires)]
)
H = qml.dot(
    [qml.pulse.constant for i in range(n_wires)],
    [qml.PauliY(i) @ qml.PauliY((i+1)%n_wires) for i in range(n_wires)]
)

dev = qml.device("default.qubit.jax", wires=n_wires)

def circuit(params):
    qml.evolve(H, atol=atol)(params, t=timespan)
    return qml.expval(H_obj)


num_split_times = 6

cost_ps  = qml.QNode(
    circuit,
    dev, 
    interface="jax", 
    diff_method=qml.gradients.stoch_pulse_grad, 
    num_split_times=num_split_times, 
    use_broadcasting=True
)

params = jnp.ones(n_wires)

# works for https://github.com/PennyLaneAI/pennylane/commit/3bf2964e7038c8caed17ab722e2c45dc3bcf46e9
# breaks for master / v0.31.0-rc0
grad = jax.grad(cost_ps)(params)

output:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 35
     31 params = jnp.ones(n_wires)
     33 # works for https://github.com/PennyLaneAI/pennylane/commit/3bf2964e7038c8caed17ab722e2c45dc3bcf46e9
     34 # breaks for 
---> 35 grad = jax.grad(cost_ps)(params)

    [... skipping hidden 10 frame]

File [~/Qottmann/Xanadu/pennylane/pennylane/qnode.py:950](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/qnode.py:950), in QNode.__call__(self, *args, **kwargs)
    948     self.execute_kwargs.pop("mode")
    949 # pylint: disable=unexpected-keyword-arg
--> 950 res = qml.execute(
    951     [self.tape],
    952     device=self.device,
    953     gradient_fn=self.gradient_fn,
    954     interface=self.interface,
    955     gradient_kwargs=self.gradient_kwargs,
    956     override_shots=override_shots,
    957     **self.execute_kwargs,
    958 )
    960 res = res[0]
    962 # convert result to the interface in case the qfunc has no parameters

File [~/Qottmann/Xanadu/pennylane/pennylane/interfaces/execution.py:642](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/interfaces/execution.py:642), in execute(tapes, device, gradient_fn, interface, grad_on_execution, gradient_kwargs, cache, cachesize, max_diff, override_shots, expand_fn, max_expansion, device_batch_transform)
    639     elif mapped_interface == "jax":
    640         _execute = _get_jax_execute_fn(interface, tapes)
--> 642     res = _execute(
    643         tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff
    644     )
    646 except ImportError as e:
    647     raise qml.QuantumFunctionError(
    648         f"{mapped_interface} not found. Please install the latest "
    649         f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
    650     ) from e

File [~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:416](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:416), in execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n, max_diff)
    407     return _execute_fwd(
    408         parameters,
    409         tapes,
   (...)
    412         _n=_n,
    413     )
    415 # PennyLane backward execution
--> 416 return _execute_bwd(
    417     parameters,
    418     tapes,
    419     device,
    420     execute_fn,
    421     gradient_fn,
    422     gradient_kwargs,
    423     _n=_n,
    424     max_diff=max_diff,
    425 )

File [~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:502](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:502), in _execute_bwd(params, tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n, max_diff)
    498         jvps = _compute_jvps(jacs, tangents[0], multi_measurements)
    500     return res, jvps
--> 502 return execute_wrapper(params)

    [... skipping hidden 5 frame]

File [~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:475](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/interfaces/jax.py:475), in _execute_bwd..execute_wrapper_jvp(primals, tangents)
    473 if at_max_diff:
    474     jvp_tapes, processing_fn = qml.gradients.batch_jvp(*_args, **_kwargs)
--> 475     jvps = processing_fn(execute_fn(jvp_tapes)[0])
    476 else:
    477     jvp_tapes, processing_fn = qml.gradients.batch_jvp(*_args, **_kwargs)

File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/jvp.py:439](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/jvp.py:439), in batch_jvp..processing_fn(results)
    436 start += res_len
    438 # postprocess results to compute the JVP
--> 439 jvp_ = processing_fns[t_idx](res_t)
    441 if jvp_ is None:
    442     if reduction == "append":

File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/jvp.py:329](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/jvp.py:329), in jvp..processing_fn(results)
    327 def processing_fn(results):
    328     # postprocess results to compute the Jacobian
--> 329     jac = fn(results)
    330     _jvp_fn = compute_jvp_multi if multi_m else compute_jvp_single
    332     # Jacobian without shot vectors

File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:839](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:839), in _expval_stoch_pulse_grad..processing_fn(results)
    836     start += num_tapes
    837     # Apply the postprocessing of the parameter-shift rule and contract
    838     # with classical Jacobian, effectively computing the integral approximation
--> 839     g = _parshift_and_integrate(
    840         res,
    841         cjacs,
    842         int_prefactor,
    843         psr_coeffs,
    844         single_measure,
    845         has_partitioned_shots,
    846         use_broadcasting,
    847     )
    848     grads.append(g)
    850 # g will have been defined at least once (because otherwise all gradients would have
    851 # been zero), providing a representative for a zero gradient to emulate its type/shape.

File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:269](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:269), in _parshift_and_integrate(results, cjacs, int_prefactor, psr_coeffs, single_measure, has_partitioned_shots, use_broadcasting)
    266     return tuple(_psr_and_contract(r, cjacs, int_prefactor) for r in zip(*results))
    267 if nesting_layers == 0:
    268     # Single measurement without shot vector
--> 269     return _psr_and_contract(results, cjacs, int_prefactor)
    271 # Multiple measurements with shot vector. Not supported with broadcasting yet.
    272 if use_broadcasting:
    273     # TODO: Remove once #2690 is resolved

File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:262](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:262), in _parshift_and_integrate.._psr_and_contract(res_list, cjacs, int_prefactor)
    257     res = jnp.moveaxis(jnp.reshape(res, new_shape), 1, 0)
    259 # Contract the results, parameter-shift rule coefficients and (classical) Jacobians,
    260 # and include the rescaling factor from the Monte Carlo integral and from global
    261 # prefactors of Pauli word generators.
--> 262 return _contract(psr_coeffs, res, cjacs) * int_prefactor

File [~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:193](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/Qottmann/Xanadu/pennylane/pennylane/gradients/pulse_gradient.py:193), in _parshift_and_integrate.._contract(coeffs, res, cjac)
    190 def _contract(coeffs, res, cjac):
    191     """Contract three tensors, the first two like a standard matrix multiplication
    192     and the result with the third tensor along the first axes."""
--> 193     return jnp.tensordot(jnp.tensordot(coeffs, res, axes=1), cjac, axes=[[0], [0]])

File [~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3162](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3162), in tensordot(a, b, axes, precision)
   3159   msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
   3160          "of lists/tuples of ints.")
   3161   raise TypeError(msg)
-> 3162 return lax.dot_general(a, b, (contracting_dims, ((), ())),
   3163                        precision=precision)

    [... skipping hidden 26 frame]

File [~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/lax/lax.py:2496](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/home/qottmann/Qottmann/Xanadu/Sandbox/~/anaconda3/envs/pennylane311/lib/python3.11/site-packages/jax/_src/lax/lax.py:2496), in _dot_general_shape_rule(lhs, rhs, dimension_numbers, precision, preferred_element_type)
   2493 if not core.symbolic_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
   2494   msg = ("dot_general requires contracting dimensions to have the same "
   2495          "shape, got {} and {}.")
-> 2496   raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
   2498 return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)

TypeError: dot_general requires contracting dimensions to have the same shape, got (2,) and (6,).
@github-actions github-actions bot added the bug 🐛 Something isn't working label Jun 20, 2023
@dwierichs
Copy link
Contributor

The problem seems to be at the interplay of qml.ops.op_math.sum.Sum and broadcasting. When exchanging the observable by, say, qml.PauliZ(0) @ qml.PauliZ(1), the example above executes.

@Qottmann
Copy link
Contributor Author

fixed thanks to @dwierichs 👌

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants