Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qml.Hamiltonian inside of a pure_callback doesn't work #1093

Open
isaacdevlugt opened this issue Sep 3, 2024 · 6 comments
Open

qml.Hamiltonian inside of a pure_callback doesn't work #1093

isaacdevlugt opened this issue Sep 3, 2024 · 6 comments
Assignees

Comments

@isaacdevlugt
Copy link
Contributor

Issue description

Trying to use qml.qchem.molecular_hamiltonian inside of a callback and it's not working.

  • Expected behavior: I expect that generating a Hamiltonian in a callback should be possible since a Hamiltonian is a pytree.

  • Actual behavior: It doesn't work.

  • Reproduces how often: 100%

  • System information:

Name: PennyLane
Version: 0.37.0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: [/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages](https://file+.vscode-resource.vscode-cdn.net/Users/isaac/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages)
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, packaging, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane-Catalyst, PennyLane_Lightning

Platform info:           macOS-14.6.1-arm64-arm-64bit
Python version:          3.11.9
Numpy version:           1.26.4
Scipy version:           1.12.0
Installed devices:
- lightning.qubit (PennyLane_Lightning-0.37.0)
- nvidia.custatevec (PennyLane-Catalyst-0.7.0)
- nvidia.cutensornet (PennyLane-Catalyst-0.7.0)
- oqc.cloud (PennyLane-Catalyst-0.7.0)
- softwareq.qpp (PennyLane-Catalyst-0.7.0)
- default.clifford (PennyLane-0.37.0)
- default.gaussian (PennyLane-0.37.0)
- default.mixed (PennyLane-0.37.0)
- default.qubit (PennyLane-0.37.0)
- default.qubit.autograd (PennyLane-0.37.0)
- default.qubit.jax (PennyLane-0.37.0)
- default.qubit.legacy (PennyLane-0.37.0)
- default.qubit.tf (PennyLane-0.37.0)
- default.qubit.torch (PennyLane-0.37.0)
- default.qutrit (PennyLane-0.37.0)
- default.qutrit.mixed (PennyLane-0.37.0)
- default.tensor (PennyLane-0.37.0)
- null.qubit (PennyLane-0.37.0)

Source code and tracebacks

coordinates = jnp.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.1]])
symbols = ['H', 'H']

# Construct the Molecule object
molecule = qchem.Molecule(symbols, coordinates)

H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion') 
data, shape = jax.tree_util.tree_flatten(H)

abstract = jax._src.api_util.shaped_abstractify(jnp.array(data))
H_abstract = jax.tree_util.tree_unflatten(shape, abstract)

@catalyst.pure_callback
def get_hamiltonian(coords, molecule) -> (H_abstract, int):
    H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion') # can't be jit'd because of deep numpy calls
    return H, qubits
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[11], line 11
      8 data, shape = jax.tree_util.tree_flatten(H)
     10 abstract = jax._src.api_util.shaped_abstractify(jnp.array(data))
---> 11 H_abstract = jax.tree_util.tree_unflatten(shape, abstract)
     13 @catalyst.pure_callback
     14 def get_hamiltonian(coords, molecule) -> (H_abstract, int):
     15     H, qubits = qchem.molecular_hamiltonian(molecule, method='openfermion') # can't be jit'd because of deep numpy calls

File ~/.virtualenvs/pennylane-catalyst/lib/python3.11/site-packages/jax/_src/tree_util.py:100, in tree_unflatten(treedef, leaves)
     86 def tree_unflatten(treedef: PyTreeDef, leaves: Iterable[Leaf]) -> Any:
     87   """Reconstructs a pytree from the treedef and the leaves.
     88 
     89   The inverse of :func:`tree_flatten`.
   (...)
     98     described by ``treedef``.
     99   """
--> 100   return treedef.unflatten(leaves)

TypeError: unflatten(): incompatible function arguments. The following argument types are supported:
    1. (self: jaxlib.xla_extension.pytree.PyTreeDef, arg0: Iterable) -> object

