diff --git a/xls/modules/zstd/memory/BUILD b/xls/modules/zstd/memory/BUILD index f1c53d78f4..3a58f8aa72 100644 --- a/xls/modules/zstd/memory/BUILD +++ b/xls/modules/zstd/memory/BUILD @@ -15,6 +15,7 @@ load("@rules_hdl//place_and_route:build_defs.bzl", "place_and_route") load("@rules_hdl//synthesis:build_defs.bzl", "benchmark_synth", "synthesize_rtl") load("@rules_hdl//verilog:providers.bzl", "verilog_library") +load("@xls_pip_deps//:requirements.bzl", "requirement") load( "//xls/build_rules:xls_build_defs.bzl", "xls_benchmark_ir", @@ -487,6 +488,33 @@ place_and_route( target_die_utilization_percentage = "10", ) +py_test( + name = "mem_reader_cocotb_test", + srcs = ["mem_reader_cocotb_test.py"], + data = [ + ":mem_reader_adv.v", + ":mem_reader_wrapper.v", + "@com_icarus_iverilog//:iverilog", + "@com_icarus_iverilog//:vvp", + ], + env = {"BUILD_WORKING_DIRECTORY": "sim_build"}, + tags = ["manual"], + visibility = ["//xls:xls_users"], + deps = [ + requirement("cocotb"), + requirement("cocotbext-axi"), + requirement("pytest"), + "//xls/common:runfiles", + "//xls/modules/zstd/cocotb:channel", + "//xls/modules/zstd/cocotb:memory", + "//xls/modules/zstd/cocotb:utils", + "//xls/modules/zstd/cocotb:xlsstruct", + "@com_google_absl_py//absl:app", + "@com_google_absl_py//absl/flags", + "@com_google_protobuf//:protobuf_python", + ], +) + xls_dslx_library( name = "axi_writer_dslx", srcs = ["axi_writer.x"], diff --git a/xls/modules/zstd/memory/mem_reader_cocotb_test.py b/xls/modules/zstd/memory/mem_reader_cocotb_test.py new file mode 100644 index 0000000000..845a321159 --- /dev/null +++ b/xls/modules/zstd/memory/mem_reader_cocotb_test.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python +# Copyright 2024 The XLS Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +import sys +import warnings +from pathlib import Path + +import cocotb +from cocotb.clock import Clock +from cocotb.triggers import ClockCycles, Event +from cocotb_bus.scoreboard import Scoreboard +from cocotbext.axi.axi_channels import AxiARBus, AxiRBus, AxiReadBus, AxiRTransaction, AxiRSource, AxiRSink, AxiRMonitor +from cocotbext.axi.axi_ram import AxiRamRead +from cocotbext.axi.sparse_memory import SparseMemory + +from xls.modules.zstd.cocotb.channel import ( + XLSChannel, + XLSChannelDriver, + XLSChannelMonitor, +) +from xls.modules.zstd.cocotb.utils import reset, run_test +from xls.modules.zstd.cocotb.xlsstruct import XLSStruct, xls_dataclass + +# to disable warnings from hexdiff used by cocotb's Scoreboard +warnings.filterwarnings("ignore", category=DeprecationWarning) + +DSLX_DATA_W = 64 +DSLX_ADDR_W = 16 + +AXI_DATA_W = 128 +AXI_ADDR_W = 16 + +LAST_W = 1 +STATUS_W = 1 +ERROR_W = 1 +ID_W = 4 +DEST_W = 4 + +# AXI +AXI_AR_PREFIX = "axi_ar" +AXI_R_PREFIX = "axi_r" + +# MemReader +MEM_READER_REQ_CHANNEL = "req" +MEM_READER_RESP_CHANNEL = "resp" + +# Override default widths of AXI response signals +signal_widths = {"rresp": 3, "rlast": 1} +AxiRBus._signal_widths = signal_widths +AxiRTransaction._signal_widths = signal_widths +AxiRSource._signal_widths = signal_widths +AxiRSink._signal_widths = signal_widths +AxiRMonitor._signal_widths = signal_widths + +@xls_dataclass +class MemReaderReq(XLSStruct): + addr: DSLX_ADDR_W + length: DSLX_ADDR_W + + +@xls_dataclass +class MemReaderResp(XLSStruct): + status: STATUS_W + data: DSLX_DATA_W + length: DSLX_ADDR_W + last: LAST_W + + +@xls_dataclass +class AxiReaderReq(XLSStruct): + addr: AXI_ADDR_W + len: AXI_ADDR_W + + +@xls_dataclass +class AxiStream(XLSStruct): + data: AXI_DATA_W + str: AXI_DATA_W // 8 + keep: AXI_DATA_W // 8 = 0 + last: LAST_W = 0 + id: ID_W = 0 + dest: DEST_W = 0 + + +@xls_dataclass +class AxiReaderError(XLSStruct): + error: ERROR_W + + +@xls_dataclass +class AxiAr(XLSStruct): + id: ID_W + addr: AXI_ADDR_W + region: 4 + len: 8 + size: 3 + burst: 2 + cache: 4 + prot: 3 + qos: 4 + + +@xls_dataclass +class AxiR(XLSStruct): + id: ID_W + data: AXI_DATA_W + resp: 3 + last: 1 + + +def print_callback(name: str = "monitor"): + def _print_callback(transaction): + print(f" [{name}]: {transaction}") + + return _print_callback + + +def set_termination_event(monitor, event, transactions): + def terminate_cb(_): + if monitor.stats.received_transactions == transactions: + print("all transactions received") + event.set() + + monitor.add_callback(terminate_cb) + + +def generate_test_data(test_cases, xfer_base=0x0, seed=1234): + random.seed(seed) + mem_size = 2**AXI_ADDR_W + data_w_div8 = DSLX_DATA_W // 8 + + assert xfer_base < mem_size, "Base address outside the memory span" + + req = [] + resp = [] + mem_writes = {} + + for xfer_offset, xfer_length in test_cases: + xfer_addr = xfer_base + xfer_offset + xfer_max_addr = xfer_addr + xfer_length + + if xfer_length == 0: + req += [MemReaderReq(addr=xfer_addr, length=0)] + resp += [MemReaderResp(status=0, data=0, length=0, last=1)] + + assert xfer_max_addr < mem_size, "Max address outside the memory span" + req += [MemReaderReq(addr=xfer_addr, length=xfer_length)] + + rem = xfer_length % data_w_div8 + for addr in range(xfer_addr, xfer_max_addr - (data_w_div8 - 1), data_w_div8): + last = ((addr + data_w_div8) >= xfer_max_addr) & (rem == 0) + data = random.randint(0, 1 << (data_w_div8 * 8)) + mem_writes.update({addr: data}) + resp += [MemReaderResp(status=0, data=data, length=data_w_div8, last=last)] + + if rem > 0: + addr = xfer_max_addr - rem + mask = (1 << (rem * 8)) - 1 + data = random.randint(0, 1 << (data_w_div8 * 8)) + mem_writes.update({addr: data}) + resp += [MemReaderResp(status=0, data=data & mask, length=rem, last=1)] + + return (req, resp, mem_writes) + + +async def test_mem_reader(dut, req_input, resp_output, mem_contents={}): + clock = Clock(dut.clk, 10, units="us") + cocotb.start_soon(clock.start()) + + mem_reader_resp_bus = XLSChannel( + dut, MEM_READER_RESP_CHANNEL, dut.clk, start_now=True + ) + mem_reader_req_driver = XLSChannelDriver(dut, MEM_READER_REQ_CHANNEL, dut.clk) + mem_reader_resp_monitor = XLSChannelMonitor( + dut, MEM_READER_RESP_CHANNEL, dut.clk, MemReaderResp, callback=print_callback() + ) + + terminate = Event() + set_termination_event(mem_reader_resp_monitor, terminate, len(resp_output)) + + scoreboard = Scoreboard(dut) + scoreboard.add_interface(mem_reader_resp_monitor, resp_output) + + ar_bus = AxiARBus.from_prefix(dut, AXI_AR_PREFIX) + r_bus = AxiRBus.from_prefix(dut, AXI_R_PREFIX) + axi_read_bus = AxiReadBus(ar=ar_bus, r=r_bus) + + mem_size = 2**AXI_ADDR_W + sparse_mem = SparseMemory(mem_size) + for addr, data in mem_contents.items(): + sparse_mem.write(addr, (data).to_bytes(8, "little")) + + memory = AxiRamRead(axi_read_bus, dut.clk, dut.rst, size=mem_size, mem=sparse_mem) + + await reset(dut.clk, dut.rst, cycles=10) + await mem_reader_req_driver.send(req_input) + await terminate.wait() + + +@cocotb.test(timeout_time=500, timeout_unit="ms") +async def mem_reader_zero_length_req(dut): + req, resp, _ = generate_test_data( + xfer_base=0xFFF, test_cases=[(0x101, 0)] + ) + await test_mem_reader(dut, req, resp) + + +@cocotb.test(timeout_time=500, timeout_unit="ms") +async def mem_reader_aligned_transfer_shorter_than_bus(dut): + req, resp, mem_contents = generate_test_data( + xfer_base=0xFFF, test_cases=[(0x101, 1)] + ) + await test_mem_reader(dut, req, resp, mem_contents) + + +@cocotb.test(timeout_time=500, timeout_unit="ms") +async def mem_reader_aligned_transfer_shorter_than_bus1(dut): + req, resp, mem_contents = generate_test_data( + xfer_base=0xFFF, test_cases=[(0x2, 1)] + ) + await test_mem_reader(dut, req, resp, mem_contents) + + +@cocotb.test(timeout_time=500, timeout_unit="ms") +async def mem_reader_aligned_transfer_shorter_than_bus2(dut): + req, resp, mem_contents = generate_test_data( + xfer_base=0xFFF, test_cases=[(0x2, 17)] + ) + await test_mem_reader(dut, req, resp, mem_contents) + + +@cocotb.test(timeout_time=500, timeout_unit="ms") +async def mem_reader_aligned_transfer_shorter_than_bus3(dut): + req, resp, mem_contents = generate_test_data( + xfer_base=0xFFF, test_cases=[(0x0, 0x1000)] + ) + await test_mem_reader(dut, req, resp, mem_contents) + + +@cocotb.test(timeout_time=500, timeout_unit="ms") +async def mem_reader_aligned_transfer_shorter_than_bus4(dut): + req, resp, mem_contents = generate_test_data( + xfer_base=0x1, test_cases=[(0x0, 0xFFF), (0x1000, 0x1)] + ) + await test_mem_reader(dut, req, resp, mem_contents) + + +if __name__ == "__main__": + sys.path.append(str(Path(__file__).parent)) + + toplevel = "mem_reader_wrapper" + verilog_sources = [ + "xls/modules/zstd/memory/mem_reader_adv.v", + "xls/modules/zstd/memory/mem_reader_wrapper.v", + ] + test_module = [Path(__file__).stem] + run_test(toplevel, test_module, verilog_sources) diff --git a/xls/modules/zstd/memory/mem_reader_wrapper.v b/xls/modules/zstd/memory/mem_reader_wrapper.v new file mode 100644 index 0000000000..3601bcbb0e --- /dev/null +++ b/xls/modules/zstd/memory/mem_reader_wrapper.v @@ -0,0 +1,111 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +`default_nettype none + +module mem_reader_wrapper #( + parameter DSLX_DATA_W = 64, + parameter DSLX_ADDR_W = 16, + parameter AXI_DATA_W = 128, + parameter AXI_ADDR_W = 16, + parameter AXI_DEST_W = 8, + parameter AXI_ID_W = 8, + + parameter CTRL_W = (DSLX_ADDR_W), + parameter REQ_W = (2 * DSLX_ADDR_W), + parameter RESP_W = (1 + DSLX_DATA_W + DSLX_ADDR_W + 1), + parameter AXI_AR_W = (AXI_ID_W + AXI_ADDR_W + 28), + parameter AXI_R_W = (AXI_ID_W + AXI_DATA_W + 4) +) ( + input wire clk, + input wire rst, + + output wire req_rdy, + input wire req_vld, + input wire [REQ_W-1:0] req_data, + + output wire resp_vld, + input wire resp_rdy, + output wire [RESP_W-1:0] resp_data, + + output wire axi_ar_arvalid, + input wire axi_ar_arready, + output wire [ AXI_ID_W-1:0] axi_ar_arid, + output wire [AXI_ADDR_W-1:0] axi_ar_araddr, + output wire [ 3:0] axi_ar_arregion, + output wire [ 7:0] axi_ar_arlen, + output wire [ 2:0] axi_ar_arsize, + output wire [ 1:0] axi_ar_arburst, + output wire [ 3:0] axi_ar_arcache, + output wire [ 2:0] axi_ar_arprot, + output wire [ 3:0] axi_ar_arqos, + + input wire axi_r_rvalid, + output wire axi_r_rready, + input wire [ AXI_ID_W-1:0] axi_r_rid, + input wire [AXI_DATA_W-1:0] axi_r_rdata, + input wire [ 2:0] axi_r_rresp, + input wire axi_r_rlast +); + + wire [AXI_AR_W-1:0] axi_ar_data; + wire axi_ar_rdy; + wire axi_ar_vld; + + assign axi_ar_rdy = axi_ar_arready; + + assign axi_ar_arvalid = axi_ar_vld; + assign { + axi_ar_arid, + axi_ar_araddr, + axi_ar_arregion, + axi_ar_arlen, + axi_ar_arsize, + axi_ar_arburst, + axi_ar_arcache, + axi_ar_arprot, + axi_ar_arqos +} = axi_ar_data; + + wire [AXI_R_W-1:0] axi_r_data; + wire axi_r_vld; + wire axi_r_rdy; + + assign axi_r_data = {axi_r_rid, axi_r_rdata, axi_r_rresp, axi_r_rlast}; + assign axi_r_vld = axi_r_rvalid; + + assign axi_r_rready = axi_r_rdy; + + mem_reader_adv mem_reader_adv ( + .clk(clk), + .rst(rst), + + .mem_reader__req_r_data(req_data), + .mem_reader__req_r_rdy (req_rdy), + .mem_reader__req_r_vld (req_vld), + + .mem_reader__resp_s_data(resp_data), + .mem_reader__resp_s_rdy (resp_rdy), + .mem_reader__resp_s_vld (resp_vld), + + .mem_reader__axi_ar_s_data(axi_ar_data), + .mem_reader__axi_ar_s_rdy (axi_ar_rdy), + .mem_reader__axi_ar_s_vld (axi_ar_vld), + + .mem_reader__axi_r_r_data(axi_r_data), + .mem_reader__axi_r_r_vld (axi_r_vld), + .mem_reader__axi_r_r_rdy (axi_r_rdy) + ); + +endmodule