Skip to content

Commit

Permalink
simulate_module_main: allow to read testvector textproto.
Browse files Browse the repository at this point in the history
Add new --testvector_proto flag that allows to read a
testvector::SampleInputsProto textproto.
(in preparation of deprecating
--channel_values_file, --args_file.)

Issues: #1645
PiperOrigin-RevId: 689435990
  • Loading branch information
hzeller authored and copybara-github committed Oct 24, 2024
1 parent 5526735 commit 6958511
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
1 change: 1 addition & 0 deletions xls/tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ cc_binary(
"//xls/simulation:module_simulator",
"//xls/simulation:verilog_simulator",
"//xls/simulation:verilog_simulators",
"//xls/tests:testvector_cc_proto",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
76 changes: 61 additions & 15 deletions xls/tools/simulate_module_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <cstdint>
#include <filesystem> // NOLINT
#include <iostream>
Expand Down Expand Up @@ -46,6 +47,7 @@
#include "xls/simulation/module_simulator.h"
#include "xls/simulation/verilog_simulator.h"
#include "xls/simulation/verilog_simulators.h"
#include "xls/tests/testvector.pb.h"
#include "xls/tools/eval_utils.h"

static constexpr std::string_view kUsage = R"(
Expand All @@ -67,6 +69,11 @@ ABSL_FLAG(
std::string, signature_file, "",
"The path to the file containing the text-encoded ModuleSignatureProto "
"describing the interface of the module to simulate.");
ABSL_FLAG(std::string, testvector_textproto, "",
"A textproto file containing the function argument or proc "
"channel test vectors.");

// Deprecated. Soon to be removed in favor of --testvector_textproto
ABSL_FLAG(std::string, args_file, "",
"Batch of arguments to pass to the module, one set per line. Each "
"line should contain a semicolon-separated set of arguments. Cannot "
Expand All @@ -87,6 +94,8 @@ ABSL_FLAG(
"in human-readable form. There is one VALUE per line. There may be zero or "
"more occurrences of VALUE for a channel. The file may contain one or more "
"channels. Cannot be specified with --args_file or --args.");
// end deprecated

ABSL_FLAG(std::vector<std::string>, output_channel_counts, {},
"Comma separated list of output_channel_name=count pairs, for "
"example: result=2. 'output_channel_name' represents an output "
Expand Down Expand Up @@ -115,6 +124,31 @@ struct ProcInput {

using InputType = std::variant<FunctionInput, ProcInput>;

absl::StatusOr<InputType> ConvertTestVector(
const testvector::SampleInputsProto& testvector) {
if (testvector.has_channel_inputs()) {
ProcInput result;
for (const testvector::ChannelInputProto& channel_input :
testvector.channel_inputs().inputs()) {
auto inserted =
result.channel_inputs.insert({channel_input.channel_name(), {}});
QCHECK(inserted.second) << "Multiple channel inputs with same name?";
std::vector<Value>& channel_values = inserted.first->second;
for (std::string_view value_str : channel_input.values()) {
XLS_ASSIGN_OR_RETURN(Value value, Parser::ParseTypedValue(value_str));
channel_values.push_back(value);
}
}
return result;
}
// If not proc, then interpret as function result.
FunctionInput result;
std::copy(testvector.function_args().args().begin(),
testvector.function_args().args().end(),
std::back_inserter(result.args_strings));
return result;
}

absl::Status RunProc(const verilog::ModuleSimulator& simulator,
const verilog::ModuleSignature& signature,
const ProcInput& proc_input) {
Expand Down Expand Up @@ -209,18 +243,22 @@ int main(int argc, char** argv) {

int64_t arg_count = absl::GetFlag(FLAGS_args).empty() ? 0 : 1;
arg_count += absl::GetFlag(FLAGS_args_file).empty() ? 0 : 1;
arg_count += absl::GetFlag(FLAGS_channel_values_file).empty() ? 0 : 1;
arg_count += (absl::GetFlag(FLAGS_channel_values_file).empty() &&
absl::GetFlag(FLAGS_testvector_textproto).empty())
? 0
: 1;
QCHECK_EQ(arg_count, 1)
<< "Must specify one of: --args_file or --args or --channel_values_file.";

if (absl::GetFlag(FLAGS_channel_values_file).empty()) {
QCHECK(absl::GetFlag(FLAGS_output_channel_counts).empty())
<< "'--output_channel_counts' can only be specified with "
"'--channel_values_file'.";
}

xls::InputType input;
if (!absl::GetFlag(FLAGS_args).empty()) {
if (!absl::GetFlag(FLAGS_testvector_textproto).empty()) {
xls::testvector::SampleInputsProto data;
QCHECK_OK(xls::ParseTextProtoFile(absl::GetFlag(FLAGS_testvector_textproto),
&data));
auto converted = xls::ConvertTestVector(data);
QCHECK_OK(converted.status());
input = converted.value();
} else if (!absl::GetFlag(FLAGS_args).empty()) {
input =
xls::FunctionInput{std::vector<std::string>{absl::GetFlag(FLAGS_args)}};
} else if (!absl::GetFlag(FLAGS_args_file).empty()) {
Expand All @@ -229,14 +267,26 @@ int main(int argc, char** argv) {
QCHECK_OK(args_file_contents_or.status());
input = xls::FunctionInput{absl::StrSplit(args_file_contents_or.value(),
'\n', absl::SkipWhitespace())};
} else {
} else if (!absl::GetFlag(FLAGS_channel_values_file).empty()) {
absl::StatusOr<std::string> channel_values_file_contents =
xls::GetFileContents(absl::GetFlag(FLAGS_channel_values_file));
QCHECK_OK(channel_values_file_contents.status());
absl::StatusOr<absl::btree_map<std::string, std::vector<xls::Value>>>
channel_values_or =
xls::ParseChannelValues(channel_values_file_contents.value());
QCHECK_OK(channel_values_or.status());
absl::flat_hash_map<std::string, std::vector<xls::Value>> channel_values;
channel_values.reserve(channel_values_or->size());
absl::c_move(*std::move(channel_values_or),
std::inserter(channel_values, channel_values.end()));
input = xls::ProcInput{.channel_inputs = channel_values};
}

// Expected count of outputs per out-channel. They are separate from the
// input file.
if (!absl::GetFlag(FLAGS_output_channel_counts).empty()) {
QCHECK(std::holds_alternative<xls::ProcInput>(input))
<< "Output channel counts only makes sense on proc inputs";
absl::flat_hash_map<std::string, int64_t> output_channel_counts;
for (std::string_view output_channel_count :
absl::GetFlag(FLAGS_output_channel_counts)) {
Expand All @@ -251,12 +301,8 @@ int main(int argc, char** argv) {
output_channel_count, split[1]);
output_channel_counts[split[0]] = count;
}
absl::flat_hash_map<std::string, std::vector<xls::Value>> channel_values;
channel_values.reserve(channel_values_or->size());
absl::c_move(*std::move(channel_values_or),
std::inserter(channel_values, channel_values.end()));
input = xls::ProcInput{.channel_inputs = channel_values,
.output_channel_counts = output_channel_counts};
std::get<xls::ProcInput>(input).output_channel_counts =
std::move(output_channel_counts);
}

QCHECK(!absl::GetFlag(FLAGS_signature_file).empty())
Expand Down

0 comments on commit 6958511

Please sign in to comment.