Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support loading *.aig files in binary format #130

Merged
merged 4 commits into from
Jan 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiger/aig.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def to_aig(circ, *, allow_lazy=False) -> AIG:
if isinstance(circ, pathlib.Path) and circ.is_file():
circ = parser.load(circ)
elif isinstance(circ, str):
if circ.startswith('aag '):
if circ.startswith('aag ') or circ.startswith('aig '):
circ = parser.parse(circ) # Assume it is an AIGER string.
else:
circ = parser.load(circ) # Assume it is a file path.
Expand Down
203 changes: 152 additions & 51 deletions aiger/parser.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
import io
import re
from collections import defaultdict
from functools import reduce
from typing import Mapping, List, Optional
from uuid import uuid1

import attr
import funcy as fn
from bidict import bidict
from sortedcontainers import SortedDict
from toposort import toposort_flatten
from uuid import uuid1
from sortedcontainers import SortedList, SortedSet, SortedDict

import aiger as A


@attr.s(auto_attribs=True, repr=False)
class Header:
binary_mode: bool
max_var_index: int
num_inputs: int
num_latches: int
num_outputs: int
num_ands: int

def __repr__(self):
return f"aag {self.max_var_index} {self.num_inputs} " \
f"{self.num_latches} {self.num_outputs} {self.num_ands}"
mode = 'aig' if self.binary_mode else 'aag'
return f"{mode} {self.max_var_index} {self.num_inputs} " \
f"{self.num_latches} {self.num_outputs} {self.num_ands}"


NOT_DONE_PARSING_ERROR = "Parsing rules exhausted at line {}!\n{}"
Expand All @@ -38,6 +39,13 @@ class Latch:
init: bool = attr.ib(converter=bool)


@attr.s(auto_attribs=True, frozen=True)
class And:
lhs: int
rhs0: int
rhs1: int


@attr.s(auto_attribs=True, frozen=True)
class Symbol:
kind: str
Expand All @@ -59,9 +67,10 @@ class SymbolTable:
@attr.s(auto_attribs=True)
class State:
header: Optional[Header] = None
inputs: List[str] = attr.ib(factory=list)
outputs: List[str] = attr.ib(factory=list)
latches: List[str] = attr.ib(factory=list)
inputs: List[int] = attr.ib(factory=list)
outputs: List[int] = attr.ib(factory=list)
latches: List[Latch] = attr.ib(factory=list)
ands: List[And] = attr.ib(factory=list)
symbols: SymbolTable = attr.ib(factory=SymbolTable)
comments: Optional[List[str]] = None
nodes: SortedDict = attr.ib(factory=SortedDict)
Expand All @@ -78,29 +87,46 @@ def remaining_outputs(self):
def remaining_inputs(self):
return self.header.num_inputs - len(self.inputs)

@property
def remaining_ands(self):
return self.header.num_ands - len(self.ands)

HEADER_PATTERN = re.compile(r"aag (\d+) (\d+) (\d+) (\d+) (\d+)\n")

def _consume_stream(stream, delim) -> str:
line = bytearray()
ch = -1
delim = ord(delim)
while ch != delim:
ch = next(stream, delim)
line.append(ch)
return line.decode('ascii')

def parse_header(state, line) -> bool:

HEADER_PATTERN = re.compile(r"(a[ai]g) (\d+) (\d+) (\d+) (\d+) (\d+)\n")


def parse_header(state, stream) -> bool:
if state.header is not None:
return False

line = _consume_stream(stream, '\n')
match = HEADER_PATTERN.match(line)
if not match:
raise ValueError(f"Failed to parse aag HEADER. {line}")
raise ValueError(f"Failed to parse aag/aig HEADER. {line}")

try:
ids = fn.lmap(int, match.groups())
binary_mode = match.group(1) == 'aig'
ids = fn.lmap(int, match.groups()[1:])

if any(x < 0 for x in ids):
raise ValueError("Indicies must be positive!")
raise ValueError("Indices must be positive!")