Invoked with: PyTreeDef(CustomNode(Sum[(None,)], [CustomNode(SProd[()], [*, CustomNode(Identity[(<Wires = [0]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [0]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [1]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [2]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(PauliZ[(<Wires = [3]>, ())], [])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [1]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliY[(<Wires = [0]>, ())], []), CustomNode(PauliX[(<Wires = [1]>, ())], []), CustomNode(PauliX[(<Wires = [2]>, ())], []), CustomNode(PauliY[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliY[(<Wires = [0]>, ())], []), CustomNode(PauliY[(<Wires = [1]>, ())], []), CustomNode(PauliX[(<Wires = [2]>, ())], []), CustomNode(PauliX[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliX[(<Wires = [0]>, ())], []), CustomNode(PauliX[(<Wires = [1]>, ())], []), CustomNode(PauliY[(<Wires = [2]>, ())], []), CustomNode(PauliY[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliX[(<Wires = [0]>, ())], []), CustomNode(PauliY[(<Wires = [1]>, ())], []), CustomNode(PauliY[(<Wires = [2]>, ())], []), CustomNode(PauliX[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [2]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [0]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [1]>, ())], []), CustomNode(PauliZ[(<Wires = [2]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [1]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])]), CustomNode(SProd[()], [*, CustomNode(Prod[()], [CustomNode(PauliZ[(<Wires = [2]>, ())], []), CustomNode(PauliZ[(<Wires = [3]>, ())], [])])])])), ShapedArray(float64[15])

Additional information

Any additional information, configuration or data that might be necessary
to reproduce the issue.

@erick-xanadu erick-xanadu self-assigned this Sep 3, 2024
@dime10
Copy link
Contributor

dime10 commented Sep 3, 2024

Your traceback seems to indicate the error occurs outside of the callback

@erick-xanadu
Copy link
Contributor

@dime10 I told @isaacdevlugt to create an issue so that I can narrow it down for him. This is not urgent.

@isaacdevlugt
Copy link
Contributor Author

Thanks @dime10! Yes there's a bit of context missing 😅

@josh146
Copy link
Member

josh146 commented Sep 4, 2024

Note that the following does work (assuming you know in advance how many terms the Hamiltonian has):

data, shape = jax.tree_util.tree_flatten(H)

def get_hamiltonian(coordinates):
    molecule = qml.qchem.Molecule(["H", "H"], coordinates)
    H, qubits = qml.qchem.molecular_hamiltonian(molecule)
    return H

@qml.qjit
def f(coordinates):
    return catalyst.pure_callback(get_hamiltonian, result_type=[jax.ShapeDtypeStruct([], dtype=float)] * 15)(coordinates)
>>> f(coordinates)
[Array(9.7983828, dtype=float64),
 Array(0.3060033, dtype=float64),
 Array(0.3060033, dtype=float64),
 Array(0.19345955, dtype=float64),
 Array(-0.74817268, dtype=float64),
 Array(0.15271652, dtype=float64),
 Array(0.19164807, dtype=float64),
 Array(0.03893155, dtype=float64),
 Array(-0.03893155, dtype=float64),
 Array(-0.03893155, dtype=float64),
 Array(0.03893155, dtype=float64),
 Array(-0.74817268, dtype=float64),
 Array(0.19164807, dtype=float64),
 Array(0.15271652, dtype=float64),
 Array(0.20376722, dtype=float64)]

However, attempting to include the pytree struct of the Hamiltonian in the result_type leads to a failure in Catalyst validation.

@isaacdevlugt something to note here: the underlying solution is that we should ensure that qml.qchem.molecule and qml.qchem.molecular_hamiltonian should be jax.jit compatible end-to-end -- there is no need for them to try and convert JAX arrays to NumPy arrays. This should negate the need to try and use a callback.

@josh146
Copy link
Member

josh146 commented Sep 4, 2024

cc @soranjh

@isaacdevlugt
Copy link
Contributor Author

@josh146 agreed that molecule and molecular_hamiltonian could be jit-compatible!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants