Skip to content

Commit

Permalink
added local and global constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
vytautas-a committed Nov 17, 2023
1 parent 5bcb4c5 commit 0aa22bc
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 103 deletions.
175 changes: 100 additions & 75 deletions qadence/analog/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from qadence.parameters import Parameter, evaluate
from qadence.types import StrEnum

DEFAULT_MAX_AMPLITUDE = 2 * pi * 3
DEFAULT_MAX_DETUNING = 2 * pi * 20
GLOBAL_MAX_AMPLITUDE = 300
GLOBAL_MAX_DETUNING = 2 * pi * 2000
LOCAL_MAX_AMPLITUDE = 3
LOCAL_MAX_DETUNING = 2 * pi * 20


class WeightConstraint(StrEnum):
Expand Down Expand Up @@ -39,93 +41,116 @@ class AddressingPattern:
# list of weights for fixed detuning pattern that cannot be changed during the execution
weights_det: dict[int, float | torch.Tensor | Parameter]

# maximum amplitude can also be chosen as a variational parameter if needed
max_amp: float | torch.Tensor | Parameter = DEFAULT_MAX_AMPLITUDE

# maximum detuning can also be chosen as a variational parameter if needed
max_det: float | torch.Tensor | Parameter = DEFAULT_MAX_DETUNING

# weight constraint
weight_constraint: WeightConstraint = WeightConstraint.NORMALIZE

