From a49c28752775fea225fcdd6525f9395048c37a5a Mon Sep 17 00:00:00 2001 From: Vytautas Abramavicius Date: Fri, 17 Nov 2023 22:10:46 +0200 Subject: [PATCH] added local and global constraints --- qadence/analog/addressing.py | 175 +++++++++++++++++------------- qadence/analog/utils.py | 33 ++++-- qadence/backends/pulser/pulses.py | 20 +++- tests/analog/test_patterns.py | 32 +++--- 4 files changed, 157 insertions(+), 103 deletions(-) diff --git a/qadence/analog/addressing.py b/qadence/analog/addressing.py index ab68a38a4..9c74bb448 100644 --- a/qadence/analog/addressing.py +++ b/qadence/analog/addressing.py @@ -10,8 +10,10 @@ from qadence.parameters import Parameter, evaluate from qadence.types import StrEnum -DEFAULT_MAX_AMPLITUDE = 2 * pi * 3 -DEFAULT_MAX_DETUNING = 2 * pi * 20 +GLOBAL_MAX_AMPLITUDE = 300 +GLOBAL_MAX_DETUNING = 2 * pi * 2000 +LOCAL_MAX_AMPLITUDE = 3 +LOCAL_MAX_DETUNING = 2 * pi * 20 class WeightConstraint(StrEnum): @@ -39,93 +41,116 @@ class AddressingPattern: # list of weights for fixed detuning pattern that cannot be changed during the execution weights_det: dict[int, float | torch.Tensor | Parameter] - # maximum amplitude can also be chosen as a variational parameter if needed - max_amp: float | torch.Tensor | Parameter = DEFAULT_MAX_AMPLITUDE - - # maximum detuning can also be chosen as a variational parameter if needed - max_det: float | torch.Tensor | Parameter = DEFAULT_MAX_DETUNING - - # weight constraint - weight_constraint: WeightConstraint = WeightConstraint.NORMALIZE - - def _normalize_weights(self) -> None: - self.weights_amp = { - k: Parameter(v) if not isinstance(v, Parameter) else abs(v) - for k, v in self.weights_amp.items() + # amplitude can also be chosen as a variational parameter if needed + amp: float | torch.Tensor | Parameter = LOCAL_MAX_AMPLITUDE + + # detuning can also be chosen as a variational parameter if needed + det: float | torch.Tensor | Parameter = LOCAL_MAX_DETUNING + + def _validate_weights( + self, + weights: dict[int, float | torch.Tensor | Parameter], + ) -> None: + for v in weights.values(): + if not isinstance(v, Parameter): + if not (v >= 0.0 and v <= 1.0): + raise ValueError("Addressing pattern weights must sum fall in range [0.0, 1.0]") + + def _constrain_weights( + self, + weights: dict[int, float | torch.Tensor | Parameter], + ) -> dict: + # augment weight dict if needed + weights = { + i: Parameter(0.0) + if i not in weights + else (Parameter(weights[i]) if not isinstance(weights[i], Parameter) else weights[i]) + for i in range(self.n_qubits) } - sum_weights_amp = sum(list(self.weights_amp.values())) - self.weights_amp = {k: v / sum_weights_amp for k, v in self.weights_amp.items()} - self.weights_det = { - k: Parameter(v) if not isinstance(v, Parameter) else abs(v) - for k, v in self.weights_det.items() + # restrict weights to [0, 1] range + weights = { + k: abs(v * (sigmoid(v, 20, 1.0) - sigmoid(v, 20.0, -1.0))) for k, v in weights.items() } - sum_weights_det = sum(list(self.weights_det.values())) - self.weights_det = {k: v / sum_weights_det for k, v in self.weights_det.items()} - def _restrict_weights(self) -> None: - self.weights_amp = { - k: v * (sigmoid(v, 20, 0.0) - sigmoid(v, 20.0, -1.0)) - for k, v in self.weights_amp.items() - } - self.weights_det = { - k: v * (sigmoid(v, 20.0, 0.0) - sigmoid(v, 20.0, -1.0)) - for k, v in self.weights_det.items() + return weights + + def _constrain_max_vals(self) -> None: + # enforce constraints: + # 0 <= amp <= GLOBAL_MAX_AMPLITUDE + # 0 <= abs(det) <= GLOBAL_MAX_DETUNING + self.amp = abs( + self.amp + * ( + sympy.Heaviside(self.amp + GLOBAL_MAX_AMPLITUDE) + - sympy.Heaviside(self.amp - GLOBAL_MAX_AMPLITUDE) + ) + ) + self.det = -abs( + self.det + * ( + sympy.Heaviside(self.det + GLOBAL_MAX_DETUNING) + - sympy.Heaviside(self.det - GLOBAL_MAX_DETUNING) + ) + ) + + def _create_local_constraint(self, val: sympy.Expr, weights: dict, max_val: float) -> dict: + # enforce local constraints: + # amp * w_amp_i < LOCAL_MAX_AMPLITUDE or + # abs(det) * w_det_i < LOCAL_MAX_DETUNING + local_constr = {k: val * v for k, v in weights.items()} + local_constr = { + k: sympy.Heaviside(v) - sympy.Heaviside(v - max_val) for k, v in local_constr.items() } - def _restrict_max_vals(self) -> None: - self.max_amp = self.max_amp * ( - sympy.Heaviside(self.max_amp) - sympy.Heaviside(self.max_amp - DEFAULT_MAX_AMPLITUDE) - ) - self.max_det = self.max_det * ( - sympy.Heaviside(self.max_det) - sympy.Heaviside(self.max_det - DEFAULT_MAX_DETUNING) + return local_constr + + def _create_global_constraint( + self, val: sympy.Expr, weights: dict, max_val: float + ) -> sympy.Expr: + # enforce global constraints: + # amp * sum(w_amp_0, w_amp_1, ...) < GLOBAL_MAX_AMPLITUDE or + # abs(det) * sum(w_det_0, w_det_1, ...) < GLOBAL_MAX_DETUNING + weighted_vals_global = val * sum([v for v in weights.values()]) + weighted_vals_global = sympy.Heaviside(weighted_vals_global) - sympy.Heaviside( + weighted_vals_global - max_val ) + return weighted_vals_global + def __post_init__(self) -> None: - # validate weights - if all([not isinstance(v, Parameter) for v in self.weights_amp.values()]): - if not torch.isclose( - torch.tensor(list(self.weights_amp.values())).sum(), - torch.tensor(1.0), - atol=1e-3, - ): - raise ValueError("Amplitude addressing pattern weights must sum to 1.0") - if all([not isinstance(v, Parameter) for v in self.weights_det.values()]): - if not torch.isclose( - torch.tensor(list(self.weights_det.values())).sum(), - torch.tensor(1.0), - atol=1e-3, - ): - raise ValueError("Detuning addressing pattern weights must sum to 1.0") - - # validate detuning value - if not isinstance(self.max_amp, Parameter): - if self.max_amp > DEFAULT_MAX_AMPLITUDE: + # validate amplitude/detuning weights + self._validate_weights(self.weights_amp) + self._validate_weights(self.weights_det) + + # validate maximum global amplitude/detuning values + if not isinstance(self.amp, Parameter): + if self.amp > GLOBAL_MAX_AMPLITUDE: warn("Maximum absolute value of amplitude is exceeded") - if not isinstance(self.max_det, Parameter): - if self.max_det > DEFAULT_MAX_DETUNING: + if not isinstance(self.det, Parameter): + if abs(self.det) > GLOBAL_MAX_DETUNING: warn("Maximum absolute value of detuning is exceeded") - # augment weight dicts if needed - self.weights_amp = { - i: Parameter(0.0) if i not in self.weights_amp else self.weights_amp[i] - for i in range(self.n_qubits) - } - self.weights_det = { - i: Parameter(0.0) if i not in self.weights_det else self.weights_det[i] - for i in range(self.n_qubits) - } + # constrain amplitude/detuning parameterized weights to [0.0, 1.0] interval + self.weights_amp = self._constrain_weights(self.weights_amp) + self.weights_det = self._constrain_weights(self.weights_det) - # apply weight constraint - if self.weight_constraint == WeightConstraint.NORMALIZE: - self._normalize_weights() - elif self.weight_constraint == WeightConstraint.RESTRICT: - self._restrict_weights() - else: - raise ValueError("Weight constraint type not found.") + # constrain max global amplitude and detuning to strict interval + self._constrain_max_vals() - # restrict max amplitude and detuning to strict interval - self._restrict_max_vals() + # create additional local and global constraints for amplitude/detuning masks + self.local_constr_amp = self._create_local_constraint( + self.amp, self.weights_amp, LOCAL_MAX_AMPLITUDE + ) + self.local_constr_det = self._create_local_constraint( + -self.det, self.weights_det, LOCAL_MAX_DETUNING + ) + self.global_constr_amp = self._create_global_constraint( + self.amp, self.weights_amp, GLOBAL_MAX_AMPLITUDE + ) + self.global_constr_det = self._create_global_constraint( + -self.det, self.weights_det, GLOBAL_MAX_DETUNING + ) # validate number of qubits in mask if max(list(self.weights_amp.keys())) >= self.n_qubits: diff --git a/qadence/analog/utils.py b/qadence/analog/utils.py index e0e222871..98e3b4586 100644 --- a/qadence/analog/utils.py +++ b/qadence/analog/utils.py @@ -137,16 +137,33 @@ def rot_generator(block: ConstantAnalogRotation) -> AbstractBlock: def add_pattern(register: Register, pattern: Union[AddressingPattern, None]) -> AbstractBlock: support = tuple(range(register.n_qubits)) if pattern is not None: - max_amp = pattern.max_amp - max_det = pattern.max_det + amp = pattern.amp + det = pattern.det weights_amp = pattern.weights_amp weights_det = pattern.weights_det + local_constr_amp = pattern.local_constr_amp + local_constr_det = pattern.local_constr_det + global_constr_amp = pattern.global_constr_amp + global_constr_det = pattern.global_constr_det else: - max_amp = 0.0 - max_det = 0.0 + amp = 0.0 + det = 0.0 weights_amp = {i: 0.0 for i in support} weights_det = {i: 0.0 for i in support} - - p_drive_terms = (1 / 2) * max_amp * add(X(i) * weights_amp[i] for i in support) - p_detuning_terms = -max_det * add(0.5 * (I(i) - Z(i)) * weights_det[i] for i in support) - return p_drive_terms + p_detuning_terms # type: ignore[no-any-return] + local_constr_amp = {i: 0.0 for i in support} + local_constr_det = {i: 0.0 for i in support} + global_constr_amp = 0.0 + global_constr_det = 0.0 + + p_amp_terms = ( + (1 / 2) + * amp + * global_constr_amp + * add(X(i) * weights_amp[i] * local_constr_amp[i] for i in support) + ) + p_det_terms = ( + -det + * global_constr_det + * add(0.5 * (I(i) - Z(i)) * weights_det[i] * local_constr_det[i] for i in support) + ) + return p_amp_terms + p_det_terms # type: ignore[no-any-return] diff --git a/qadence/backends/pulser/pulses.py b/qadence/backends/pulser/pulses.py index fa60b64e4..7687b56d1 100644 --- a/qadence/backends/pulser/pulses.py +++ b/qadence/backends/pulser/pulses.py @@ -51,15 +51,23 @@ def add_addressing_pattern( support = tuple(range(n_qubits)) if config.addressing_pattern is not None: - max_amp = config.addressing_pattern.max_amp - max_det = config.addressing_pattern.max_det + amp = config.addressing_pattern.amp + det = config.addressing_pattern.det weights_amp = config.addressing_pattern.weights_amp weights_det = config.addressing_pattern.weights_det + local_constr_amp = config.addressing_pattern.local_constr_amp + local_constr_det = config.addressing_pattern.local_constr_det + global_constr_amp = config.addressing_pattern.global_constr_amp + global_constr_det = config.addressing_pattern.global_constr_det else: - max_amp = 0.0 - max_det = 0.0 + amp = 0.0 + det = 0.0 weights_amp = {i: Parameter(0.0) for i in support} weights_det = {i: Parameter(0.0) for i in support} + local_constr_amp = {i: Parameter(0.0) for i in support} + local_constr_det = {i: Parameter(0.0) for i in support} + global_constr_amp = 0.0 + global_constr_det = 0.0 for i in support: # declare separate local channel for each qubit @@ -77,8 +85,8 @@ def add_addressing_pattern( if weights_det[i].is_number # type: ignore [union-attr] else sequence.declare_variable(f"w-det-{i}") ) - omega = max_amp * w_amp - detuning = -max_det * w_det + omega = amp * w_amp + detuning = -det * w_det pulse = Pulse.ConstantPulse( duration=total_duration, amplitude=omega, detuning=detuning, phase=0 ) diff --git a/tests/analog/test_patterns.py b/tests/analog/test_patterns.py index 5f70169ff..0788fae11 100644 --- a/tests/analog/test_patterns.py +++ b/tests/analog/test_patterns.py @@ -22,14 +22,14 @@ @pytest.mark.parametrize( - "max_amp,max_det", + "amp,det", [(0.0, 10.0), (15.0, 0.0), (15.0, 9.0)], ) @pytest.mark.parametrize( "spacing", [8.0, 30.0], ) -def test_pulser_pyq_addressing(max_amp: float, max_det: float, spacing: float) -> None: +def test_pulser_pyq_addressing(amp: float, det: float, spacing: float) -> None: n_qubits = 3 block = AnalogRX("x") circ = QuantumCircuit(n_qubits, block) @@ -43,8 +43,8 @@ def test_pulser_pyq_addressing(max_amp: float, max_det: float, spacing: float) - w_det = {i: rand_weights_det[i] for i in range(n_qubits)} p = AddressingPattern( n_qubits=n_qubits, - max_det=max_det, - max_amp=max_amp, + det=det, + amp=amp, weights_det=w_det, weights_amp=w_amp, ) @@ -80,12 +80,12 @@ def test_addressing_training() -> None: # define training parameters w_amp = {i: Parameter(f"w_amp{i}", trainable=True) for i in range(n_qubits)} w_det = {i: Parameter(f"w_det{i}", trainable=True) for i in range(n_qubits)} - max_amp = Parameter("max_amp", trainable=True) - max_det = Parameter("max_det", trainable=True) + amp = Parameter("amp", trainable=True) + det = Parameter("det", trainable=True) p = AddressingPattern( n_qubits=n_qubits, - max_det=max_det, - max_amp=max_amp, + det=det, + amp=amp, weights_det=w_det, # type: ignore [arg-type] weights_amp=w_amp, # type: ignore [arg-type] ) @@ -116,10 +116,14 @@ def test_addressing_training() -> None: # get final results f_value_model = model.expectation({}).detach() - assert torch.all( - torch.tensor(list(p.evaluate(p.weights_amp, model.vparams).values())) > 0.0 - ) and torch.all(torch.tensor(list(p.evaluate(p.weights_amp, model.vparams).values())) < 1.0) - assert torch.all( - torch.tensor(list(p.evaluate(p.weights_det, model.vparams).values())) > 0.0 - ) and torch.all(torch.tensor(list(p.evaluate(p.weights_det, model.vparams).values())) < 1.0) + weights_amp = torch.tensor(list(p.evaluate(p.weights_amp, model.vparams).values())) + weights_amp_mask = weights_amp.abs() < 0.001 + weights_amp[weights_amp_mask] = 0.0 + + weights_det = torch.tensor(list(p.evaluate(p.weights_det, model.vparams).values())) + weights_det_mask = weights_det.abs() < 0.001 + weights_det[weights_det_mask] = 0.0 + + assert torch.all(weights_amp >= 0.0) and torch.all(weights_amp <= 1.0) + assert torch.all(weights_det >= 0.0) and torch.all(weights_det <= 1.0) assert torch.isclose(f_value, f_value_model, atol=ATOL_DICT[BackendName.PULSER])