From 3476bd71987631e5bb5240383dc608ce18895158 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 22 May 2024 15:07:01 +0100 Subject: [PATCH] feat: hugr binary cli tool (#1096) Closes #1095 currently validates against std extensions can read from stdin with `echo "" | hugr -` binary depends on optional feature to limit dependencies, but not put in sub crate to allow `cargo install hugr` add integration testing of binary behaviour based on https://docs.rs/assert_cmd/2.0.14/assert_cmd/cmd/struct.Command.html#method.write_stdin --- hugr/Cargo.toml | 17 +++++++-- hugr/src/cli.rs | 54 +++++++++++++++++++++++++++ hugr/src/lib.rs | 3 ++ hugr/src/main.rs | 34 +++++++++++++++++ hugr/tests/cli.rs | 93 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 hugr/src/cli.rs create mode 100644 hugr/src/main.rs create mode 100644 hugr/tests/cli.rs diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index 03b643550..89a5f8949 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -23,6 +23,7 @@ path = "src/lib.rs" [features] extension_inference = [] +cli = ["dep:clap", "dep:clap-stdin"] [dependencies] portgraph = { workspace = true, features = ["serde", "petgraph"] } @@ -52,6 +53,8 @@ delegate = "0.12.0" paste = "1.0" strum = "0.26.1" strum_macros = "0.26.1" +clap = { version = "4.5.4", features = ["derive"], optional = true } +clap-stdin = { version = "0.4.0", optional = true } [dev-dependencies] criterion = { version = "0.5.1", features = ["html_reports"] } @@ -61,10 +64,18 @@ urlencoding = "2.1.2" cool_asserts = "2.0.3" insta = { workspace = true, features = ["yaml"] } jsonschema = "0.18.0" -proptest = { version = "1.4.0" } -proptest-derive = { version = "0.4.0"} -regex-syntax = { version = "0.8.3"} +proptest = { version = "1.4.0" } +proptest-derive = { version = "0.4.0" } +regex-syntax = { version = "0.8.3" } +assert_cmd = "2.0.14" +predicates = "3.1.0" +assert_fs = "1.1.1" [[bench]] name = "bench_main" harness = false + + +[[bin]] +name = "hugr" +required-features = ["cli"] diff --git a/hugr/src/cli.rs b/hugr/src/cli.rs new file mode 100644 index 000000000..a58278188 --- /dev/null +++ b/hugr/src/cli.rs @@ -0,0 +1,54 @@ +//! Standard command line tools, used by the hugr binary. + +use crate::{extension::ExtensionRegistry, Hugr, HugrView}; +use clap::Parser; +use clap_stdin::FileOrStdin; +use thiserror::Error; +/// Validate and visualise a HUGR file. +#[derive(Parser, Debug)] +#[clap(version = "1.0", long_about = None)] +#[clap(about = "Validate a HUGR.")] +pub struct CmdLineArgs { + input: FileOrStdin, + /// Visualise with mermaid. + #[arg(short, long, value_name = "MERMAID", help = "Visualise with mermaid.")] + mermaid: bool, + /// Skip validation. + #[arg(short, long, help = "Skip validation.")] + no_validate: bool, + // TODO YAML extensions +} + +/// Error type for the CLI. +#[derive(Error, Debug)] +pub enum CliError { + /// Error reading input. + #[error("Error reading input: {0}")] + Input(#[from] clap_stdin::StdinError), + /// Error parsing input. + #[error("Error parsing input: {0}")] + Parse(#[from] serde_json::Error), + /// Error validating HUGR. + #[error("Error validating HUGR: {0}")] + Validate(#[from] crate::hugr::ValidationError), +} + +/// String to print when validation is successful. +pub const VALID_PRINT: &str = "HUGR valid!"; + +impl CmdLineArgs { + /// Run the HUGR cli and validate against an extension registry. + pub fn run(&self, registry: &ExtensionRegistry) -> Result<(), CliError> { + let mut hugr: Hugr = serde_json::from_reader(self.input.into_reader()?)?; + if self.mermaid { + println!("{}", hugr.mermaid_string()); + } + + if !self.no_validate { + hugr.update_validate(registry)?; + + println!("{}", VALID_PRINT); + } + Ok(()) + } +} diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 7d1dd2499..7c1cdca5f 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -157,5 +157,8 @@ pub use crate::core::{ pub use crate::extension::Extension; pub use crate::hugr::{Hugr, HugrView, SimpleReplacement}; +#[cfg(feature = "cli")] +pub mod cli; + #[cfg(test)] pub mod proptest; diff --git a/hugr/src/main.rs b/hugr/src/main.rs new file mode 100644 index 000000000..7055446fa --- /dev/null +++ b/hugr/src/main.rs @@ -0,0 +1,34 @@ +//! Validate serialized HUGR on the command line + +use hugr::std_extensions::arithmetic::{ + conversions::EXTENSION as CONVERSIONS_EXTENSION, float_ops::EXTENSION as FLOAT_OPS_EXTENSION, + float_types::EXTENSION as FLOAT_TYPES_EXTENSION, int_ops::EXTENSION as INT_OPS_EXTENSION, + int_types::EXTENSION as INT_TYPES_EXTENSION, +}; +use hugr::std_extensions::logic::EXTENSION as LOGICS_EXTENSION; + +use hugr::extension::{ExtensionRegistry, PRELUDE}; + +use clap::Parser; +use hugr::cli::CmdLineArgs; + +fn main() { + let opts = CmdLineArgs::parse(); + + // validate with all std extensions + let reg = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + INT_OPS_EXTENSION.to_owned(), + INT_TYPES_EXTENSION.to_owned(), + CONVERSIONS_EXTENSION.to_owned(), + FLOAT_OPS_EXTENSION.to_owned(), + FLOAT_TYPES_EXTENSION.to_owned(), + LOGICS_EXTENSION.to_owned(), + ]) + .unwrap(); + + if let Err(e) = opts.run(®) { + eprintln!("{}", e); + std::process::exit(1); + } +} diff --git a/hugr/tests/cli.rs b/hugr/tests/cli.rs new file mode 100644 index 000000000..601762081 --- /dev/null +++ b/hugr/tests/cli.rs @@ -0,0 +1,93 @@ +#![cfg(feature = "cli")] +use assert_cmd::Command; +use assert_fs::{fixture::FileWriteStr, NamedTempFile}; +use hugr::{ + builder::{Container, Dataflow, DataflowHugr}, + extension::prelude::{BOOL_T, QB_T}, + type_row, + types::FunctionType, + Hugr, +}; +use predicates::{prelude::*, str::contains}; +use rstest::{fixture, rstest}; + +use hugr::builder::DFGBuilder; +use hugr::cli::VALID_PRINT; +#[fixture] +fn cmd() -> Command { + Command::cargo_bin(env!("CARGO_PKG_NAME")).unwrap() +} + +#[fixture] +fn test_hugr() -> Hugr { + let df = DFGBuilder::new(FunctionType::new_endo(type_row![BOOL_T])).unwrap(); + let [i] = df.input_wires_arr(); + df.finish_prelude_hugr_with_outputs([i]).unwrap() +} + +#[fixture] +fn test_hugr_string(test_hugr: Hugr) -> String { + serde_json::to_string(&test_hugr).unwrap() +} + +#[fixture] +fn test_hugr_file(test_hugr_string: String) -> NamedTempFile { + let file = assert_fs::NamedTempFile::new("sample.hugr").unwrap(); + file.write_str(&test_hugr_string).unwrap(); + file +} + +#[rstest] +fn test_doesnt_exist(mut cmd: Command) { + cmd.arg("foobar"); + cmd.assert() + .failure() + .stderr(contains("No such file or directory").and(contains("Error reading input"))); +} + +#[rstest] +fn test_validate(test_hugr_file: NamedTempFile, mut cmd: Command) { + cmd.arg(test_hugr_file.path()); + cmd.assert().success().stdout(contains(VALID_PRINT)); +} + +#[rstest] +fn test_stdin(test_hugr_string: String, mut cmd: Command) { + cmd.write_stdin(test_hugr_string); + cmd.arg("-"); + + cmd.assert().success().stdout(contains(VALID_PRINT)); +} + +#[rstest] +fn test_mermaid(test_hugr_file: NamedTempFile, mut cmd: Command) { + const MERMAID: &str = "graph LR\n subgraph 0 [\"(0) DFG\"]"; + cmd.arg(test_hugr_file.path()); + cmd.arg("--mermaid"); + cmd.arg("--no-validate"); + cmd.assert().success().stdout(contains(MERMAID)); +} + +#[rstest] +fn test_bad_hugr(mut cmd: Command) { + let df = DFGBuilder::new(FunctionType::new_endo(type_row![QB_T])).unwrap(); + let bad_hugr = df.hugr().clone(); + + let bad_hugr_string = serde_json::to_string(&bad_hugr).unwrap(); + cmd.write_stdin(bad_hugr_string); + cmd.arg("-"); + + cmd.assert() + .failure() + .stderr(contains("Error validating HUGR").and(contains("unconnected port"))); +} + +#[rstest] +fn test_bad_json(mut cmd: Command) { + cmd.write_stdin(r#"{"foo": "bar"}"#); + cmd.arg("-"); + + cmd.assert() + .failure() + .stderr(contains("Error parsing input")); +}