Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MST Test Suite integration #42

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions arroba/tests/mst_test_suite.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import List, Tuple

import os
import io
import json

import dag_cbor
Expand All @@ -7,7 +10,7 @@

from tqdm import tqdm

from ..diff import Change, Diff
from ..diff import Change, Diff, null_diff
from ..mst import MST
from ..storage import MemoryStorage, Block
from . import testutil
Expand All @@ -16,31 +19,47 @@ class MSTSuiteTest(testutil.TestCase):

def setUp(self):
super().setUp()
self.diff_testcases = {}
# recursively search for test cases in JSON format.
# for now we only know how to process "mst-diff" test cases - more types will be added
# in the future
self.test_suite_base = "./mst-test-suite/"
diff_testcases = {}
for path in [os.path.join(dp, f) for dp, _, fn in os.walk(self.test_suite_base + "/tests/") for f in fn]:
if not path.endswith(".json"):
continue
with open(path) as json_file:
testcase = json.load(json_file)
if testcase.get("$type") == "mst-diff":
self.diff_testcases[path] = testcase
diff_testcases[path] = testcase
self.diff_testcases = dict(sorted(diff_testcases.items())) # sort them because os.walk() uses a weird order

def parse_car(self, stream) -> Tuple[CID, List[Tuple[CID, bytes]]]:
car_header = dag_cbor.decode(stream.read(varint.decode(stream)))
blocks = []
while True:
try:
block = stream.read(varint.decode(stream))
except ValueError:
break
blocks.append((CID.decode(block[:36]), block[36:]))
return car_header["roots"][0], blocks

def populate_storage_from_car(self, storage: MemoryStorage, car_path: str) -> CID:
# ad-hoc CAR parser, returns the root CID
with open(self.test_suite_base + car_path, "rb") as carfile:
car_header = dag_cbor.decode(carfile.read(varint.decode(carfile)))
while True:
try:
block = carfile.read(varint.decode(carfile))
except ValueError:
break
cid = CID.decode(block[:36])
storage.blocks[cid] = Block(cid=cid, encoded=block[36:])
return car_header["roots"][0]
root, blocks = self.parse_car(carfile)
for cid, value in blocks:
storage.blocks[cid] = Block(cid=cid, encoded=value)
return root

def serialise_canonical_car(self, root: CID, blocks: List[Tuple[CID, bytes]]) -> bytes:
car = io.BytesIO()
header = dag_cbor.encode({"version": 1, "roots": [root]})
car.write(varint.encode(len(header)) + header)
for cid, value in sorted(blocks, key=lambda x: bytes(x[0])):
entry = bytes(cid) + value
car.write(varint.encode(len(entry)) + entry)
return car.getvalue()

def test_diffs(self):
for testname, testcase in tqdm(self.diff_testcases.items()):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, and I used tqdm for a progress bar - not super necessary. I should either take it out or add it as an optional dependency

Expand Down Expand Up @@ -72,7 +91,8 @@ def test_diffs(self):
"new_value": None
})

# sort the lists for comparison, per mst-test-suite's rules
# sort the lists for comparison, per mst-test-suite's rules.
# NOTE: maybe we should just compare set()s instead?
created_list = sorted(cid.encode("base32") for cid in diff.new_cids)
deleted_list = sorted(cid.encode("base32") for cid in diff.removed_cids)
ops_list.sort(key=lambda x: x["rpath"])
Expand All @@ -81,3 +101,30 @@ def test_diffs(self):
self.assertEqual(created_list, testcase["results"]["created_nodes"], f"{testname} created_nodes") # currently fails!
self.assertEqual(deleted_list, testcase["results"]["deleted_nodes"], f"{testname} deleted_nodes")
# TODO: implement checks for proof_nodes, firehose_cids (test data hasn't been generated yet)

def test_diffs_inverse(self):
# we re-use the diff test cases but "backwards" - applying the op list
# to the initial MST see if we end up at the correct final MST
for testname, testcase in tqdm(self.diff_testcases.items()):
storage = MemoryStorage()
root_a = self.populate_storage_from_car(storage, testcase["inputs"]["mst_a"])
mst = MST.load(storage=storage, cid=root_a)

for op in testcase["results"]["record_ops"]:
if op["old_value"] and op["new_value"]: # update
mst = mst.update(op["rpath"], CID.decode(op["new_value"]))
elif op["old_value"]: # delete
mst = mst.delete(op["rpath"])
else: # create
mst = mst.add(op["rpath"], CID.decode(op["new_value"]))

diff = null_diff(mst) # should get us a map of the complete new mst
root_b = mst.get_pointer()

with open(self.test_suite_base + testcase["inputs"]["mst_b"], "rb") as car_b:
reference_root, reference_blocks = self.parse_car(car_b)

reference_cid_set = set(x[0] for x in reference_blocks) # just look at the cids from the car

self.assertEqual(root_b, reference_root, f"{testname} inverse: new root") # fails occasionally
self.assertEqual(diff.new_cids, reference_cid_set, f"{testname} inverse: new cid set") # basically always fails, I think I'm doing something wrong
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the gist of this test case is: take mst_a, apply the list of ops to it, and compare the result to mst_b. The list of CIDs I'm getting from null_diff(mst).new_cids looks completely different to what I expect, so I suspect I'm doing something wrong here (or maybe null_diff() is just very broken heh)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hah. Both are possible I guess! arroba's diff code is used in production, but definitely not as broadly as it could be, and not all parts, so big bugs are very possible.