Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compatibility issue #12

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,11 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) {

bool is_file_exist(const char *fileName)
{
std::ifstream infile(fileName);
#ifdef _WIN32
std::wifstream infile(console::UTF8toUTF16(fileName).c_str());
#else
std::ifstream infile(fileName);
#endif
return infile.good();
}

Expand Down
102 changes: 51 additions & 51 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,11 @@ struct whisper_params {

void whisper_print_usage(int argc, const char ** argv, const whisper_params & params);

char* whisper_param_turn_lowercase(char* in){
int string_len = strlen(in);
for(int i = 0; i < string_len; i++){
*(in+i) = tolower((unsigned char)*(in+i));
}
return in;
std::string toLowerCase(const std::string& input) {
std::string result = input; // Create a copy of the input string
std::transform(result.begin(), result.end(), result.begin(),
[](unsigned char c){ return std::tolower(c); });
return result;
}

bool whisper_params_parse(int argc, const char ** argv, whisper_params & params) {
Expand Down Expand Up @@ -163,7 +162,7 @@ bool whisper_params_parse(int argc, const char ** argv, whisper_params & params)
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); }
else if (arg == "-l" || arg == "--language") { params.language = toLowerCase(argv[++i]); }
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
Expand Down Expand Up @@ -954,35 +953,6 @@ void cb_log_disable(enum ggml_log_level , const char * , void * ) { }
int run(int argc, const char ** argv) {
whisper_params params;

// If the only argument starts with "@", read arguments line-by-line
// from the given file.
std::vector<std::string> vec_args;
if (argc == 2 && argv != nullptr && argv[1] != nullptr && argv[1][0] == '@') {
// Save the name of the executable.
vec_args.push_back(argv[0]);

// Open the response file.
char const * rspfile = argv[1] + sizeof(char);
std::ifstream fin(rspfile);
if (fin.is_open() == false) {
fprintf(stderr, "error: response file '%s' not found\n", rspfile);
return 1;
}

// Read the entire response file.
std::string line;
while (std::getline(fin, line)) {
vec_args.push_back(line);
}

// Use the contents of the response file as the command-line arguments.
argc = static_cast<int>(vec_args.size());
argv = static_cast<char **>(alloca(argc * sizeof (char *)));
for (int i = 0; i < argc; ++i) {
argv[i] = const_cast<char *>(vec_args[i].c_str());
}
}

if (whisper_params_parse(argc, argv, params) == false) {
whisper_print_usage(argc, argv, params);
return 1;
Expand Down Expand Up @@ -1151,7 +1121,6 @@ int run(int argc, const char ** argv) {
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;

wparams.heuristic = params.heuristic;
wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx;

wparams.debug_mode = params.debug_mode;
Expand Down Expand Up @@ -1295,22 +1264,53 @@ int run(int argc, const char ** argv) {
return 0;
}

#if _WIN32
int wmain(int argc, const wchar_t ** argv_UTF16LE) {
console::init(true, true);
atexit([]() { console::cleanup(); });
std::vector<std::string> buffer(argc);
std::vector<const char*> argv_UTF8(argc);
for (int i = 0; i < argc; ++i) {
buffer[i] = console::UTF16toUTF8(argv_UTF16LE[i]);
argv_UTF8[i] = buffer[i].c_str();
}
return run(argc, argv_UTF8.data());
// Platform-specific function to convert UTF-16 to UTF-8
#ifdef _WIN32
std::string UTF16toUTF8(const wchar_t* utf16str) {
return console::UTF16toUTF8(utf16str);
}
#define MAIN wmain
#define CHAR_TYPE const wchar_t
#else
int main(int argc, const char ** argv_UTF8) {
#define MAIN main
#define CHAR_TYPE const char
#endif

int MAIN(int argc, CHAR_TYPE** argv) {
console::init(true, true);
atexit([]() { console::cleanup(); });
return run(argc, argv_UTF8);
}

#ifdef _WIN32
auto convert_to_utf8 = UTF16toUTF8;
#else
auto convert_to_utf8 = [](const char* str) { return std::string(str); };
#endif

std::vector<std::string> args;
if (argc == 2 && argv != nullptr && argv[1] != nullptr && convert_to_utf8(argv[1])[0] == '@') {
args.push_back(convert_to_utf8(argv[0]));
const char* rspfile = convert_to_utf8(argv[1]).c_str() + 1; // skip '@'
std::ifstream fin(rspfile);

if (!fin.is_open()) {
fprintf(stderr, "error: response file '%s' not found\n", rspfile);
return 1;
}

std::string line;
while (std::getline(fin, line)) {
args.push_back(line);
}
} else {
for (int i = 0; i < argc; ++i) {
args.push_back(convert_to_utf8(argv[i]));
}
}

std::vector<const char*> argv_converted(args.size());
for (size_t i = 0; i < args.size(); ++i) {
argv_converted[i] = args[i].c_str();
}

return run(static_cast<int>(args.size()), argv_converted.data());
}
1 change: 0 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,6 @@ int main(int argc, char ** argv) {

wparams.thold_pt = params.word_thold;
wparams.max_len = params.max_len == 0 ? 60 : params.max_len;
wparams.split_on_word = params.split_on_word;
wparams.audio_ctx = params.audio_ctx;

wparams.debug_mode = params.debug_mode;
Expand Down
Loading