Skip to content

Commit

Permalink
add 'inverse' diff test case
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBuchanan314 committed Dec 11, 2024
1 parent 51fb0f6 commit 8f724e5
Showing 1 changed file with 60 additions and 13 deletions.
73 changes: 60 additions & 13 deletions arroba/tests/mst_test_suite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import List, Tuple

import os
import io
import json

import dag_cbor
Expand All @@ -7,7 +10,7 @@

from tqdm import tqdm

from ..diff import Change, Diff
from ..diff import Change, Diff, null_diff
from ..mst import MST
from ..storage import MemoryStorage, Block
from . import testutil
Expand All @@ -16,31 +19,47 @@ class MSTSuiteTest(testutil.TestCase):

def setUp(self):
super().setUp()
self.diff_testcases = {}
# recursively search for test cases in JSON format.
# for now we only know how to process "mst-diff" test cases - more types will be added
# in the future
self.test_suite_base = "./mst-test-suite/"
diff_testcases = {}
for path in [os.path.join(dp, f) for dp, _, fn in os.walk(self.test_suite_base + "/tests/") for f in fn]:
if not path.endswith(".json"):
continue
with open(path) as json_file:
testcase = json.load(json_file)
if testcase.get("$type") == "mst-diff":
self.diff_testcases[path] = testcase
diff_testcases[path] = testcase
self.diff_testcases = dict(sorted(diff_testcases.items())) # sort them because os.walk() uses a weird order

def parse_car(self, stream) -> Tuple[CID, List[Tuple[CID, bytes]]]:
car_header = dag_cbor.decode(stream.read(varint.decode(stream)))
blocks = []
while True:
try:
block = stream.read(varint.decode(stream))
except ValueError:
break
blocks.append((CID.decode(block[:36]), block[36:]))
return car_header["roots"][0], blocks

def populate_storage_from_car(self, storage: MemoryStorage, car_path: str) -> CID:
# ad-hoc CAR parser, returns the root CID
with open(self.test_suite_base + car_path, "rb") as carfile:
car_header = dag_cbor.decode(carfile.read(varint.decode(carfile)))
while True:
try:
block = carfile.read(varint.decode(carfile))
except ValueError:
break
cid = CID.decode(block[:36])
storage.blocks[cid] = Block(cid=cid, encoded=block[36:])
return car_header["roots"][0]
root, blocks = self.parse_car(carfile)
for cid, value in blocks:
storage.blocks[cid] = Block(cid=cid, encoded=value)
return root

def serialise_canonical_car(self, root: CID, blocks: List[Tuple[CID, bytes]]) -> bytes:
car = io.BytesIO()
header = dag_cbor.encode({"version": 1, "roots": [root]})
car.write(varint.encode(len(header)) + header)
for cid, value in sorted(blocks, key=lambda x: bytes(x[0])):
entry = bytes(cid) + value
car.write(varint.encode(len(entry)) + entry)
return car.getvalue()

def test_diffs(self):
for testname, testcase in tqdm(self.diff_testcases.items()):
Expand Down Expand Up @@ -72,7 +91,8 @@ def test_diffs(self):
"new_value": None
})

# sort the lists for comparison, per mst-test-suite's rules
# sort the lists for comparison, per mst-test-suite's rules.
# NOTE: maybe we should just compare set()s instead?
created_list = sorted(cid.encode("base32") for cid in diff.new_cids)
deleted_list = sorted(cid.encode("base32") for cid in diff.removed_cids)
ops_list.sort(key=lambda x: x["rpath"])
Expand All @@ -81,3 +101,30 @@ def test_diffs(self):
self.assertEqual(created_list, testcase["results"]["created_nodes"], f"{testname} created_nodes") # currently fails!
self.assertEqual(deleted_list, testcase["results"]["deleted_nodes"], f"{testname} deleted_nodes")
# TODO: implement checks for proof_nodes, firehose_cids (test data hasn't been generated yet)

def test_diffs_inverse(self):
# we re-use the diff test cases but "backwards" - applying the op list
# to the initial MST see if we end up at the correct final MST
for testname, testcase in tqdm(self.diff_testcases.items()):
storage = MemoryStorage()
root_a = self.populate_storage_from_car(storage, testcase["inputs"]["mst_a"])
mst = MST.load(storage=storage, cid=root_a)

for op in testcase["results"]["record_ops"]:
if op["old_value"] and op["new_value"]: # update
mst = mst.update(op["rpath"], CID.decode(op["new_value"]))
elif op["old_value"]: # delete
mst = mst.delete(op["rpath"])
else: # create
mst = mst.add(op["rpath"], CID.decode(op["new_value"]))

diff = null_diff(mst) # should get us a map of the complete new mst
root_b = mst.get_pointer()

with open(self.test_suite_base + testcase["inputs"]["mst_b"], "rb") as car_b:
reference_root, reference_blocks = self.parse_car(car_b)

reference_cid_set = set(x[0] for x in reference_blocks) # just look at the cids from the car

self.assertEqual(root_b, reference_root, f"{testname} inverse: new root") # fails occasionally
self.assertEqual(diff.new_cids, reference_cid_set, f"{testname} inverse: new cid set") # basically always fails, I think I'm doing something wrong

0 comments on commit 8f724e5

Please sign in to comment.