Skip to content

Commit

Permalink
Unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
HGSilveri committed Oct 16, 2024
1 parent 93c15d2 commit 98b7d92
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 15 deletions.
16 changes: 2 additions & 14 deletions pulser-core/pulser/register/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import warnings
from collections.abc import Mapping
from typing import Any, Optional, Union, cast, TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional, Union, cast

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -331,7 +331,6 @@ def max_connectivity(
def with_automatic_layout(
self,
device: Device,
validate: bool = True,
layout_slug: str | None = None,
) -> Register:
"""Replicates the register with an automatically generated layout.
Expand All @@ -340,8 +339,6 @@ def with_automatic_layout(
Args:
device: The device constraints for the layout generation.
validate: Whether to validate the generated RegisterLayout
against the provided device.
layout_slug: An optional slug for the generated layout.
Raises:
Expand All @@ -367,16 +364,7 @@ def with_automatic_layout(
)
layout = pulser.register.RegisterLayout(trap_coords, slug=layout_slug)
trap_ids = layout.get_traps_from_coordinates(*self.sorted_coords)
reg_from_layout = layout.define_register(*trap_ids)
if validate:
try:
device.validate_register(reg_from_layout)
except Exception as e:
raise RuntimeError(
"When defined from the layout, the register fails device "
"validation."
) from e
return cast(Register, reg_from_layout)
return cast(Register, layout.define_register(*trap_ids))

def rotated(self, degrees: float) -> Register:
"""Makes a new rotated register.
Expand Down
59 changes: 58 additions & 1 deletion tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.
from __future__ import annotations

import dataclasses
from unittest.mock import patch

import numpy as np
import pytest

from pulser import Register, Register3D
from pulser.devices import DigitalAnalogDevice, MockDevice
from pulser.devices import AnalogDevice, DigitalAnalogDevice, MockDevice
from pulser.register import RegisterLayout


def test_creation():
Expand Down Expand Up @@ -587,3 +589,58 @@ def test_register_recipes_torch(
}
reg = reg_classmethod(**kwargs)
_assert_reg_requires_grad(reg, invert=not requires_grad)


@pytest.mark.parametrize("optimal_filling", [None, 0.4, 0.1])
def test_automatic_layout(optimal_filling):
reg = Register.square(4, spacing=5)
max_layout_filling = 0.5
min_traps = int(np.ceil(len(reg.qubits) / max_layout_filling))
optimal_traps = int(
np.ceil(len(reg.qubits) / (optimal_filling or max_layout_filling))
)
device = dataclasses.replace(
AnalogDevice,
max_atom_num=20,
max_layout_filling=max_layout_filling,
optimal_layout_filling=optimal_filling,
pre_calibrated_layouts=(),
)
device.validate_register(reg)

# On its own, it works
new_reg = reg.with_automatic_layout(device, layout_slug="foo")
assert isinstance(new_reg.layout, RegisterLayout)
assert str(new_reg.layout) == "foo"
trap_num = new_reg.layout.number_of_traps
assert min_traps <= trap_num <= optimal_traps
# To test the device limits on trap number are enforced
if not optimal_filling:
assert trap_num == min_traps
bound_below_dev = dataclasses.replace(
device, min_layout_traps=trap_num + 1
)
assert (
reg.with_automatic_layout(bound_below_dev).layout.number_of_traps
== bound_below_dev.min_layout_traps
)
elif trap_num < optimal_traps:
assert trap_num > min_traps
bound_above_dev = dataclasses.replace(
device, max_layout_traps=trap_num - 1
)
assert (
reg.with_automatic_layout(
bound_above_dev, validate=False
).layout.number_of_traps
== bound_above_dev.max_layout_traps
)

with pytest.raises(TypeError, match="must be of type Device"):
reg.with_automatic_layout(MockDevice)

# Minimum number of traps is too high
with pytest.raises(RuntimeError, match="Failed to find a site"):
reg.with_automatic_layout(
dataclasses.replace(device, min_layout_traps=200)
)

0 comments on commit 98b7d92

Please sign in to comment.