Skip to content

Commit

Permalink
Encoders: refactor (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
IceKhan13 authored May 27, 2023
1 parent b00fe9f commit 5161a44
Show file tree
Hide file tree
Showing 19 changed files with 783 additions and 612 deletions.
65 changes: 29 additions & 36 deletions blackwater/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,45 @@
.. currentmodule:: blackwater.data
Classes
=======
Circuit encoders
================
.. autosummary::
:toctree: ../stubs/
ExpValDataSet
CircuitGraphExpValMitigationDataset
DefaultNumpyEstimatorInputEncoder
NodeEncoder
DefaultNodeEncoder
BackendNodeEncoder
DefaultPyGEstimatorEncoder
PygData
ExpValData
Functions
DefaultCircuitEncoder
DefaultPyGCircuitEncoder
Operator encoders
=================
.. autosummary::
:toctree: ../stubs/
DefaultOperatorEncoder
Backend encoders
================
.. autosummary::
:toctree: ../stubs/
DefaultPyGBackendEncoder
Utilities
=========
.. autosummary::
:toctree: ../stubs/
extract_properties_from_backend
circuit_to_json_graph
backend_to_json_graph
encode_pauli_sum_operator
encode_operator
encode_sparse_pauli_operatpr
DefaultNumpyEstimatorInputEncoder
"""

from .loaders.dataclasses import ExpValDataSet
from .loaders.exp_val import CircuitGraphExpValMitigationDataset
from .encoders.numpy import DefaultNumpyEstimatorInputEncoder
from .encoders.torch import (
NodeEncoder,
from .encoders.primtives_utils import DefaultNumpyEstimatorInputEncoder
from .encoders.backend import DefaultPyGBackendEncoder
from .encoders.operator import DefaultOperatorEncoder
from .encoders.circuit import (
DefaultNodeEncoder,
BackendNodeEncoder,
DefaultPyGEstimatorEncoder,
extract_properties_from_backend,
circuit_to_json_graph,
backend_to_json_graph,
PygData,
ExpValData,
)
from .encoders.utils import (
encode_pauli_sum_operator,
encode_operator,
encode_sparse_pauli_operatpr,
DefaultCircuitEncoder,
DefaultPyGCircuitEncoder,
)
57 changes: 57 additions & 0 deletions blackwater/data/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
"""Dataclasses module."""
from abc import abstractmethod
from dataclasses import dataclass
from typing import List, Any

from qiskit import QuantumCircuit
from qiskit.dagcircuit import DAGNode
from qiskit.providers import Backend
from qiskit.quantum_info import Operator


# pylint: disable=arguments-differ
class DataEncoder:
"""Base data encode class."""

@abstractmethod
def encode(self, **kwargs):
"""Encodes data
Expand All @@ -17,6 +26,54 @@ def encode(self, **kwargs):
raise NotImplementedError


class DataDecoder:
"""Base data decoder class."""

@classmethod
def decode(cls, data: Any):
"""Decodes from data to object.
Args:
data: encoded data
Returns:
decoded object
"""
raise NotImplementedError


class CircuitEncoder(DataEncoder):
"""Base encoder class for circuit objects."""

@abstractmethod
def encode(self, circuit: QuantumCircuit, **kwargs): # type: ignore
raise NotImplementedError


class OperatorEncoder(DataEncoder):
"""Base encoder class for operator objects."""

@abstractmethod
def encode(self, operator: Operator, **kwargs): # type: ignore
raise NotImplementedError


class BackendEncoder(DataEncoder):
"""Base encoder class for backend objects."""

@abstractmethod
def encode(self, backend: Backend, **kwargs): # type: ignore
raise NotImplementedError


class NodeEncoder(DataEncoder):
"""Base class for circuit dag node encoder."""

def encode(self, node: DAGNode, **kwargs) -> List[float]: # type: ignore
"""Encodes node of circuit dag."""
raise NotImplementedError


# pylint: disable=no-member
@dataclass
class BlackwaterData:
Expand Down
2 changes: 1 addition & 1 deletion blackwater/data/dataio/dataio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List

from blackwater.data.core import BlackwaterData
from blackwater.data.encoders.torch import ExpValData
from blackwater.data.encoders.graph_utils import ExpValData


# pylint: disable=unspecified-encoding
Expand Down
35 changes: 35 additions & 0 deletions blackwater/data/encoders/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Backend encoders."""

from typing import Union

import torch
from qiskit.circuit.library import get_standard_gate_name_mapping
from qiskit.providers import BackendV1, BackendV2
from torch_geometric.data import Data

from blackwater.data.core import BackendEncoder
from blackwater.data.encoders.graph_utils import backend_to_json_graph

N_QUBIT_PROPERTIES = 2
ALL_INSTRUCTIONS = list(get_standard_gate_name_mapping().keys())


# pylint: disable=no-member
class DefaultPyGBackendEncoder(BackendEncoder):
"""Default pytorch geometric backend encoder.
Turns backend into pyg data.
"""

def encode(self, backend: Union[BackendV1, BackendV2], **kwargs): # type: ignore
backend_graph = backend_to_json_graph(backend)
backend_nodes = torch.tensor(backend_graph.nodes, dtype=torch.float)
backend_edges = torch.transpose(
torch.tensor(backend_graph.edges, dtype=torch.float), 0, 1
)
backend_edge_features = torch.tensor(
backend_graph.edge_features, dtype=torch.float
)
return Data(
x=backend_nodes, edge_index=backend_edges, edge_attr=backend_edge_features
)
95 changes: 95 additions & 0 deletions blackwater/data/encoders/circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Circuit encoders."""

from typing import Optional

import numpy as np
import torch
from qiskit import QuantumCircuit
from qiskit.providers import BackendV2
from torch_geometric.data import Data

from blackwater.data.core import CircuitEncoder
from blackwater.data.encoders.graph_utils import (
DefaultNodeEncoder,
circuit_to_json_graph,
BackendNodeEncoder,
)


# pylint: disable=no-member
class DefaultCircuitEncoder(CircuitEncoder):
"""Default circuit encoder to transform circuit into numpy array for training
Returns:
numpy array where:
- first element - depth of circuit
- second element - 2q depth of circuit
- 3rd - number of 1q gates
- 4th - number of 2q gates
- 5th - num qubits
"""

def encode(self, circuit: QuantumCircuit, **kwargs) -> np.ndarray: # type: ignore
"""Encodes circuit.
Args:
circuit: circuit to encoder
**kwargs: other arguments
Returns:
numpy array
"""
depth = circuit.depth()
two_qubit_depth = circuit.depth(lambda x: x[0].num_qubits == 2)

num_one_q_gates = 0
num_two_q_gates = 0
for instr in circuit._data:
num_qubits = len(instr.qubits)
if num_qubits == 1:
num_one_q_gates += 1
if num_qubits == 2:
num_two_q_gates += 1

return np.array(
[
depth,
two_qubit_depth,
num_one_q_gates,
num_two_q_gates,
circuit.num_qubits,
]
)


class DefaultPyGCircuitEncoder(CircuitEncoder):
"""Default pytorch geometric circuit encoder.
Turns circuit into pyg data.
"""

def __init__(self, backend: Optional[BackendV2]):
"""Constructor.
Args:
backend: optional backend. Will be used for node data encoding.
"""
self.backend = backend

def encode(self, circuit: QuantumCircuit, **kwargs): # type: ignore
node_encoder = (
DefaultNodeEncoder()
if self.backend is None
else BackendNodeEncoder(self.backend)
)
circuit_graph = circuit_to_json_graph(circuit, node_encoder=node_encoder)
circuit_nodes = torch.tensor(circuit_graph.nodes, dtype=torch.float)
circuit_edges = torch.transpose(
torch.tensor(circuit_graph.edges, dtype=torch.long), 0, 1
)
circuit_edge_features = torch.tensor(
circuit_graph.edge_features, dtype=torch.float
)
return Data(
x=circuit_nodes, edge_index=circuit_edges, edge_attr=circuit_edge_features
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Graph encoders."""

from abc import ABC
from dataclasses import dataclass, asdict
from typing import Union, List, Dict, Optional, Tuple

Expand All @@ -16,23 +15,15 @@
from qiskit.transpiler import Target
from torch_geometric.data import Data

from blackwater.data.core import DataEncoder, BlackwaterData
from blackwater.data.core import DataEncoder, BlackwaterData, NodeEncoder
from blackwater.data.encoders.utils import OperatorData, encode_operator
from blackwater.exception import BlackwaterException

N_QUBIT_PROPERTIES = 2
ALL_INSTRUCTIONS = list(get_standard_gate_name_mapping().keys())


# pylint: disable=no-member
class NodeEncoder(ABC):
"""Base class for circuit dag node encoder."""

def encode(self, node: DAGNode) -> List[float]:
"""Encodes node of circuit dag."""
raise NotImplementedError


# pylint: disable=no-member, arguments-differ
class DefaultNodeEncoder(NodeEncoder):
"""DefaultNodeEncoder."""

Expand All @@ -58,7 +49,7 @@ def __init__(self, available_instructions: Optional[List[str]] = None):
for idx, inst in enumerate(available_instructions)
}

def encode(self, node: DAGNode) -> List[float]:
def encode(self, node: DAGNode, **kwargs) -> List[float]: # type: ignore
if isinstance(node, DAGOpNode):
params_encoding = [0.0, 0.0, 0.0]
for i, param in enumerate(node.op.params):
Expand Down Expand Up @@ -144,7 +135,7 @@ def __init__(self, backend: BackendV2):
self.num_qubits = backend.num_qubits
self.properties: BackendProperties = extract_properties_from_backend(backend)

def encode(self, node: DAGNode) -> List[float]:
def encode(self, node: DAGNode, **kwargs) -> List[float]: # type: ignore
if isinstance(node, DAGOpNode):
params_encoding = [0.0, 0.0, 0.0]
for i, param in enumerate(node.op.params):
Expand Down Expand Up @@ -380,19 +371,18 @@ class DefaultPyGEstimatorEncoder(DataEncoder):
"""Default encoder for pyg data.
Converts circuit data into torch_geometric.Data"""

def encode(self, **kwargs) -> Tuple[Data, float]:
circuit: QuantumCircuit = kwargs.get("circuit")
operator: PauliSumOp = kwargs.get("operator")
exp_value = kwargs.get("exp_val")
backend = kwargs.get("backend")

if circuit is None or operator is None or exp_value is None or backend is None:
raise BlackwaterException("Missing encoder input.")

def encode( # type: ignore
self,
circuit: QuantumCircuit,
operator: PauliSumOp,
exp_val: float,
backend: BackendV2,
**kwargs,
) -> Tuple[Data, float]:
data = ExpValData.build(
circuit=circuit,
expectation_values=[exp_value],
expectation_values=[exp_val],
observable=operator,
backend=backend,
).to_pyg()
return data, exp_value
return data, exp_val
Loading

0 comments on commit 5161a44

Please sign in to comment.