Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Pengfei Chen committed Oct 18, 2023
1 parent e8cce5a commit 576b13c
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 90 deletions.
77 changes: 34 additions & 43 deletions unitary/examples/quantum_chinese_chess/move_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from unitary.examples.quantum_chinese_chess.move import Move, Jump
from unitary.examples.quantum_chinese_chess.board import Board
from unitary.examples.quantum_chinese_chess.piece import Piece
import pytest
from unitary import alpha
from typing import List
from string import ascii_lowercase, digits
from unitary.examples.quantum_chinese_chess.enums import (
MoveType,
MoveVariant,
Expand All @@ -31,35 +35,8 @@
sample_board,
get_board_probability_distribution,
print_samples,
set_board,
)
import pytest
from unitary import alpha
from typing import List
from string import ascii_lowercase, digits


_EMPTY_FEN = "9/9/9/9/9/9/9/9/9/9 w---1"


def global_names():
pass


# global board
# board = Board.from_fen(_EMPTY_FEN)


def set_board(positions: List[str]):
global board
board = Board.from_fen(_EMPTY_FEN)
for col in ascii_lowercase[:9]:
for row in digits:
globals()[f"{col}{row}"] = board.board[f"{col}{row}"]
for position in positions:
board.board[position].reset(
Piece(position, SquareState.OCCUPIED, Type.ROOK, Color.RED)
)
alpha.Flip()(board.board[position])


def test_move_eq():
Expand Down Expand Up @@ -168,10 +145,12 @@ def test_to_str():


def test_jump_classical():
global_names()

# Target is empty.
set_board(["a1", "b1"])
board = set_board(["a1", "b1"])
# TODO(): try move the following varaibles declarations into a function.
for col in ascii_lowercase[:9]:
for row in digits:
globals()[f"{col}{row}"] = board.board[f"{col}{row}"]
Jump(MoveVariant.CLASSICAL)(a1, b2)
assert_samples_in(board, [locations_to_bitboard(["b2", "b1"])])

Expand All @@ -182,8 +161,10 @@ def test_jump_classical():

def test_jump_capture():
# Source is in quantum state.
global_names()
set_board(["a1", "b1"])
board = set_board(["a1", "b1"])
for col in ascii_lowercase[:9]:
for row in digits:
globals()[f"{col}{row}"] = board.board[f"{col}{row}"]
alpha.PhasedSplit()(a1, a2, a3)
board_probabilities = get_board_probability_distribution(board, 1000)
assert len(board_probabilities) == 2
Expand All @@ -199,8 +180,10 @@ def test_jump_capture():
assert_samples_in(board, [locations_to_bitboard(["a3", "b1"])])

# Target is in quantum state.
global_names()
set_board(["a1", "b1"])
board = set_board(["a1", "b1"])
for col in ascii_lowercase[:9]:
for row in digits:
globals()[f"{col}{row}"] = board.board[f"{col}{row}"]
alpha.PhasedSplit()(b1, b2, b3)
Jump(MoveVariant.CAPTURE)(a1, b2)
board_probabilities = get_board_probability_distribution(board, 1000)
Expand All @@ -209,8 +192,10 @@ def test_jump_capture():
assert_fifty_fifty(board_probabilities, locations_to_bitboard(["b2", "b3"]))

