From 73825cf59de53aff129fd0c275d21638aed90429 Mon Sep 17 00:00:00 2001 From: Vytautas Abramavicius Date: Mon, 2 Oct 2023 14:35:00 +0300 Subject: [PATCH] pulser backend Co-authored-by: Kaonan Micadei Co-authored-by: Vytautas Abramavicius --- qadence/backends/pulser/__init__.py | 5 + qadence/backends/pulser/backend.py | 242 +++++++++++++++++++++++++ qadence/backends/pulser/channels.py | 16 ++ qadence/backends/pulser/config.py | 54 ++++++ qadence/backends/pulser/convert_ops.py | 42 +++++ qadence/backends/pulser/devices.py | 77 ++++++++ qadence/backends/pulser/pulses.py | 215 ++++++++++++++++++++++ qadence/backends/pulser/waveforms.py | 78 ++++++++ 8 files changed, 729 insertions(+) create mode 100644 qadence/backends/pulser/__init__.py create mode 100644 qadence/backends/pulser/backend.py create mode 100644 qadence/backends/pulser/channels.py create mode 100644 qadence/backends/pulser/config.py create mode 100644 qadence/backends/pulser/convert_ops.py create mode 100644 qadence/backends/pulser/devices.py create mode 100644 qadence/backends/pulser/pulses.py create mode 100644 qadence/backends/pulser/waveforms.py diff --git a/qadence/backends/pulser/__init__.py b/qadence/backends/pulser/__init__.py new file mode 100644 index 00000000..f4faa035 --- /dev/null +++ b/qadence/backends/pulser/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .backend import Backend, Configuration +from .devices import Device +from .pulses import supported_gates diff --git a/qadence/backends/pulser/backend.py b/qadence/backends/pulser/backend.py new file mode 100644 index 00000000..8e71ec07 --- /dev/null +++ b/qadence/backends/pulser/backend.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass +from typing import Any + +import numpy as np +import qutip +import torch +from pulser import Register as PulserRegister +from pulser import Sequence +from pulser.pulse import Pulse +from pulser_simulation.simresults import SimulationResults +from pulser_simulation.simulation import QutipEmulator +from torch import Tensor + +from qadence.backend import Backend as BackendInterface +from qadence.backend import BackendName, ConvertedCircuit, ConvertedObservable +from qadence.backends.utils import to_list_of_dicts +from qadence.blocks import AbstractBlock +from qadence.circuit import QuantumCircuit +from qadence.measurements import Measurements +from qadence.overlap import overlap_exact +from qadence.register import Register +from qadence.utils import Endianness + +from .channels import GLOBAL_CHANNEL, LOCAL_CHANNEL +from .config import Configuration +from .convert_ops import convert_observable +from .devices import Device, IdealDevice, RealisticDevice +from .pulses import add_pulses + +WEAK_COUPLING_CONST = 1.2 + +DEFAULT_SPACING = 8.0 # µm (standard value) + + +def create_register(register: Register, spacing: float = DEFAULT_SPACING) -> PulserRegister: + """Create Pulser register instance. + + Args: + register (Register): graph representing a register with accompanying coordinate data + spacing (float): distance between qubits in micrometers + + Returns: + Register: Pulser register + """ + + # create register from coordinates + coords = np.array(list(register.coords.values())) + return PulserRegister.from_coordinates(coords * spacing) + + +def make_sequence(circ: QuantumCircuit, config: Configuration) -> Sequence: + if config.device_type == Device.IDEALIZED: + device = IdealDevice + elif config.device_type == Device.REALISTIC: + device = RealisticDevice + else: + raise ValueError("Specified device is not supported.") + + max_amp = device.channels["rydberg_global"].max_amp + min_duration = device.channels["rydberg_global"].min_duration + + if config.spacing is not None: + spacing = config.spacing + elif max_amp is not None: + # Ideal spacing for entanglement gate + spacing = WEAK_COUPLING_CONST * device.rydberg_blockade_radius(max_amp) # type: ignore + else: + spacing = DEFAULT_SPACING + + pulser_register = create_register(circ.register, spacing) + + sequence = Sequence(pulser_register, device) + sequence.declare_channel(GLOBAL_CHANNEL, "rydberg_global") + sequence.declare_channel(LOCAL_CHANNEL, "rydberg_local", initial_target=0) + + # add a minimum duration pulse omega=0 pulse at the beginning for simulation convergence reasons + # since Pulser's QutipEmulator doesn't allow simulation of sequences with total duration < 4ns + zero_pulse = Pulse.ConstantPulse( + duration=max(sequence.device.channels["rydberg_global"].min_duration, 4), + amplitude=0.0, + detuning=0.0, + phase=0.0, + ) + sequence.add(zero_pulse, GLOBAL_CHANNEL, "wait-for-all") + + add_pulses(sequence, circ.block, config, circ.register, spacing) + sequence.measure() + + return sequence + + +# TODO: make it parallelized +# TODO: add execution on the cloud platform +def simulate_sequence( + sequence: Sequence, config: Configuration, state: Tensor +) -> SimulationResults: + simulation = QutipEmulator.from_sequence( + sequence, + sampling_rate=config.sampling_rate, + config=config.sim_config, + with_modulation=config.with_modulation, + ) + if state is not None: + simulation.set_initial_state(qutip.Qobj(state.cpu().numpy())) + + return simulation.run(nsteps=config.n_steps_solv, method=config.method_solv) + + +@dataclass(frozen=True, eq=True) +class Backend(BackendInterface): + """The Pulser backend""" + + name: BackendName = BackendName.PULSER + supports_ad: bool = False + support_bp: bool = False + is_remote: bool = False + with_measurements: bool = True + with_noise: bool = False + native_endianness: Endianness = Endianness.BIG + config: Configuration = Configuration() + + def circuit(self, circ: QuantumCircuit) -> Sequence: + native = make_sequence(circ, self.config) + + return ConvertedCircuit(native=native, abstract=circ, original=circ) + + def observable(self, observable: AbstractBlock, n_qubits: int = None) -> Tensor: + from qadence.transpile import flatten, scale_primitive_blocks_only, transpile + + # make sure only leaves, i.e. primitive blocks are scaled + block = transpile(flatten, scale_primitive_blocks_only)(observable) + + (native,) = convert_observable(block, n_qubits=n_qubits, config=self.config) + return ConvertedObservable(native=native, abstract=block, original=observable) + + def assign_parameters( + self, + circuit: ConvertedCircuit, + param_values: dict[str, Tensor], + ) -> Any: + if param_values == {} and circuit.native.is_parametrized(): + missing = list(circuit.native.declared_variables.keys()) + raise ValueError(f"Please, provide values for the following parameters: {missing}") + + if param_values == {}: + return circuit.native + + numpy_param_values = { + k: v.detach().cpu().numpy() + for (k, v) in param_values.items() + if k in circuit.native.declared_variables + } + + return circuit.native.build(**numpy_param_values) + + def run( + self, + circuit: ConvertedCircuit, + param_values: dict[str, Tensor] = {}, + state: Tensor | None = None, + endianness: Endianness = Endianness.BIG, + ) -> Tensor: + vals = to_list_of_dicts(param_values) + + batched_wf = np.zeros((len(vals), 2**circuit.abstract.n_qubits), dtype=np.complex128) + + for i, param_values_el in enumerate(vals): + sequence = self.assign_parameters(circuit, param_values_el) + sim_result = simulate_sequence(sequence, self.config, state) + wf = ( + sim_result.get_final_state(ignore_global_phase=False, normalize=True) + .full() + .flatten() + ) + + # We flip the wavefunction coming out of pulser, + # essentially changing logic 0 with logic 1 in the basis states. + batched_wf[i] = np.flip(wf) + + batched_wf_torch = torch.from_numpy(batched_wf) + + if endianness != self.native_endianness: + from qadence.transpile import invert_endianness + + batched_wf_torch = invert_endianness(batched_wf_torch) + + return batched_wf_torch + + def sample( + self, + circuit: ConvertedCircuit, + param_values: dict[str, Tensor] = {}, + n_shots: int = 1, + state: Tensor | None = None, + endianness: Endianness = Endianness.BIG, + ) -> list[Counter]: + if n_shots < 1: + raise ValueError("You can only call sample with n_shots>0.") + + vals = to_list_of_dicts(param_values) + + samples = [] + for param_values_el in vals: + sequence = self.assign_parameters(circuit, param_values_el) + sim_result = simulate_sequence(sequence, self.config, state) + sample = sim_result.sample_final_state(n_shots) + samples.append(sample) + if endianness != self.native_endianness: + from qadence.transpile import invert_endianness + + samples = invert_endianness(samples) + return samples + + def expectation( + self, + circuit: ConvertedCircuit, + observable: list[ConvertedObservable] | ConvertedObservable, + param_values: dict[str, Tensor] = {}, + state: Tensor | None = None, + protocol: Measurements | None = None, + endianness: Endianness = Endianness.BIG, + ) -> Tensor: + state = self.run(circuit, param_values=param_values, state=state, endianness=endianness) + + observables = observable if isinstance(observable, list) else [observable] + support = sorted(list(circuit.abstract.register.support)) + res_list = [obs.native(state, param_values, qubit_support=support) for obs in observables] + + res = torch.transpose(torch.stack(res_list), 0, 1).squeeze() + res = res if len(res.shape) > 0 else res.reshape(1) + return res.real + + @staticmethod + def _overlap(bras: Tensor, kets: Tensor) -> Tensor: + return overlap_exact(bras, kets) + + @staticmethod + def default_configuration() -> Configuration: + return Configuration() diff --git a/qadence/backends/pulser/channels.py b/qadence/backends/pulser/channels.py new file mode 100644 index 00000000..8d576e42 --- /dev/null +++ b/qadence/backends/pulser/channels.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from pulser.channels.channels import Rydberg + +GLOBAL_CHANNEL = "Global" +LOCAL_CHANNEL = "Local" + + +@dataclass(frozen=True) +class CustomRydberg(Rydberg): + name: str = "Rydberg" + + duration_steps: int = 1 # ns + amplitude_steps: float = 0.01 # rad/µs diff --git a/qadence/backends/pulser/config.py b/qadence/backends/pulser/config.py new file mode 100644 index 00000000..d247f912 --- /dev/null +++ b/qadence/backends/pulser/config.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from pulser_simulation.simconfig import SimConfig + +from qadence.backend import BackendConfiguration +from qadence.blocks.analog import Interaction + +from .devices import Device + + +@dataclass +class Configuration(BackendConfiguration): + # device type + device_type: Device = Device.IDEALIZED + + # atomic spacing + spacing: Optional[float] = None + + # sampling rate to be used for local simulations + sampling_rate: float = 1.0 + + # solver method to pass to the Qutip solver + method_solv: str = "adams" + + # number of solver steps to pass to the Qutip solver + n_steps_solv: float = 1e8 + + # simulation configuration with optional noise options + sim_config: Optional[SimConfig] = None + + # add modulation to the local execution + with_modulation: bool = False + + # Use gate-level parameters + use_gate_params = True + + # pulse amplitude on local channel + amplitude_local: Optional[float] = None + + # pulse amplitude on global channel + amplitude_global: Optional[float] = None + + # detuning value + detuning: Optional[float] = None + + # interaction type + interaction: Interaction = Interaction.NN + + def __post_init__(self) -> None: + if self.sim_config is not None and not isinstance(self.sim_config, SimConfig): + raise TypeError("Wrong 'sim_config' attribute type, pass a valid SimConfig object!") diff --git a/qadence/backends/pulser/convert_ops.py b/qadence/backends/pulser/convert_ops.py new file mode 100644 index 00000000..1b254e79 --- /dev/null +++ b/qadence/backends/pulser/convert_ops.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import Sequence + +import torch +from torch.nn import Module + +from qadence.blocks import ( + AbstractBlock, +) +from qadence.blocks.block_to_tensor import ( + block_to_tensor, +) +from qadence.utils import Endianness + +from .config import Configuration + + +def convert_observable( + block: AbstractBlock, n_qubits: int | None, config: Configuration = None +) -> Sequence[Module]: + return [PulserObservable(block, n_qubits)] + + +class PulserObservable(Module): + def __init__(self, block: AbstractBlock, n_qubits: int | None): + super().__init__() + self.block = block + self.n_qubits = n_qubits + + def forward( + self, + state: torch.Tensor, + values: dict[str, torch.Tensor] | list = {}, + qubit_support: tuple | None = None, + endianness: Endianness = Endianness.BIG, + ) -> torch.Tensor: + # FIXME: cache this, it is very inefficient for non-parametric observables + block_mat = block_to_tensor( + self.block, values, qubit_support=qubit_support, endianness=endianness # type: ignore [arg-type] # noqa + ).squeeze(0) + return torch.sum(torch.matmul(state, block_mat) * state.conj(), dim=1) diff --git a/qadence/backends/pulser/devices.py b/qadence/backends/pulser/devices.py new file mode 100644 index 00000000..6b479b6a --- /dev/null +++ b/qadence/backends/pulser/devices.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from numpy import pi +from pulser.channels.channels import Rydberg +from pulser.channels.eom import RydbergBeam, RydbergEOM +from pulser.devices._device_datacls import Device as PulserDevice +from pulser.devices._device_datacls import VirtualDevice + +from qadence.types import StrEnum + +# Idealized virtual device +IdealDevice = VirtualDevice( + name="IdealizedDevice", + dimensions=2, + rydberg_level=60, + max_atom_num=100, + max_radial_distance=100, + min_atom_distance=0, + channel_objects=( + Rydberg.Global(max_abs_detuning=2 * pi * 4, max_amp=2 * pi * 3), + Rydberg.Local(max_targets=1000, max_abs_detuning=2 * pi * 4, max_amp=2 * pi * 3), + ), +) + + +# device with realistic specs with local channels and custom bandwith. +RealisticDevice = PulserDevice( + name="RealisticDevice", + dimensions=2, + rydberg_level=60, + max_atom_num=100, + max_radial_distance=60, + min_atom_distance=5, + channel_objects=( + Rydberg.Global( + max_abs_detuning=2 * pi * 4, + max_amp=2 * pi * 3, + clock_period=4, + min_duration=16, + max_duration=2**26, + mod_bandwidth=16, + eom_config=RydbergEOM( + limiting_beam=RydbergBeam.RED, + max_limiting_amp=40 * 2 * pi, + intermediate_detuning=700 * 2 * pi, + mod_bandwidth=24, + controlled_beams=(RydbergBeam.BLUE,), + ), + ), + Rydberg.Local( + max_targets=20, + max_abs_detuning=2 * pi * 4, + max_amp=2 * pi * 3, + clock_period=4, + min_duration=16, + max_duration=2**26, + mod_bandwidth=16, + eom_config=RydbergEOM( + limiting_beam=RydbergBeam.RED, + max_limiting_amp=40 * 2 * pi, + intermediate_detuning=700 * 2 * pi, + mod_bandwidth=24, + controlled_beams=(RydbergBeam.BLUE,), + ), + ), + ), +) + + +class Device(StrEnum): + """Supported types of devices for Pulser backend""" + + IDEALIZED = IdealDevice + "idealized device, least realistic" + + REALISTIC = RealisticDevice + "device with realistic specs" diff --git a/qadence/backends/pulser/pulses.py b/qadence/backends/pulser/pulses.py new file mode 100644 index 00000000..8a0a3cae --- /dev/null +++ b/qadence/backends/pulser/pulses.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from functools import partial +from typing import Union + +import numpy as np +from pulser.channels.base_channel import Channel +from pulser.parametrized.variable import Variable, VariableItem +from pulser.pulse import Pulse +from pulser.sequence.sequence import Sequence +from pulser.waveforms import CompositeWaveform, ConstantWaveform, RampWaveform + +from qadence import Register +from qadence.blocks import AbstractBlock, CompositeBlock +from qadence.blocks.analog import ( + AnalogBlock, + AnalogComposite, + ConstantAnalogRotation, + Interaction, + WaitBlock, +) +from qadence.operations import RX, RY, AnalogEntanglement, OpName +from qadence.parameters import evaluate + +from .channels import GLOBAL_CHANNEL, LOCAL_CHANNEL +from .config import Configuration +from .waveforms import SquareWaveform + +TVar = Union[Variable, VariableItem] + +supported_gates = [ + OpName.ZERO, + OpName.RX, + OpName.RY, + OpName.ANALOGENTANG, + OpName.ANALOGRX, + OpName.ANALOGRY, + OpName.ANALOGSWAP, + OpName.WAIT, +] + + +def add_pulses( + sequence: Sequence, + block: AbstractBlock, + config: Configuration, + qc_register: Register, + spacing: float, +) -> None: + # we need this because of the case with a single type of block in a KronBlock + # TODO: document properly + + n_qubits = len(sequence.register.qubits) + + # define qubit support + qubit_support = block.qubit_support + if not isinstance(qubit_support[0], int): + qubit_support = tuple(range(n_qubits)) + + if isinstance(block, AnalogBlock) and config.interaction != Interaction.NN: + raise ValueError(f"Pulser does not support other interactions than '{Interaction.NN}'") + + local_channel = sequence.device.channels["rydberg_local"] + global_channel = sequence.device.channels["rydberg_global"] + + rx = partial(digital_rot_pulse, channel=local_channel, phase=0, config=config) + ry = partial(digital_rot_pulse, channel=local_channel, phase=np.pi / 2, config=config) + + # TODO: lets move those to `@singledipatch`ed functions + if isinstance(block, WaitBlock): + # wait if its a global wait + if block.qubit_support.is_global: + (uuid, duration) = block.parameters.uuid_param("duration") + t = evaluate(duration) if duration.is_number else sequence.declare_variable(uuid) + pulse = Pulse.ConstantPulse(duration=t, amplitude=0, detuning=0, phase=0) + sequence.add(pulse, GLOBAL_CHANNEL, "wait-for-all") + + # do nothing if its a non-global wait, because that means we are doing a rotation + # on other qubits + else: + support = set(block.qubit_support) + if not support.issubset(sequence.register.qubits): + raise ValueError("Trying to wait on qubits outside of support.") + + elif isinstance(block, ConstantAnalogRotation): + ps = block.parameters + (a_uuid, alpha) = ps.uuid_param("alpha") + (w_uuid, omega) = ps.uuid_param("omega") + (p_uuid, phase) = ps.uuid_param("phase") + (d_uuid, detuning) = ps.uuid_param("delta") + + a = evaluate(alpha) if alpha.is_number else sequence.declare_variable(a_uuid) + w = evaluate(omega) if omega.is_number else sequence.declare_variable(w_uuid) + p = evaluate(phase) if phase.is_number else sequence.declare_variable(p_uuid) + d = evaluate(detuning) if detuning.is_number else sequence.declare_variable(d_uuid) + + # calculate generator eigenvalues + block.eigenvalues_generator = block.compute_eigenvalues_generator( + qc_register, block, spacing + ) + + if block.qubit_support.is_global: + pulse = analog_rot_pulse(a, w, p, d, global_channel, config) + sequence.add(pulse, GLOBAL_CHANNEL, protocol="wait-for-all") + else: + pulse = analog_rot_pulse(a, w, p, d, local_channel, config) + sequence.target(qubit_support, LOCAL_CHANNEL) + sequence.add(pulse, LOCAL_CHANNEL, protocol="wait-for-all") + + elif isinstance(block, AnalogEntanglement): + (uuid, duration) = block.parameters.uuid_param("duration") + t = evaluate(duration) if duration.is_number else sequence.declare_variable(uuid) + sequence.add( + entangle_pulse(t, global_channel, config), GLOBAL_CHANNEL, protocol="wait-for-all" + ) + + elif isinstance(block, (RX, RY)): + (uuid, p) = block.parameters.uuid_param("parameter") + angle = evaluate(p) if p.is_number else sequence.declare_variable(uuid) + pulse = rx(angle) if isinstance(block, RX) else ry(angle) + sequence.target(qubit_support, LOCAL_CHANNEL) + sequence.add(pulse, LOCAL_CHANNEL, protocol="wait-for-all") + + elif isinstance(block, CompositeBlock) or isinstance(block, AnalogComposite): + for block in block.blocks: + add_pulses(sequence, block, config, qc_register, spacing) + + else: + msg = f"The pulser backend currently does not support blocks of type: {type(block)}" + raise NotImplementedError(msg) + + +def analog_rot_pulse( + alpha: TVar | float, + omega: TVar | float, + phase: TVar | float, + detuning: TVar | float, + channel: Channel, + config: Configuration | None = None, +) -> Pulse: + # omega in rad/us; detuning in rad/us + if config is not None: + if channel.addressing == "Global": + max_amp = config.amplitude_global if config.amplitude_global is not None else omega + elif channel.addressing == "Local": + max_amp = config.amplitude_local if config.amplitude_local is not None else omega + max_det = config.detuning if config.detuning is not None else detuning + else: + max_amp = omega + max_det = detuning + + # get pulse duration in ns + duration = 1000 * abs(alpha) / np.sqrt(omega**2 + detuning**2) + + # create amplitude waveform + amp_wf = SquareWaveform.from_duration( + duration=duration, # type: ignore + max_amp=max_amp, # type: ignore[arg-type] + duration_steps=channel.clock_period, # type: ignore[attr-defined] + min_duration=channel.min_duration, + ) + + # create detuning waveform + det_wf = SquareWaveform.from_duration( + duration=duration, # type: ignore + max_amp=max_det, # type: ignore[arg-type] + duration_steps=channel.clock_period, # type: ignore[attr-defined] + min_duration=channel.min_duration, + ) + + return Pulse(amplitude=amp_wf, detuning=det_wf, phase=abs(phase)) + + +def entangle_pulse( + duration: TVar | float, channel: Channel, config: Configuration | None = None +) -> Pulse: + if config is None: + max_amp = channel.max_amp + else: + max_amp = ( + config.amplitude_global if config.amplitude_global is not None else channel.max_amp + ) + + clock = channel.clock_period + delay_wf = ConstantWaveform(clock * np.ceil(duration / clock), 0) # type: ignore + half_pi_wf = SquareWaveform.from_area( + area=np.pi / 2, + max_amp=max_amp, # type: ignore[arg-type] + duration_steps=clock, # type: ignore[attr-defined] + min_duration=channel.min_duration, + ) + + detuning_wf = RampWaveform(duration=half_pi_wf.duration, start=0, stop=np.pi) + amplitude = CompositeWaveform(half_pi_wf, delay_wf) + detuning = CompositeWaveform(detuning_wf, delay_wf) + return Pulse(amplitude=amplitude, detuning=detuning, phase=np.pi / 2) + + +def digital_rot_pulse( + angle: TVar | float, phase: float, channel: Channel, config: Configuration | None = None +) -> Pulse: + if config is None: + max_amp = channel.max_amp + else: + max_amp = config.amplitude_local if config.amplitude_local is not None else channel.max_amp + + # TODO: Implement reverse rotation for angles bigger than π + amplitude_wf = SquareWaveform.from_area( + area=abs(angle), # type: ignore + max_amp=max_amp, # type: ignore[arg-type] + duration_steps=channel.clock_period, # type: ignore[attr-defined] + min_duration=channel.min_duration, + ) + + return Pulse.ConstantDetuning(amplitude=amplitude_wf, detuning=0, phase=phase) diff --git a/qadence/backends/pulser/waveforms.py b/qadence/backends/pulser/waveforms.py new file mode 100644 index 00000000..aac0a653 --- /dev/null +++ b/qadence/backends/pulser/waveforms.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import numpy as np +from pulser.parametrized.decorators import parametrize +from pulser.waveforms import ConstantWaveform + +# determined by hardware team as a safe resolution +MAX_AMPLITUDE_SCALING = 0.1 +EPS = 1e-9 + + +class SquareWaveform(ConstantWaveform): + def __init__(self, duration: int, value: float): + super().__init__(duration, value) + + @classmethod + @parametrize + def from_area( + cls, + area: float, + max_amp: float, + duration_steps: int = 1, + min_duration: int = 1, + ) -> SquareWaveform: + amp_steps = MAX_AMPLITUDE_SCALING * max_amp + + duration = max( + duration_steps * np.round(area / (duration_steps * max_amp) * 1e3), + min_duration, + ) + amplitude = min( + amp_steps * np.ceil(area / (amp_steps * duration) * 1e3), + max_amp, + ) + delta = np.abs(1e-3 * duration * amplitude - area) + + new_duration = duration + duration_steps + new_amplitude = max( + amp_steps * np.ceil(area / (amp_steps * new_duration) * 1e3), + max_amp, + ) + new_delta = np.abs(1e-3 * new_duration * new_amplitude - area) + + while new_delta < delta: + duration = new_duration + amplitude = new_amplitude + delta = new_delta + + new_duration = duration + duration_steps + new_amplitude = max( + amp_steps * np.ceil(area / (amp_steps * new_duration) * 1e3), + max_amp, + ) + new_delta = np.abs(1e-3 * new_duration * new_amplitude - area) + + return cls(duration, amplitude) + + @classmethod + @parametrize + def from_duration( + cls, + duration: int, + max_amp: float, + duration_steps: int = 1, + min_duration: int = 1, + ) -> SquareWaveform: + amp_steps = MAX_AMPLITUDE_SCALING * max_amp + + duration = max( + duration_steps * np.round(duration / duration_steps), + min_duration, + ) + amplitude = min( + amp_steps * np.ceil(max_amp / (amp_steps + EPS) * 1e3), + max_amp, + ) + + return cls(duration, amplitude)