max_idx, nin, nlatch, nout, nand = ids
if nin + nlatch + nand > max_idx:
raise ValueError("Sum of claimed indices greater than max.")

state.header = Header(
binary_mode=binary_mode,
max_var_index=max_idx,
num_inputs=nin,
num_latches=nlatch,
Expand All @@ -116,21 +142,38 @@ def parse_header(state, line) -> bool:
IO_PATTERN = re.compile(r"(\d+)\s*\n")


def parse_input(state, line) -> bool:
match = IO_PATTERN.match(line)

if match is None or state.remaining_inputs <= 0:
return False
lit = int(line)
def _add_input(state, lit):
state.inputs.append(lit)
state.nodes[lit] = set()
return True


def parse_output(state, line) -> bool:
def parse_input(state, stream) -> bool:
if state.remaining_inputs <= 0:
return False

if state.header.binary_mode:
for lit in range(2, 2 * (state.header.num_inputs + 1), 2):
_add_input(state, lit)
return False

line = _consume_stream(stream, '\n')
match = IO_PATTERN.match(line)
if match is None or state.remaining_outputs <= 0:

if match is None:
raise ValueError(f"Expecting an input: {line}")

_add_input(state, int(line))
return True


def parse_output(state, stream) -> bool:
if state.remaining_outputs <= 0:
return False

line = _consume_stream(stream, '\n')
match = IO_PATTERN.match(line)
if match is None:
raise ValueError(f"Expecting an output: {line}")
lit = int(line)
state.outputs.append(lit)
if lit & 1:
Expand All @@ -139,17 +182,28 @@ def parse_output(state, line) -> bool:


LATCH_PATTERN = re.compile(r"(\d+) (\d+)(?: (\d+))?\n")
LATCH_PATTERN_BINARY = re.compile(r"(\d+)(?: (\d+))?\n")


def parse_latch(state, line) -> bool:
def parse_latch(state, stream) -> bool:
if state.remaining_latches <= 0:
return False

match = LATCH_PATTERN.match(line)
if match is None:
raise ValueError("Expecting a latch: {line}")
line = _consume_stream(stream, '\n')

if state.header.binary_mode:
match = LATCH_PATTERN_BINARY.match(line)
if match is None:
raise ValueError(f"Expecting a latch: {line}")
idx = state.header.num_inputs + len(state.latches) + 1
lit = 2 * idx
elems = (lit,) + match.groups()
else:
match = LATCH_PATTERN.match(line)
if match is None:
raise ValueError(f"Expecting a latch: {line}")
elems = match.groups()

elems = match.groups()
if elems[2] is None:
elems = elems[:2] + (0,)
elems = fn.lmap(int, elems)
Expand All @@ -165,30 +219,69 @@ def parse_latch(state, line) -> bool:
AND_PATTERN = re.compile(r"(\d+) (\d+) (\d+)\s*\n")


def parse_and(state, line) -> bool:
if state.header.num_ands <= 0:
return False

match = AND_PATTERN.match(line)
if match is None:
return False

elems = fn.lmap(int, match.groups())
state.header.num_ands -= 1
deps = set(elems[1:])
state.nodes[elems[0]] = deps
def _read_delta(data):
ch = next(data)
i = 0
delta = 0
while (ch & 0x80) != 0:
if i == 5:
raise ValueError("Invalid byte in delta encoding")
delta |= (ch & 0x7f) << (7 * i)
i += 1
ch = next(data)
if i == 5 and ch >= 8:
raise ValueError("Invalid byte in delta encoding")

delta |= ch << (7 * i)
return delta


def _add_and(state, elems):
lhs, rhs0, rhs1 = fn.lmap(int, elems)
state.ands.append(And(lhs, rhs0, rhs1))
deps = {rhs0, rhs1}
state.nodes[lhs] = deps
for dep in deps:
if dep & 1:
state.nodes[dep] = {dep ^ 1}


def parse_and(state, stream) -> bool:
if state.remaining_ands <= 0:
return False

if state.header.binary_mode:
idx = state.header.num_inputs + state.header.num_latches + len(state.ands) + 1
lhs = 2 * idx
delta = _read_delta(stream)
if delta > lhs:
raise ValueError(f"Invalid lhs {lhs} or delta {delta}")
rhs0 = lhs - delta
delta = _read_delta(stream)
if delta > rhs0:
raise ValueError(f"Invalid rhs0 {rhs0} or delta {delta}")
rhs1 = rhs0 - delta
else:
line = _consume_stream(stream, '\n')
match = AND_PATTERN.match(line)
if match is None:
raise ValueError(f"Expecting an and: {line}")
lhs, rhs0, rhs1 = match.groups()

_add_and(state, (lhs, rhs0, rhs1))
return True


SYM_PATTERN = re.compile(r"([ilo])(\d+) (.*)\s*\n")


def parse_symbol(state, line) -> bool:
def parse_symbol(state, stream) -> bool:
line = _consume_stream(stream, '\n')
match = SYM_PATTERN.match(line)
if match is None:
# We might have consumed the 'c' starting the comments section
if line.rstrip() == 'c':
state.comments = []
return False

kind, idx, name = match.groups()
Expand All @@ -202,7 +295,8 @@ def parse_symbol(state, line) -> bool:
return True


def parse_comment(state, line) -> bool:
def parse_comment(state, stream) -> bool:
line = _consume_stream(stream, '\n')
if state.comments is not None:
state.comments.append(line.rstrip())
elif line.rstrip() == 'c':
Expand All @@ -227,25 +321,30 @@ def finish_table(table, keys):
return {table[i]: key for i, key in enumerate(keys)}


def parse(lines, to_aig: bool = True):
if isinstance(lines, str):
lines = io.StringIO(lines)
def parse(stream):
if isinstance(stream, list):
stream = ''.join(stream)
if isinstance(stream, str):
stream = bytes(stream, 'ascii')
stream = iter(stream)

state = State()
parsers = parse_seq()
parser = next(parsers)

for i, line in enumerate(lines):
while not parser(state, line):
i = 0
while stream.__length_hint__() > 0:
i += 1
while not parser(state, stream):
parser = next(parsers, None)

if parser is None:
raise ValueError(NOT_DONE_PARSING_ERROR.format(i + 1, state))
raise ValueError(NOT_DONE_PARSING_ERROR.format(i, state))

if parser not in (parse_header, parse_output, parse_comment, parse_symbol):
raise ValueError(DONE_PARSING_ERROR.format(state))

assert state.header.num_ands == 0
assert state.remaining_ands == 0
assert state.remaining_inputs == 0
assert state.remaining_outputs == 0
assert state.remaining_latches == 0
Expand All @@ -260,6 +359,7 @@ def parse(lines, to_aig: bool = True):

# Create expression DAG.
latch_ids = {latch.id: name for name, latch in latches.items()}
and_ids = {and_.lhs: and_ for and_ in state.ands}
lit2expr = {0: A.aig.ConstFalse()}
for lit in toposort_flatten(state.nodes):
if lit == 0:
Expand All @@ -272,6 +372,7 @@ def parse(lines, to_aig: bool = True):
elif lit & 1:
lit2expr[lit] = A.aig.Inverter(lit2expr[lit & -2])
else:
assert lit in and_ids
nodes = [lit2expr[lit2] for lit2 in state.nodes[lit]]
lit2expr[lit] = reduce(A.aig.AndGate, nodes)

Expand All @@ -284,9 +385,9 @@ def parse(lines, to_aig: bool = True):
)


def load(path: str, to_aig: bool = True):
with open(path, 'r') as f:
return parse(''.join(f.readlines()), to_aig=to_aig)
def load(path: str):
with open(path, 'rb') as f:
return parse(f.read())


__all__ = ['load', 'parse']
Loading