def _normalize_weights(self) -> None:
self.weights_amp = {
k: Parameter(v) if not isinstance(v, Parameter) else abs(v)
for k, v in self.weights_amp.items()
# amplitude can also be chosen as a variational parameter if needed
amp: float | torch.Tensor | Parameter = LOCAL_MAX_AMPLITUDE

# detuning can also be chosen as a variational parameter if needed
det: float | torch.Tensor | Parameter = LOCAL_MAX_DETUNING

def _validate_weights(
self,
weights: dict[int, float | torch.Tensor | Parameter],
) -> None:
for v in weights.values():
if not isinstance(v, Parameter):
if not (v >= 0.0 and v <= 1.0):
raise ValueError("Addressing pattern weights must sum fall in range [0.0, 1.0]")

def _constrain_weights(
self,
weights: dict[int, float | torch.Tensor | Parameter],
) -> dict:
# augment weight dict if needed
weights = {
i: Parameter(0.0)
if i not in weights
else (Parameter(weights[i]) if not isinstance(weights[i], Parameter) else weights[i])
for i in range(self.n_qubits)
}
sum_weights_amp = sum(list(self.weights_amp.values()))
self.weights_amp = {k: v / sum_weights_amp for k, v in self.weights_amp.items()}

self.weights_det = {
k: Parameter(v) if not isinstance(v, Parameter) else abs(v)
for k, v in self.weights_det.items()
# restrict weights to [0, 1] range
weights = {
k: abs(v * (sigmoid(v, 20, 1.0) - sigmoid(v, 20.0, -1.0))) for k, v in weights.items()
}
sum_weights_det = sum(list(self.weights_det.values()))
self.weights_det = {k: v / sum_weights_det for k, v in self.weights_det.items()}

def _restrict_weights(self) -> None:
self.weights_amp = {
k: v * (sigmoid(v, 20, 0.0) - sigmoid(v, 20.0, -1.0))
for k, v in self.weights_amp.items()
}
self.weights_det = {
k: v * (sigmoid(v, 20.0, 0.0) - sigmoid(v, 20.0, -1.0))
for k, v in self.weights_det.items()
return weights

def _constrain_max_vals(self) -> None:
# enforce constraints:
# 0 <= amp <= GLOBAL_MAX_AMPLITUDE
# 0 <= abs(det) <= GLOBAL_MAX_DETUNING
self.amp = abs(
self.amp
* (
sympy.Heaviside(self.amp + GLOBAL_MAX_AMPLITUDE)
- sympy.Heaviside(self.amp - GLOBAL_MAX_AMPLITUDE)
)
)
self.det = -abs(
self.det
* (
sympy.Heaviside(self.det + GLOBAL_MAX_DETUNING)
- sympy.Heaviside(self.det - GLOBAL_MAX_DETUNING)
)
)

def _create_local_constraint(self, val: sympy.Expr, weights: dict, max_val: float) -> dict:
# enforce local constraints:
# amp * w_amp_i < LOCAL_MAX_AMPLITUDE or
# abs(det) * w_det_i < LOCAL_MAX_DETUNING
local_constr = {k: val * v for k, v in weights.items()}
local_constr = {
k: sympy.Heaviside(v) - sympy.Heaviside(v - max_val) for k, v in local_constr.items()
}

def _restrict_max_vals(self) -> None:
self.max_amp = self.max_amp * (
sympy.Heaviside(self.max_amp) - sympy.Heaviside(self.max_amp - DEFAULT_MAX_AMPLITUDE)
)
self.max_det = self.max_det * (
sympy.Heaviside(self.max_det) - sympy.Heaviside(self.max_det - DEFAULT_MAX_DETUNING)
return local_constr

def _create_global_constraint(
self, val: sympy.Expr, weights: dict, max_val: float
) -> sympy.Expr:
# enforce global constraints:
# amp * sum(w_amp_0, w_amp_1, ...) < GLOBAL_MAX_AMPLITUDE or
# abs(det) * sum(w_det_0, w_det_1, ...) < GLOBAL_MAX_DETUNING
weighted_vals_global = val * sum([v for v in weights.values()])
weighted_vals_global = sympy.Heaviside(weighted_vals_global) - sympy.Heaviside(
weighted_vals_global - max_val
)

return weighted_vals_global

def __post_init__(self) -> None:
# validate weights
if all([not isinstance(v, Parameter) for v in self.weights_amp.values()]):
if not torch.isclose(
torch.tensor(list(self.weights_amp.values())).sum(),
torch.tensor(1.0),
atol=1e-3,
):
raise ValueError("Amplitude addressing pattern weights must sum to 1.0")
if all([not isinstance(v, Parameter) for v in self.weights_det.values()]):
if not torch.isclose(
torch.tensor(list(self.weights_det.values())).sum(),
torch.tensor(1.0),
atol=1e-3,
):
raise ValueError("Detuning addressing pattern weights must sum to 1.0")

# validate detuning value
if not isinstance(self.max_amp, Parameter):
if self.max_amp > DEFAULT_MAX_AMPLITUDE:
# validate amplitude/detuning weights
self._validate_weights(self.weights_amp)
self._validate_weights(self.weights_det)

# validate maximum global amplitude/detuning values
if not isinstance(self.amp, Parameter):
if self.amp > GLOBAL_MAX_AMPLITUDE:
warn("Maximum absolute value of amplitude is exceeded")
if not isinstance(self.max_det, Parameter):
if self.max_det > DEFAULT_MAX_DETUNING:
if not isinstance(self.det, Parameter):
if abs(self.det) > GLOBAL_MAX_DETUNING:
warn("Maximum absolute value of detuning is exceeded")

# augment weight dicts if needed
self.weights_amp = {
i: Parameter(0.0) if i not in self.weights_amp else self.weights_amp[i]
for i in range(self.n_qubits)
}
self.weights_det = {
i: Parameter(0.0) if i not in self.weights_det else self.weights_det[i]
for i in range(self.n_qubits)
}
# constrain amplitude/detuning parameterized weights to [0.0, 1.0] interval
self.weights_amp = self._constrain_weights(self.weights_amp)
self.weights_det = self._constrain_weights(self.weights_det)

# apply weight constraint
if self.weight_constraint == WeightConstraint.NORMALIZE:
self._normalize_weights()
elif self.weight_constraint == WeightConstraint.RESTRICT:
self._restrict_weights()
else:
raise ValueError("Weight constraint type not found.")
# constrain max global amplitude and detuning to strict interval
self._constrain_max_vals()

# restrict max amplitude and detuning to strict interval
self._restrict_max_vals()
# create additional local and global constraints for amplitude/detuning masks
self.local_constr_amp = self._create_local_constraint(
self.amp, self.weights_amp, LOCAL_MAX_AMPLITUDE
)
self.local_constr_det = self._create_local_constraint(
-self.det, self.weights_det, LOCAL_MAX_DETUNING
)
self.global_constr_amp = self._create_global_constraint(
self.amp, self.weights_amp, GLOBAL_MAX_AMPLITUDE
)
self.global_constr_det = self._create_global_constraint(
-self.det, self.weights_det, GLOBAL_MAX_DETUNING
)

# validate number of qubits in mask
if max(list(self.weights_amp.keys())) >= self.n_qubits:
Expand Down
33 changes: 25 additions & 8 deletions qadence/analog/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,33 @@ def rot_generator(block: ConstantAnalogRotation) -> AbstractBlock:
def add_pattern(register: Register, pattern: Union[AddressingPattern, None]) -> AbstractBlock:
support = tuple(range(register.n_qubits))
if pattern is not None:
max_amp = pattern.max_amp
max_det = pattern.max_det
amp = pattern.amp
det = pattern.det
weights_amp = pattern.weights_amp
weights_det = pattern.weights_det
local_constr_amp = pattern.local_constr_amp
local_constr_det = pattern.local_constr_det
global_constr_amp = pattern.global_constr_amp
global_constr_det = pattern.global_constr_det
else:
max_amp = 0.0
max_det = 0.0
amp = 0.0
det = 0.0
weights_amp = {i: 0.0 for i in support}
weights_det = {i: 0.0 for i in support}

p_drive_terms = (1 / 2) * max_amp * add(X(i) * weights_amp[i] for i in support)
p_detuning_terms = -max_det * add(0.5 * (I(i) - Z(i)) * weights_det[i] for i in support)
return p_drive_terms + p_detuning_terms # type: ignore[no-any-return]
local_constr_amp = {i: 0.0 for i in support}
local_constr_det = {i: 0.0 for i in support}
global_constr_amp = 0.0
global_constr_det = 0.0

p_amp_terms = (
(1 / 2)
* amp
* global_constr_amp
* add(X(i) * weights_amp[i] * local_constr_amp[i] for i in support)
)
p_det_terms = (
-det
* global_constr_det
* add(0.5 * (I(i) - Z(i)) * weights_det[i] * local_constr_det[i] for i in support)
)
return p_amp_terms + p_det_terms # type: ignore[no-any-return]
20 changes: 14 additions & 6 deletions qadence/backends/pulser/pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,23 @@ def add_addressing_pattern(

support = tuple(range(n_qubits))
if config.addressing_pattern is not None:
max_amp = config.addressing_pattern.max_amp
max_det = config.addressing_pattern.max_det
amp = config.addressing_pattern.amp
det = config.addressing_pattern.det
weights_amp = config.addressing_pattern.weights_amp
weights_det = config.addressing_pattern.weights_det
local_constr_amp = config.addressing_pattern.local_constr_amp
local_constr_det = config.addressing_pattern.local_constr_det
global_constr_amp = config.addressing_pattern.global_constr_amp
global_constr_det = config.addressing_pattern.global_constr_det
else:
max_amp = 0.0
max_det = 0.0
amp = 0.0
det = 0.0
weights_amp = {i: Parameter(0.0) for i in support}
weights_det = {i: Parameter(0.0) for i in support}
local_constr_amp = {i: Parameter(0.0) for i in support}
local_constr_det = {i: Parameter(0.0) for i in support}
global_constr_amp = 0.0
global_constr_det = 0.0

for i in support:
# declare separate local channel for each qubit
Expand All @@ -77,8 +85,8 @@ def add_addressing_pattern(
if weights_det[i].is_number # type: ignore [union-attr]
else sequence.declare_variable(f"w-det-{i}")
)
omega = max_amp * w_amp
detuning = -max_det * w_det
omega = amp * w_amp
detuning = -det * w_det
pulse = Pulse.ConstantPulse(
duration=total_duration, amplitude=omega, detuning=detuning, phase=0
)
Expand Down
32 changes: 18 additions & 14 deletions tests/analog/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@


@pytest.mark.parametrize(
"max_amp,max_det",
"amp,det",
[(0.0, 10.0), (15.0, 0.0), (15.0, 9.0)],
)
@pytest.mark.parametrize(
"spacing",
[8.0, 30.0],
)
def test_pulser_pyq_addressing(max_amp: float, max_det: float, spacing: float) -> None:
def test_pulser_pyq_addressing(amp: float, det: float, spacing: float) -> None:
n_qubits = 3
block = AnalogRX("x")
circ = QuantumCircuit(n_qubits, block)
Expand All @@ -43,8 +43,8 @@ def test_pulser_pyq_addressing(max_amp: float, max_det: float, spacing: float) -
w_det = {i: rand_weights_det[i] for i in range(n_qubits)}
p = AddressingPattern(
n_qubits=n_qubits,
max_det=max_det,
max_amp=max_amp,
det=det,
amp=amp,
weights_det=w_det,
weights_amp=w_amp,
)
Expand Down Expand Up @@ -80,12 +80,12 @@ def test_addressing_training() -> None:
# define training parameters
w_amp = {i: Parameter(f"w_amp{i}", trainable=True) for i in range(n_qubits)}
w_det = {i: Parameter(f"w_det{i}", trainable=True) for i in range(n_qubits)}
max_amp = Parameter("max_amp", trainable=True)
max_det = Parameter("max_det", trainable=True)
amp = Parameter("amp", trainable=True)
det = Parameter("det", trainable=True)
p = AddressingPattern(
n_qubits=n_qubits,
max_det=max_det,
max_amp=max_amp,
det=det,
amp=amp,
weights_det=w_det, # type: ignore [arg-type]
weights_amp=w_amp, # type: ignore [arg-type]
)
Expand Down Expand Up @@ -116,10 +116,14 @@ def test_addressing_training() -> None:
# get final results
f_value_model = model.expectation({}).detach()

assert torch.all(
torch.tensor(list(p.evaluate(p.weights_amp, model.vparams).values())) > 0.0
) and torch.all(torch.tensor(list(p.evaluate(p.weights_amp, model.vparams).values())) < 1.0)
assert torch.all(
torch.tensor(list(p.evaluate(p.weights_det, model.vparams).values())) > 0.0
) and torch.all(torch.tensor(list(p.evaluate(p.weights_det, model.vparams).values())) < 1.0)
weights_amp = torch.tensor(list(p.evaluate(p.weights_amp, model.vparams).values()))
weights_amp_mask = weights_amp.abs() < 0.001
weights_amp[weights_amp_mask] = 0.0

weights_det = torch.tensor(list(p.evaluate(p.weights_det, model.vparams).values()))
weights_det_mask = weights_det.abs() < 0.001
weights_det[weights_det_mask] = 0.0

assert torch.all(weights_amp >= 0.0) and torch.all(weights_amp <= 1.0)
assert torch.all(weights_det >= 0.0) and torch.all(weights_det <= 1.0)
assert torch.isclose(f_value, f_value_model, atol=ATOL_DICT[BackendName.PULSER])

0 comments on commit 0aa22bc

Please sign in to comment.