# Both source and target are in quantum state.
global_names()
set_board(["a1", "b1"])
board = set_board(["a1", "b1"])
for col in ascii_lowercase[:9]:
for row in digits:
globals()[f"{col}{row}"] = board.board[f"{col}{row}"]
alpha.PhasedSplit()(a1, a2, a3)
alpha.PhasedSplit()(b1, b2, b3)
assert_sample_distribution(
Expand Down Expand Up @@ -238,8 +223,10 @@ def test_jump_capture():

def test_jump_excluded():
# Target is in quantum state.
global_names()
set_board(["a1", "b1"])
board = set_board(["a1", "b1"])
for col in ascii_lowercase[:9]:
for row in digits:
globals()[f"{col}{row}"] = board.board[f"{col}{row}"]
alpha.PhasedSplit()(b1, b2, b3)
Jump(MoveVariant.EXCLUDED)(a1, b2)
# pop() will break the supersition and only one of the following two states are possible.
Expand All @@ -252,8 +239,10 @@ def test_jump_excluded():
assert_samples_in(board, [locations_to_bitboard(["b2", "b3"])])

# Both source and target are in quantum state.
global_names()
set_board(["a1", "b1"])
board = set_board(["a1", "b1"])
for col in ascii_lowercase[:9]:
for row in digits:
globals()[f"{col}{row}"] = board.board[f"{col}{row}"]
alpha.PhasedSplit()(a1, a2, a3)
alpha.PhasedSplit()(b1, b2, b3)
Jump(MoveVariant.EXCLUDED)(a2, b2)
Expand All @@ -272,8 +261,10 @@ def test_jump_excluded():

def test_jump_basic():
# Souce is in quantum state.
global_names()
set_board(["a1"])
board = set_board(["a1"])
for col in ascii_lowercase[:9]:
for row in digits:
globals()[f"{col}{row}"] = board.board[f"{col}{row}"]
alpha.PhasedSplit()(a1, a2, a3)
Jump(MoveVariant.BASIC)(a2, d1)
board_probabilities = get_board_probability_distribution(board, 1000)
Expand Down
124 changes: 77 additions & 47 deletions unitary/examples/quantum_chinese_chess/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unitary.alpha import QuantumObject, QuantumWorld
from unitary.examples.quantum_chinese_chess.enums import SquareState
from unitary.examples.quantum_chinese_chess.enums import SquareState, Type, Color
from unitary.examples.quantum_chinese_chess.board import Board
from string import ascii_lowercase, digits
from unitary.examples.quantum_chinese_chess.piece import Piece
from unitary import alpha
from typing import List, Dict
from collections import defaultdict
from scipy.stats import chisquare


_EMPTY_FEN = "9/9/9/9/9/9/9/9/9/9 w---1"


def set_board(positions: List[str]) -> Board:
"""Returns a board with the specified positions filled with
RED ROOKs.
"""
board = Board.from_fen(_EMPTY_FEN)
for position in positions:
board.board[position].reset(
Piece(position, SquareState.OCCUPIED, Type.ROOK, Color.RED)
)
alpha.Flip()(board.board[position])
return board


def location_to_bit(location: str) -> int:
"""Transform location notation (e.g. "a3") into a bitboard bit number."""
"""Transform location notation (e.g. "a3") into a bitboard bit number.
The return value ranges from 0 to 89.
"""
x = ord(location[0]) - ord("a")
y = int(location[1])
return y * 9 + x


def locations_to_bitboard(locations: List[str]) -> int:
"""Transform a list of locations into a 90-bit board bitstring."""
"""Transform a list of locations into a 90-bit board bitstring.
Each nonzero bit of the bitstring indicates that the corresponding
piece is occupied.
"""
bitboard = 0
for location in locations:
bitboard += 1 << location_to_bit(location)
return bitboard


def nth_bit_of(n: int, bit_board: int) -> bool:
"""Returns the n-th bit of a 90-bit bitstring."""
"""Returns the `n`-th (zero-based) bit of a 90-bit bitstring `bit_board`."""
return (bit_board >> n) % 2 == 1


Expand All @@ -48,7 +70,7 @@ def bit_to_location(bit: int) -> str:


def bitboard_to_locations(bitboard: int) -> List[str]:
"""Transform a 90-bit bitstring into a list of locations."""
"""Transform a 90-bit bitstring `bitboard` into a list of locations."""
locations = []
for n in range(90):
if nth_bit_of(n, bitboard):
Expand All @@ -57,22 +79,26 @@ def bitboard_to_locations(bitboard: int) -> List[str]:


def sample_board(board: Board, repetitions: int) -> List[int]:
"""Sample the given `board` by the given `repetitions`.
Returns a list of 90-bit bitstring, each corresponding to one sample.
"""
samples = board.board.peek(count=repetitions, convert_to_enum=False)
# Convert peek results (in List[List[int]]) into bitstring.
# Convert peek results (in List[List[int]]) into List[int].
samples = [
int("0b" + "".join([str(i) for i in sample[::-1]]), base=2)
for sample in samples
]
return samples


def print_samples(samples):
"""Prints all the samples as lists of locations."""
def print_samples(samples: List[int]) -> None:
"""Aggregate all the samples and print the dictionary of {locations: count}."""
sample_dict = {}
for sample in samples:
if sample not in sample_dict:
sample_dict[sample] = 0
sample_dict[sample] += 1
print("Actual samples:")
for key in sample_dict:
print(f"{bitboard_to_locations(key)}: {sample_dict[key]}")

Expand All @@ -81,17 +107,11 @@ def get_board_probability_distribution(
board: Board, repetitions: int = 1000
) -> Dict[int, float]:
"""Returns the probability distribution for each board found in the sample.
The values are returned as a dict{bitboard(int): probability(float)}.
"""
board_probabilities: Dict[int, float] = {}

samples = board.board.peek(count=repetitions, convert_to_enum=False)
# Convert peek results (in List[List[int]]) into bitstring.
samples = [
int("0b" + "".join([str(i) for i in sample[::-1]]), base=2)
for sample in samples
]
samples = sample_board(board, repetitions)
for sample in samples:
if sample not in board_probabilities:
board_probabilities[sample] = 0.0
Expand All @@ -103,63 +123,73 @@ def get_board_probability_distribution(
return board_probabilities


def assert_samples_in(board: Board, possibilities):
def assert_samples_in(board: Board, probabilities: Dict[int, float]) -> None:
"""Samples the given `board` and asserts that all samples are within
the given `probabilities` (i.e. a map from bitstring into its possibility),
and that each possibility is represented at least once in the samples.
"""
samples = sample_board(board, 500)
assert len(samples) == 500
all_in = all(sample in possibilities for sample in samples)
print(possibilities)
print(set(samples))
all_in = all(sample in probabilities for sample in samples)
assert all_in, print_samples(samples)
# Make sure each possibility is represented at least once.
for possibility in possibilities:
for possibility in probabilities:
any_in = any(sample == possibility for sample in samples)
assert any_in, print_samples(samples)


def assert_sample_distribution(board: Board, probability_map, p_significant=1e-6):
def assert_sample_distribution(
board: Board, probabilities: Dict[int, float], p_significant: float = 1e-6
) -> None:
"""Performs a chi-squared test that samples follow an expected distribution.
probability_map is a map from bitboards to expected probability. An
assertion is raised if one of the samples is not in the map, or if the
probability that the samples are at least as different from the expected
ones as the observed sampless is less than p_significant.
`probabilities` is a map from bitboards to expected probability. An
AssertionError is raised if any of the samples is not in the map, or if the
expected versus observed samples fails the chi-squared test.
"""
assert abs(sum(probability_map.values()) - 1) < 1e-9
samples = sample_board(board, 500)
assert len(samples) == 500
n_samples = 500
assert abs(sum(probabilities.values()) - 1) < 1e-9
samples = sample_board(board, n_samples)
counts = defaultdict(int)
for sample in samples:
assert sample in probability_map, bitboard_to_locations(sample)
assert sample in probabilities, bitboard_to_locations(sample)
counts[sample] += 1
observed = []
expected = []
for position, probability in probability_map.items():
for position, probability in probabilities.items():
observed.append(counts[position])
expected.append(500 * probability)
expected.append(n_samples * probability)
p = chisquare(observed, expected).pvalue
assert (
p > p_significant
), f"Observed {observed} far from expected {expected} (p = {p})"
), f"Observed {observed} is far from expected {expected} (p = {p})"


def assert_this_or_that(samples, this, that):
"""Asserts all the samples are either equal to this or that,
and that one of each exists in the samples.
def assert_this_or_that(samples: List[int], this: int, that: int) -> None:
"""Asserts all the samples are either equal to `this` or `that`,
and that at least one of them exists in the samples.
"""
# assert any(sample == this for sample in samples), print_samples(samples)
# assert any(sample == that for sample in samples), print_samples(samples)
assert any(sample == this for sample in samples), print_samples(samples)
assert any(sample == that for sample in samples), print_samples(samples)
assert all(sample == this or sample == that for sample in samples), print_samples(
samples
)


def assert_prob_about(probs, that, expected, atol=0.05):
"""Checks that the probability is within atol of the expected value."""
assert that in probs, print_samples(list(probs.keys()))
assert probs[that] > expected - atol, print_samples(list(probs.keys()))
assert probs[that] < expected + atol, print_samples(list(probs.keys()))
def assert_prob_about(
probabilities: Dict[int, float], that: int, expected: float, atol: float = 0.05
) -> None:
"""Checks that the probability of `that` is within `atol` of the value of `expected`."""
assert that in probabilities, print_samples(list(probabilities.keys()))
assert probabilities[that] > expected - atol, print_samples(
list(probabilities.keys())
)
assert probabilities[that] < expected + atol, print_samples(
list(probabilities.keys())
)


def assert_fifty_fifty(probs, that):
"""Checks that the probability is close to 50%."""
assert_prob_about(probs, that, 0.5), print_samples(list(probs.keys()))
def assert_fifty_fifty(probabilities, that):
"""Checks that the probability of `that` is close to 50%."""
assert_prob_about(probabilities, that, 0.5), print_samples(
list(probabilities.keys())
)

0 comments on commit 576b13c

Please sign in to comment.