Skip to content

Commit

Permalink
dtw: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
denersc committed Mar 12, 2024
1 parent 87f2620 commit add2db7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 190 deletions.
53 changes: 0 additions & 53 deletions tests/test-dtw.py

This file was deleted.

135 changes: 5 additions & 130 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,10 +1068,10 @@ static void whisper_kv_cache_seq_cp(

// [EXPERIMENTAL] Token-level timestamps with DTW
static bool aheads_masks_init(
const whisper_context_params & cparams,
const whisper_hparams & hparams,
struct whisper_aheads_masks & aheads_masks,
ggml_backend_t backend) {
const whisper_context_params & cparams,
const whisper_hparams & hparams,
struct whisper_aheads_masks & aheads_masks,
ggml_backend_t backend) {

const int32_t n_text_layer = hparams.n_text_layer;
const int32_t n_head = hparams.n_text_head;
Expand Down Expand Up @@ -6983,40 +6983,6 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor *
}
}

/*static ggml_tensor * median_filter(ggml_context * ctx, ggml_tensor * x, int filter_width) {
WHISPER_ASSERT(filter_width < x->ne[2]);
WHISPER_ASSERT(filter_width % 2);
WHISPER_ASSERT(ggml_n_dims(x) == 3);
WHISPER_ASSERT(x->type == GGML_TYPE_F32);
std::vector<float> filter;
filter.reserve(filter_width);
ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, x->ne[0], x->ne[1], x->ne[2]);
for (int64_t i = 0; i < x->ne[0]; ++i) {
for (int64_t j = 0; j < x->ne[1]; ++j) {
for (int64_t k = 0; k < x->ne[2]; ++k) {
for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) {
// "reflect" padding
int64_t idx = k + off;
if (idx < 0)
idx = -idx;
else if (idx >= x->ne[2])
idx = 2*(x->ne[2] - 1) - idx;
filter.push_back(ggml_get_f32_nd(x, i, j, idx, 0));
}
std::sort(filter.begin(), filter.end());
const float v = filter[filter.size()/2];
ggml_set_f32_nd(r, i, j, k, 0, v);
filter.clear();
}
}
}
return r;
}*/

static void whisper_exp_compute_token_level_timestamps_dtw(
struct whisper_context * ctx,
struct whisper_state * state,
Expand Down Expand Up @@ -7082,7 +7048,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
const auto n_tokens = state->aheads_cross_QKs->ne[0];
const auto n_heads = state->aheads_cross_QKs->ne[2];

// Copy data from decoder buffer to a local CPU tensor, discarding unused audio
// Copy data from decoder buffer to a local CPU tensor, discarding unused audio
// tokens (i.e. discarding rows at the end of tensor)
// IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
Expand Down Expand Up @@ -7177,97 +7143,6 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
ggml_free(gctx);
}

//void whisper_test_dtw(float* in, size_t in_ne0, size_t in_ne1, int32_t **out, size_t *out_ne0, size_t *out_ne1) {
// struct ggml_init_params params = {
// /*.mem_size =*/ 32*1024*1024,
// /*.mem_buffer =*/ NULL,
// /*.no_alloc =*/ false,
// };
/* struct ggml_context * ctx = ggml_init(params);
struct ggml_tensor * x = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, in_ne0, in_ne1);
for (size_t i = 0; i < in_ne0; i++) {
for (size_t j = 0; j < in_ne1; j++) {
ggml_set_f32_nd(x, i, j, 0, 0, in[j + i * in_ne1]);
}
}
struct ggml_tensor * r = dtw_and_backtrace(ctx, x);
*out = (int32_t*) malloc(sizeof(int32_t) * r->ne[0] * r->ne[1]);
for (int i = 0; i < r->ne[0]; ++i) {
for (int j = 0; j < r->ne[1]; ++j) {
(*out)[j + i * r->ne[1]] = ggml_get_i32_nd(r, i, j, 0, 0);
}
}
*out_ne0 = r->ne[0];
*out_ne1 = r->ne[1];
ggml_free(ctx);
}*/

