diff --git a/CHANGELOG.md b/CHANGELOG.md index afa4465ce..6b7c1278d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Fixed compilation with clang 16.0.6 - Added Threads::Threads to `EXT_LIBS` - Updates to pymarian: building for multiple python versions; disabling tcmalloc; hosting gated COMETs on HuggingFace +- Add "pymarian" CLI, a proxy to "marian" binary, but made available in PATH after "pip install pymarian" ### Added - Added `--normalize-gradient-by-ratio` to mildly adapt gradient magnitude if effective batch size diverges from running average effective batch size. diff --git a/src/command/marian_main.cpp b/src/command/marian_main.cpp index e838fe808..3897ac6bb 100644 --- a/src/command/marian_main.cpp +++ b/src/command/marian_main.cpp @@ -38,25 +38,51 @@ #include "marian_conv.cpp" #undef main +#include +#include +#include #include "3rd_party/ExceptionWithCallStack.h" +#include "3rd_party/spdlog/details/format.h" int main(int argc, char** argv) { using namespace marian; + using MainFunc = int(*)(int, char**); + std::map> subcmds = { + {"train", {&mainTrainer, "Train a model (default)"}}, + {"decode", {&mainDecoder, "Decode or translate text"}}, + {"score", {&mainScorer, "Score translations"}}, + {"embed", {&mainEmbedder, "Embed text"}}, + {"evaluate", {&mainEvaluator, "Run Evaluator metric"}}, + {"vocab", {&mainVocab, "Create vocabulary"}}, + {"convert", {&mainConv, "Convert model file format"}} + }; + // no arguments, or the first arg is "?"", print help message + if (argc == 1 || (argc == 2 && (std::string(argv[1]) == "?") )) { + std::cout << "Usage: " << argv[0] << " COMMAND [ARGS]" << std::endl; + std::cout << "Commands:" << std::endl; + for (auto&& [name, val] : subcmds) { + std::cerr << fmt::format("{:10} : {}\n", name, std::get<1>(val)); + } + return 0; + } - if(argc > 1 && argv[1][0] != '-') { + if (argc > 1 && argv[1][0] != '-') { std::string cmd = argv[1]; argc--; argv[1] = argv[0]; argv++; - if(cmd == "train") return mainTrainer(argc, argv); - else if(cmd == "decode") return mainDecoder(argc, argv); - else if (cmd == "score") return mainScorer(argc, argv); - else if (cmd == "embed") return mainEmbedder(argc, argv); - else if (cmd == "evaluate") return mainEvaluator(argc, argv); - else if (cmd == "vocab") return mainVocab(argc, argv); - else if (cmd == "convert") return mainConv(argc, argv); - std::cerr << "Command must be train, decode, score, embed, vocab, or convert." << std::endl; - exit(1); - } else + if (subcmds.count(cmd) > 0) { + auto [func, desc] = subcmds[cmd]; + return func(argc, argv); + } + else { + std::cerr << "Unknown command: " << cmd << ". Known commands are:" << std::endl; + for (auto&& [name, val] : subcmds) { + std::cerr << fmt::format("{:10} : {}\n", name, std::get<1>(val)); + } + return 1; + } + } + else return mainTrainer(argc, argv); } diff --git a/src/python/binding/bind.cpp b/src/python/binding/bind.cpp index 38a1e3429..e42fd4ff5 100644 --- a/src/python/binding/bind.cpp +++ b/src/python/binding/bind.cpp @@ -1,3 +1,5 @@ +#define PYBIND11_DETAILED_ERROR_MESSAGES + #include "pybind11/pybind11.h" #include "pybind11/stl.h" // if your IDE/vscode complains about missing paths @@ -6,13 +8,30 @@ #include "evaluator.hpp" #include "trainer.hpp" #include "translator.hpp" - - -#define PYBIND11_DETAILED_ERROR_MESSAGES +#include "command/marian_main.cpp" namespace py = pybind11; using namespace pymarian; +/** + * @brief Wrapper function to call Marian main entry point from Python + * + * Calls Marian main entry point from Python. + * It converts args from a vector of strings (Python-ic API) to char* (C API) + * before passsing on to the main function. + * @param args vector of strings + * @return int return code + */ +int main_wrap(std::vector args) { + // Convert vector of strings to vector of char* + std::vector argv; + argv.push_back(const_cast("pymarian")); + for (auto& arg : args) { + argv.push_back(const_cast(arg.c_str())); + } + argv.push_back(nullptr); + return main(argv.size() - 1, argv.data()); +} PYBIND11_MODULE(_pymarian, m) { m.doc() = "Marian C++ API bindings via pybind11"; @@ -44,5 +63,7 @@ PYBIND11_MODULE(_pymarian, m) { .def("embed", py::overload_cast<>(&PyEmbedder::embed)) ; + m.def("main", &main_wrap, "Marian main entry point"); + } diff --git a/src/python/pymarian/__init__.py b/src/python/pymarian/__init__.py index 36011c203..7816acc89 100644 --- a/src/python/pymarian/__init__.py +++ b/src/python/pymarian/__init__.py @@ -1,6 +1,7 @@ import logging from itertools import islice from pathlib import Path +import sys from typing import Iterator, List, Optional, Tuple, Union import _pymarian @@ -46,8 +47,8 @@ def model_type(self) -> str: @classmethod def new( cls, - model_file: Path, - vocab_file: Path = None, + model_file: Union[Path, str], + vocab_file: Union[Path, str] = None, devices: Optional[List[int]] = None, width=Defaults.FLOAT_PRECISION, mini_batch=Defaults.MINI_BATCH, @@ -76,8 +77,8 @@ def new( :return: iterator of scores """ - assert model_file.exists(), f'Model file {model_file} does not exist' - assert vocab_file.exists(), f'Vocab file {vocab_file} does not exist' + assert Path(model_file).exists(), f'Model file {model_file} does not exist' + assert Path(vocab_file).exists(), f'Vocab file {vocab_file} does not exist' assert like in Defaults.MODEL_TYPES, f'Unknown model type: {like}' n_inputs = len(Defaults.MODEL_TYPES[like]) vocabs = [vocab_file] * n_inputs @@ -97,7 +98,7 @@ def new( cpu_threads=cpu_threads, average=average, ) - if kwargs.pop('fp16'): + if kwargs.pop('fp16', False): kwargs['fp16'] = '' # empty string for flag; i.e, "--fp16" and not "--fp16=true" # TODO: remove this when c++ bindings supports iterator @@ -171,3 +172,19 @@ def __init__(self, cli_string='', **kwargs): """ cli_string += ' ' + kwargs_to_cli(**kwargs) super().__init__(cli_string.stip()) + +def main(): + """proxy to marian main function""" + code = _pymarian.main(sys.argv[1:]) + sys.exit(code) + +def help(*vargs): + """print help text""" + args = [] + args += vargs + if '--help' not in args and '-h' not in args: + args.append('--help') + # note: this will print to stdout + _pymarian.main(args) + # do not exit, as this is a library function + diff --git a/src/python/pyproject.toml b/src/python/pyproject.toml index 30eb16f36..34445648e 100644 --- a/src/python/pyproject.toml +++ b/src/python/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ ] [project.scripts] +pymarian = "pymarian:main" pymarian-eval = "pymarian.eval:main" pymarian-qtdemo = "pymarian.qtdemo:main" pymarian-mtapi = "pymarian.mtapi_server:main"