diff --git a/examples/common.cpp b/examples/common.cpp index 7b2556eb486..22da7b15800 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -894,7 +894,11 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) { bool is_file_exist(const char *fileName) { - std::ifstream infile(fileName); + #ifdef _WIN32 + std::wifstream infile(console::UTF8toUTF16(fileName).c_str()); + #else + std::ifstream infile(fileName); + #endif return infile.good(); } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 78d2e6e7ab6..774198e34d6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -99,12 +99,11 @@ struct whisper_params { void whisper_print_usage(int argc, const char ** argv, const whisper_params & params); -char* whisper_param_turn_lowercase(char* in){ - int string_len = strlen(in); - for(int i = 0; i < string_len; i++){ - *(in+i) = tolower((unsigned char)*(in+i)); - } - return in; +std::string toLowerCase(const std::string& input) { + std::string result = input; // Create a copy of the input string + std::transform(result.begin(), result.end(), result.begin(), + [](unsigned char c){ return std::tolower(c); }); + return result; } bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) { @@ -163,7 +162,7 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); } + else if (arg == "-l" || arg == "--language") { params.language = toLowerCase(argv[++i]); } else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } else if ( arg == "--prompt") { params.prompt = argv[++i]; } else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } @@ -954,35 +953,6 @@ void cb_log_disable(enum ggml_log_level , const char * , void * ) { } int run(int argc, const char ** argv) { whisper_params params; - // If the only argument starts with "@", read arguments line-by-line - // from the given file. - std::vector vec_args; - if (argc == 2 && argv != nullptr && argv[1] != nullptr && argv[1][0] == '@') { - // Save the name of the executable. - vec_args.push_back(argv[0]); - - // Open the response file. - char const * rspfile = argv[1] + sizeof(char); - std::ifstream fin(rspfile); - if (fin.is_open() == false) { - fprintf(stderr, "error: response file '%s' not found\n", rspfile); - return 1; - } - - // Read the entire response file. - std::string line; - while (std::getline(fin, line)) { - vec_args.push_back(line); - } - - // Use the contents of the response file as the command-line arguments. - argc = static_cast(vec_args.size()); - argv = static_cast(alloca(argc * sizeof (char *))); - for (int i = 0; i < argc; ++i) { - argv[i] = const_cast(vec_args[i].c_str()); - } - } - if (whisper_params_parse(argc, argv, params) == false) { whisper_print_usage(argc, argv, params); return 1; @@ -1151,7 +1121,6 @@ int run(int argc, const char ** argv) { wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; wparams.heuristic = params.heuristic; - wparams.split_on_word = params.split_on_word; wparams.audio_ctx = params.audio_ctx; wparams.debug_mode = params.debug_mode; @@ -1295,22 +1264,53 @@ int run(int argc, const char ** argv) { return 0; } -#if _WIN32 -int wmain(int argc, const wchar_t ** argv_UTF16LE) { - console::init(true, true); - atexit([]() { console::cleanup(); }); - std::vector buffer(argc); - std::vector argv_UTF8(argc); - for (int i = 0; i < argc; ++i) { - buffer[i] = console::UTF16toUTF8(argv_UTF16LE[i]); - argv_UTF8[i] = buffer[i].c_str(); - } - return run(argc, argv_UTF8.data()); +// Platform-specific function to convert UTF-16 to UTF-8 +#ifdef _WIN32 +std::string UTF16toUTF8(const wchar_t* utf16str) { + return console::UTF16toUTF8(utf16str); } +#define MAIN wmain +#define CHAR_TYPE const wchar_t #else -int main(int argc, const char ** argv_UTF8) { +#define MAIN main +#define CHAR_TYPE const char +#endif + +int MAIN(int argc, CHAR_TYPE** argv) { console::init(true, true); atexit([]() { console::cleanup(); }); - return run(argc, argv_UTF8); -} + +#ifdef _WIN32 + auto convert_to_utf8 = UTF16toUTF8; +#else + auto convert_to_utf8 = [](const char* str) { return std::string(str); }; #endif + + std::vector args; + if (argc == 2 && argv != nullptr && argv[1] != nullptr && convert_to_utf8(argv[1])[0] == '@') { + args.push_back(convert_to_utf8(argv[0])); + const char* rspfile = convert_to_utf8(argv[1]).c_str() + 1; // skip '@' + std::ifstream fin(rspfile); + + if (!fin.is_open()) { + fprintf(stderr, "error: response file '%s' not found\n", rspfile); + return 1; + } + + std::string line; + while (std::getline(fin, line)) { + args.push_back(line); + } + } else { + for (int i = 0; i < argc; ++i) { + args.push_back(convert_to_utf8(argv[i])); + } + } + + std::vector argv_converted(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + argv_converted[i] = args[i].c_str(); + } + + return run(static_cast(args.size()), argv_converted.data()); +} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3f540c71394..a1d86e8adef 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -751,7 +751,6 @@ int main(int argc, char ** argv) { wparams.thold_pt = params.word_thold; wparams.max_len = params.max_len == 0 ? 60 : params.max_len; - wparams.split_on_word = params.split_on_word; wparams.audio_ctx = params.audio_ctx; wparams.debug_mode = params.debug_mode;