//void whisper_test_dtw_timestamp_funcs(float* in, size_t in_ne0, size_t in_ne1, size_t in_ne2, float **out, size_t *out_ne0, size_t *out_ne1, size_t *out_ne2) {
// struct ggml_init_params params = {
// /*.mem_size =*/ 32*1024*1024,
// /*.mem_buffer =*/ NULL,
// /*.no_alloc =*/ false,
// };
/* struct ggml_context * ctx = ggml_init(params);
struct ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, in_ne0, in_ne1, in_ne2);
for (int64_t idx = 0; idx < in_ne0*in_ne1*in_ne2; ++idx) {
int64_t k = idx % in_ne2;
int64_t j = (idx / in_ne2) % in_ne1;
int64_t i = idx / (in_ne1*in_ne2);
//fprintf(stderr, "idx=%ld i=%ld j=%ld k=%ld\n", idx, i, j, k);
ggml_set_f32_nd(x, i, j, k, 0, in[idx]);
}
// Testing normalization
// Change dimensions so first is N_TOKENS to normalize over that
// Permute to correct shape for next computations
ggml_tensor * w = ggml_permute(ctx, x, 1, 0, 2, 3); // N_TOKENS*N_AUDIO_TOKENS*ALIGMENT_HEADS
w = ggml_cont(ctx, w);
w = ggml_norm(ctx, w, 0);
w = ggml_permute(ctx, w, 2, 1, 0, 3);
w = ggml_permute(ctx, w, 0, 2, 1, 3);
struct ggml_cgraph * gf = ggml_new_graph(ctx);
ggml_build_forward_expand(gf, w);
ggml_graph_compute_with_ctx(ctx, gf, 4);
// Pass median filter, dimensions unchanged
struct ggml_context * gctx2 = ggml_init(params);
ggml_tensor * w_medfilt = median_filter(gctx2, w, 7);
// - Take mean over rows (matrix = weights.mean(axis=0))
//
// Out dimension is N_TOKENS*N_AUDIO_TOKENS
ggml_tensor * w_mean = ggml_mean(gctx2, w_medfilt);
ggml_tensor * scale = ggml_new_tensor_1d(gctx2, GGML_TYPE_F32, 1);
ggml_set_f32_1d(scale, 0, -1);
ggml_tensor * w_negative = ggml_scale(gctx2, w_mean, scale);
ggml_tensor * w_reshape = ggml_reshape_2d(gctx2, w_negative, w_negative->ne[1], w_negative->ne[2]);
struct ggml_cgraph * gf2 = ggml_new_graph(gctx2);
ggml_build_forward_expand(gf2, w_reshape);
ggml_graph_compute_with_ctx(gctx2, gf2, 4);
// Find alignment
ggml_tensor * alignment = dtw_and_backtrace(gctx2, w_reshape);
// Copy output
ggml_tensor * r = alignment;
*out = (float*) malloc(sizeof(float) * r->ne[0] * r->ne[1] * r->ne[2]);
for (int idx = 0; idx < r->ne[0] * r->ne[1] * r->ne[2]; ++idx) {
int64_t k = idx % r->ne[2];
int64_t j = (idx / r->ne[2]) % r->ne[1];
int64_t i = idx / (r->ne[2]*r->ne[1]);
(*out)[idx] = ggml_get_f32_nd(r, i, j, k, 0);
}
*out_ne0 = r->ne[0];
*out_ne1 = r->ne[1];
*out_ne2 = r->ne[2];
ggml_free(ctx);
}*/


void whisper_log_set(ggml_log_callback log_callback, void * user_data) {
g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default;
g_state.log_callback_user_data = user_data;
Expand Down
10 changes: 3 additions & 7 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ extern "C" {
bool use_gpu;
int gpu_device; // CUDA device

// [EXPERIMENTAL] Token-level timestamps with DTW
// FIXME: not sure if the way dtw_n_top_most and dtw_custom are structured is comfortable?
// [EXPERIMENTAL] DTW-based token-level timestamps
bool dtw_token_timestamps;
enum whisper_alignment_heads_preset dtw_aheads_preset;
struct {
Expand All @@ -142,9 +142,9 @@ extern "C" {
int64_t t0; // start time of the token
int64_t t1; // end time of the token

// dtw token-level timestamp data
// [EXPERIMENTAL] Token-level timestamps with DTW
// do not use if you haven't computed token-level timestamps with dtw
// (I think) roughly corresponds to the moment in audio in which the token was output
// Roughly corresponds to the moment in audio in which the token was output
int64_t t_dtw;

float vlen; // voice length of the token
Expand Down Expand Up @@ -659,10 +659,6 @@ extern "C" {

WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data);

// test dtw
//WHISPER_API void whisper_test_dtw(float* in, size_t in_ne0, size_t in_ne1, int32_t **out, size_t *out_ne0, size_t *out_ne1);
//WHISPER_API void whisper_test_dtw_timestamp_funcs(float* in, size_t in_ne0, size_t in_ne1, size_t in_ne2, float **out, size_t *out_ne0, size_t *out_ne1, size_t *out_ne2);

#ifdef __cplusplus
}
#endif
Expand Down

0 comments on commit add2db7

Please sign in to comment.