diff --git a/include/whisper.h b/include/whisper.h index fcd756a9fe2..430a8231a2a 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -9,34 +9,35 @@ #include #ifdef __GNUC__ -# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) #elif defined(_MSC_VER) -# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func #else -# define WHISPER_DEPRECATED(func, hint) func +#define WHISPER_DEPRECATED(func, hint) func #endif #ifdef WHISPER_SHARED -# ifdef _WIN32 -# ifdef WHISPER_BUILD -# define WHISPER_API __declspec(dllexport) -# else -# define WHISPER_API __declspec(dllimport) -# endif -# else -# define WHISPER_API __attribute__ ((visibility ("default"))) -# endif +#ifdef _WIN32 +#ifdef WHISPER_BUILD +#define WHISPER_API __declspec(dllexport) #else -# define WHISPER_API +#define WHISPER_API __declspec(dllimport) +#endif +#else +#define WHISPER_API __attribute__((visibility("default"))) +#endif +#else +#define WHISPER_API #endif #define WHISPER_SAMPLE_RATE 16000 -#define WHISPER_N_FFT 400 -#define WHISPER_HOP_LENGTH 160 -#define WHISPER_CHUNK_SIZE 30 +#define WHISPER_N_FFT 400 +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_CHUNK_SIZE 30 #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif // @@ -85,9 +86,10 @@ extern "C" { typedef int32_t whisper_token; typedef int32_t whisper_seq_id; - enum whisper_alignment_heads_preset { + enum whisper_alignment_heads_preset + { WHISPER_AHEADS_NONE, - WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers + WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers WHISPER_AHEADS_CUSTOM, WHISPER_AHEADS_TINY_EN, WHISPER_AHEADS_TINY, @@ -103,20 +105,23 @@ extern "C" { WHISPER_AHEADS_LARGE_V3_TURBO, }; - typedef struct whisper_ahead { + typedef struct whisper_ahead + { int n_text_layer; int n_head; } whisper_ahead; - typedef struct whisper_aheads { + typedef struct whisper_aheads + { size_t n_heads; - const whisper_ahead * heads; + const whisper_ahead *heads; } whisper_aheads; - struct whisper_context_params { - bool use_gpu; - bool flash_attn; - int gpu_device; // CUDA device + struct whisper_context_params + { + bool use_gpu; + bool flash_attn; + int gpu_device; // CUDA device // [EXPERIMENTAL] Token-level timestamps with DTW bool dtw_token_timestamps; @@ -128,52 +133,55 @@ extern "C" { size_t dtw_mem_size; // TODO: remove }; - typedef struct whisper_token_data { + typedef struct whisper_token_data + { whisper_token id; // token id whisper_token tid; // forced timestamp token id - float p; // probability of the token - float plog; // log probability of the token - float pt; // probability of the timestamp token - float ptsum; // sum of probabilities of all timestamp tokens + float p; // probability of the token + float plog; // log probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens // token-level timestamp data // do not use if you haven't computed token-level timestamps - int64_t t0; // start time of the token - int64_t t1; // end time of the token + int64_t t0; // start time of the token + int64_t t1; // end time of the token // [EXPERIMENTAL] Token-level timestamps with DTW // do not use if you haven't computed token-level timestamps with dtw // Roughly corresponds to the moment in audio in which the token was output int64_t t_dtw; - float vlen; // voice length of the token + float vlen; // voice length of the token } whisper_token_data; - typedef struct whisper_model_loader { - void * context; + typedef struct whisper_model_loader + { + void *context; - size_t (*read)(void * ctx, void * output, size_t read_size); - bool (*eof)(void * ctx); - void (*close)(void * ctx); + size_t (*read)(void *ctx, void *output, size_t read_size); + bool (*eof)(void *ctx); + void (*close)(void *ctx); } whisper_model_loader; // grammar element type - enum whisper_gretype { + enum whisper_gretype + { // end of rule definition - WHISPER_GRETYPE_END = 0, + WHISPER_GRETYPE_END = 0, // start of alternate definition for rule - WHISPER_GRETYPE_ALT = 1, + WHISPER_GRETYPE_ALT = 1, // non-terminal element: reference to rule - WHISPER_GRETYPE_RULE_REF = 2, + WHISPER_GRETYPE_RULE_REF = 2, // terminal element: character (code point) - WHISPER_GRETYPE_CHAR = 3, + WHISPER_GRETYPE_CHAR = 3, // inverse char(s) ([^a], [^a-b] [^abc]) - WHISPER_GRETYPE_CHAR_NOT = 4, + WHISPER_GRETYPE_CHAR_NOT = 4, // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to // be an inclusive range ([a-z]) @@ -181,64 +189,60 @@ extern "C" { // modifies a preceding WHISPER_GRETYPE_CHAR or // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - WHISPER_GRETYPE_CHAR_ALT = 6, + WHISPER_GRETYPE_CHAR_ALT = 6, }; - typedef struct whisper_grammar_element { + typedef struct whisper_grammar_element + { enum whisper_gretype type; - uint32_t value; // Unicode code point or rule ID + uint32_t value; // Unicode code point or rule ID } whisper_grammar_element; - typedef struct whisper_vad_params { - float threshold; // Probability threshold to consider as speech. - int min_speech_duration_ms; // Min duration for a valid speech segment. - int min_silence_duration_ms; // Min silence duration to consider speech as ended. - float max_speech_duration_s; // Max duration of a speech segment before forcing a new segment. - int speech_pad_ms; // Padding added before and after speech segments. - float samples_overlap; // Overlap in seconds when copying audio samples from speech segment. + typedef struct whisper_vad_params + { + float threshold; // Probability threshold to consider as speech. + int min_speech_duration_ms; // Min duration for a valid speech segment. + int min_silence_duration_ms; // Min silence duration to consider speech as ended. + float max_speech_duration_s; // Max duration of a speech segment before forcing a new segment. + int speech_pad_ms; // Padding added before and after speech segments. + float samples_overlap; // Overlap in seconds when copying audio samples from speech segment. } whisper_vad_params; - WHISPER_API const char * whisper_version(void); + WHISPER_API const char *whisper_version(void); // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure - WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params); + WHISPER_API struct whisper_context *whisper_init_from_file_with_params(const char *path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context *whisper_init_from_buffer_with_params(void *buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context *whisper_init_with_params(struct whisper_model_loader *loader, struct whisper_context_params params); // These are the same as the above, but the internal state of the context is not allocated automatically // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) - WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); - WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params); + WHISPER_API struct whisper_context *whisper_init_from_file_with_params_no_state(const char *path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context *whisper_init_from_buffer_with_params_no_state(void *buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context *whisper_init_with_params_no_state(struct whisper_model_loader *loader, struct whisper_context_params params); WHISPER_DEPRECATED( - WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model), - "use whisper_init_from_file_with_params instead" - ); + WHISPER_API struct whisper_context *whisper_init_from_file(const char *path_model), + "use whisper_init_from_file_with_params instead"); WHISPER_DEPRECATED( - WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size), - "use whisper_init_from_buffer_with_params instead" - ); + WHISPER_API struct whisper_context *whisper_init_from_buffer(void *buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params instead"); WHISPER_DEPRECATED( - WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader), - "use whisper_init_with_params instead" - ); + WHISPER_API struct whisper_context *whisper_init(struct whisper_model_loader *loader), + "use whisper_init_with_params instead"); WHISPER_DEPRECATED( - WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model), - "use whisper_init_from_file_with_params_no_state instead" - ); + WHISPER_API struct whisper_context *whisper_init_from_file_no_state(const char *path_model), + "use whisper_init_from_file_with_params_no_state instead"); WHISPER_DEPRECATED( - WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size), - "use whisper_init_from_buffer_with_params_no_state instead" - ); + WHISPER_API struct whisper_context *whisper_init_from_buffer_no_state(void *buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params_no_state instead"); WHISPER_DEPRECATED( - WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader), - "use whisper_init_with_params_no_state instead" - ); + WHISPER_API struct whisper_context *whisper_init_no_state(struct whisper_model_loader *loader), + "use whisper_init_with_params_no_state instead"); - WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); + WHISPER_API struct whisper_state *whisper_init_state(struct whisper_context *ctx); // Given a context, enable use of OpenVINO for encode inference. // model_path: Optional path to OpenVINO encoder IR model. If set to nullptr, @@ -252,71 +256,71 @@ extern "C" { // Set to nullptr if not used. // Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1. WHISPER_API int whisper_ctx_init_openvino_encoder_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - const char * model_path, - const char * device, - const char * cache_dir); + struct whisper_context *ctx, + struct whisper_state *state, + const char *model_path, + const char *device, + const char *cache_dir); WHISPER_API int whisper_ctx_init_openvino_encoder( - struct whisper_context * ctx, - const char * model_path, - const char * device, - const char * cache_dir); + struct whisper_context *ctx, + const char *model_path, + const char *device, + const char *cache_dir); // Frees all allocated memory - WHISPER_API void whisper_free (struct whisper_context * ctx); - WHISPER_API void whisper_free_state(struct whisper_state * state); - WHISPER_API void whisper_free_params(struct whisper_full_params * params); - WHISPER_API void whisper_free_context_params(struct whisper_context_params * params); + WHISPER_API void whisper_free(struct whisper_context *ctx); + WHISPER_API void whisper_free_state(struct whisper_state *state); + WHISPER_API void whisper_free_params(struct whisper_full_params *params); + WHISPER_API void whisper_free_context_params(struct whisper_context_params *params); // Convert RAW PCM audio to log mel spectrogram. // The resulting spectrogram is stored inside the default state of the provided whisper context. // Returns 0 on success WHISPER_API int whisper_pcm_to_mel( - struct whisper_context * ctx, - const float * samples, - int n_samples, - int n_threads); + struct whisper_context *ctx, + const float *samples, + int n_samples, + int n_threads); WHISPER_API int whisper_pcm_to_mel_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - const float * samples, - int n_samples, - int n_threads); + struct whisper_context *ctx, + struct whisper_state *state, + const float *samples, + int n_samples, + int n_threads); // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. // n_mel must be 80 // Returns 0 on success WHISPER_API int whisper_set_mel( - struct whisper_context * ctx, - const float * data, - int n_len, - int n_mel); + struct whisper_context *ctx, + const float *data, + int n_len, + int n_mel); WHISPER_API int whisper_set_mel_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - const float * data, - int n_len, - int n_mel); + struct whisper_context *ctx, + struct whisper_state *state, + const float *data, + int n_len, + int n_mel); // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context. // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // offset can be used to specify the offset of the first frame in the spectrogram. // Returns 0 on success WHISPER_API int whisper_encode( - struct whisper_context * ctx, - int offset, - int n_threads); + struct whisper_context *ctx, + int offset, + int n_threads); WHISPER_API int whisper_encode_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - int offset, - int n_threads); + struct whisper_context *ctx, + struct whisper_state *state, + int offset, + int n_threads); // Run the Whisper decoder to obtain the logits and probabilities for the next token. // Make sure to call whisper_encode() first. @@ -325,19 +329,19 @@ extern "C" { // Returns 0 on success // TODO: add support for multiple decoders WHISPER_API int whisper_decode( - struct whisper_context * ctx, - const whisper_token * tokens, - int n_tokens, - int n_past, - int n_threads); + struct whisper_context *ctx, + const whisper_token *tokens, + int n_tokens, + int n_past, + int n_threads); WHISPER_API int whisper_decode_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - const whisper_token * tokens, - int n_tokens, - int n_past, - int n_threads); + struct whisper_context *ctx, + struct whisper_state *state, + const whisper_token *tokens, + int n_tokens, + int n_past, + int n_threads); // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. @@ -345,14 +349,14 @@ extern "C" { // Returns a negative number on failure - the number of tokens that would have been returned // TODO: not sure if correct WHISPER_API int whisper_tokenize( - struct whisper_context * ctx, - const char * text, - whisper_token * tokens, - int n_max_tokens); + struct whisper_context *ctx, + const char *text, + whisper_token *tokens, + int n_max_tokens); // Return the number of tokens in the provided text // Equivalent to: -whisper_tokenize(ctx, text, NULL, 0) - int whisper_token_count(struct whisper_context * ctx, const char * text); + int whisper_token_count(struct whisper_context *ctx, const char *text); // Largest language id (i.e. number of available languages - 1) WHISPER_API int whisper_lang_max_id(void); @@ -361,13 +365,13 @@ extern "C" { // Examples: // "de" -> 2 // "german" -> 2 - WHISPER_API int whisper_lang_id(const char * lang); + WHISPER_API int whisper_lang_id(const char *lang); // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found - WHISPER_API const char * whisper_lang_str(int id); + WHISPER_API const char *whisper_lang_str(int id); // Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found - WHISPER_API const char * whisper_lang_str_full(int id); + WHISPER_API const char *whisper_lang_str_full(int id); // Use mel data at offset_ms to try and auto-detect the spoken language // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first @@ -376,83 +380,84 @@ extern "C" { // The array must be whisper_lang_max_id() + 1 in size // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 WHISPER_API int whisper_lang_auto_detect( - struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs); + struct whisper_context *ctx, + int offset_ms, + int n_threads, + float *lang_probs); WHISPER_API int whisper_lang_auto_detect_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - int offset_ms, - int n_threads, - float * lang_probs); - - WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length - WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length - WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); - WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx); - - WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx); - WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx); - WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx); - WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx); - WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx); - WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx); - WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx); - WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx); - WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx); - WHISPER_API int whisper_model_ftype (struct whisper_context * ctx); - WHISPER_API int whisper_model_type (struct whisper_context * ctx); + struct whisper_context *ctx, + struct whisper_state *state, + int offset_ms, + int n_threads, + float *lang_probs); + + WHISPER_API int whisper_n_len(struct whisper_context *ctx); // mel length + WHISPER_API int whisper_n_len_from_state(struct whisper_state *state); // mel length + WHISPER_API int whisper_n_vocab(struct whisper_context *ctx); + WHISPER_API int whisper_n_text_ctx(struct whisper_context *ctx); + WHISPER_API int whisper_n_audio_ctx(struct whisper_context *ctx); + WHISPER_API int whisper_is_multilingual(struct whisper_context *ctx); + + WHISPER_API int whisper_model_n_vocab(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_audio_ctx(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_audio_state(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_audio_head(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_audio_layer(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_text_ctx(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_text_state(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_text_head(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_text_layer(struct whisper_context *ctx); + WHISPER_API int whisper_model_n_mels(struct whisper_context *ctx); + WHISPER_API int whisper_model_ftype(struct whisper_context *ctx); + WHISPER_API int whisper_model_type(struct whisper_context *ctx); // Token logits obtained from the last call to whisper_decode() // The logits for the last token are stored in the last row // Rows: n_tokens // Cols: n_vocab - WHISPER_API float * whisper_get_logits (struct whisper_context * ctx); - WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state); + WHISPER_API float *whisper_get_logits(struct whisper_context *ctx); + WHISPER_API float *whisper_get_logits_from_state(struct whisper_state *state); // Token Id -> String. Uses the vocabulary in the provided context - WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); - WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx); - + WHISPER_API const char *whisper_token_to_str(struct whisper_context *ctx, whisper_token token); + WHISPER_API const char *whisper_model_type_readable(struct whisper_context *ctx); // Special tokens - WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); + WHISPER_API whisper_token whisper_token_eot(struct whisper_context *ctx); + WHISPER_API whisper_token whisper_token_sot(struct whisper_context *ctx); + WHISPER_API whisper_token whisper_token_solm(struct whisper_context *ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context *ctx); + WHISPER_API whisper_token whisper_token_nosp(struct whisper_context *ctx); + WHISPER_API whisper_token whisper_token_not(struct whisper_context *ctx); + WHISPER_API whisper_token whisper_token_beg(struct whisper_context *ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context *ctx, int lang_id); // Task tokens - WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx); - WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_translate(struct whisper_context *ctx); + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context *ctx); // Performance information from the default state. - struct whisper_timings { + struct whisper_timings + { float sample_ms; float encode_ms; float decode_ms; float batchd_ms; float prompt_ms; }; - WHISPER_API struct whisper_timings * whisper_get_timings(struct whisper_context * ctx); - WHISPER_API void whisper_print_timings(struct whisper_context * ctx); - WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + WHISPER_API struct whisper_timings *whisper_get_timings(struct whisper_context *ctx); + WHISPER_API void whisper_print_timings(struct whisper_context *ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context *ctx); // Print system information - WHISPER_API const char * whisper_print_system_info(void); + WHISPER_API const char *whisper_print_system_info(void); //////////////////////////////////////////////////////////////////////////// // Available sampling strategies - enum whisper_sampling_strategy { + enum whisper_sampling_strategy + { WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreedyDecoder WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder }; @@ -460,157 +465,160 @@ extern "C" { // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); + typedef void (*whisper_new_segment_callback)(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data); // Progress callback - typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data); + typedef void (*whisper_progress_callback)(struct whisper_context *ctx, struct whisper_state *state, int progress, void *user_data); // Encoder begin callback // If not NULL, called before the encoder starts // If it returns false, the computation is aborted - typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context *ctx, struct whisper_state *state, void *user_data); // Logits filter callback // Can be used to modify the logits before sampling // If not NULL, called after applying temperature to logits typedef void (*whisper_logits_filter_callback)( - struct whisper_context * ctx, - struct whisper_state * state, - const whisper_token_data * tokens, - int n_tokens, - float * logits, - void * user_data); + struct whisper_context *ctx, + struct whisper_state *state, + const whisper_token_data *tokens, + int n_tokens, + float *logits, + void *user_data); // Parameters for the whisper_full() function // If you change the order or add new parameters, make sure to update the default values in whisper.cpp: // whisper_full_default_params() - struct whisper_full_params { + struct whisper_full_params + { enum whisper_sampling_strategy strategy; int n_threads; - int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder - int offset_ms; // start offset in ms - int duration_ms; // audio duration to process in ms + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms bool translate; - bool no_context; // do not use past transcription (if any) as initial prompt for the decoder - bool no_timestamps; // do not generate timestamps - bool single_segment; // force single segment output (useful for streaming) - bool print_special; // print special tokens (e.g. , , , etc.) - bool print_progress; // print progress information - bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) - bool print_timestamps; // print timestamps for each text segment when printing realtime + bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool no_timestamps; // do not generate timestamps + bool single_segment; // force single segment output (useful for streaming) + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime // [EXPERIMENTAL] token-level timestamps - bool token_timestamps; // enable token-level timestamps - float thold_pt; // timestamp token probability threshold (~0.01) - float thold_ptsum; // timestamp token sum probability threshold (~0.01) - int max_len; // max segment length in characters - bool split_on_word; // split on word rather than on token (when used with max_len) - int max_tokens; // max tokens per segment (0 = no limit) + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + bool split_on_word; // split on word rather than on token (when used with max_len) + int max_tokens; // max tokens per segment (0 = no limit) // [EXPERIMENTAL] speed-up techniques // note: these can significantly reduce the quality of the output - bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) - int audio_ctx; // overwrite the audio context size (0 = use default) + bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) + int audio_ctx; // overwrite the audio context size (0 = use default) // [EXPERIMENTAL] [TDRZ] tinydiarize - bool tdrz_enable; // enable tinydiarize speaker turn detection + bool tdrz_enable; // enable tinydiarize speaker turn detection // A regular expression that matches tokens to suppress - const char * suppress_regex; + const char *suppress_regex; // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call // use whisper_tokenize() to convert text to tokens // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224) - const char * initial_prompt; - const whisper_token * prompt_tokens; + const char *initial_prompt; + const whisper_token *prompt_tokens; int prompt_n_tokens; // for auto-detection, set to nullptr, "" or "auto" - const char * language; + const char *language; bool detect_language; // common decoding parameters: bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 bool suppress_nst; // non-speech tokens, ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 - float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 - float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 - float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 // fallback parameters // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 float temperature_inc; - float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" float logprob_thold; float no_speech_thold; - struct { - int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 + struct + { + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 } greedy; - struct { - int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 + struct + { + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf } beam_search; // called for every newly generated text segment whisper_new_segment_callback new_segment_callback; - void * new_segment_callback_user_data; + void *new_segment_callback_user_data; // called on each progress update whisper_progress_callback progress_callback; - void * progress_callback_user_data; + void *progress_callback_user_data; // called each time before the encoder starts whisper_encoder_begin_callback encoder_begin_callback; - void * encoder_begin_callback_user_data; + void *encoder_begin_callback_user_data; // called each time before ggml computation starts ggml_abort_callback abort_callback; - void * abort_callback_user_data; + void *abort_callback_user_data; // called by each decoder to filter obtained logits whisper_logits_filter_callback logits_filter_callback; - void * logits_filter_callback_user_data; + void *logits_filter_callback_user_data; - const whisper_grammar_element ** grammar_rules; - size_t n_grammar_rules; - size_t i_start_rule; - float grammar_penalty; + const whisper_grammar_element **grammar_rules; + size_t n_grammar_rules; + size_t i_start_rule; + float grammar_penalty; // Voice Activity Detection (VAD) params - bool vad; // Enable VAD - const char * vad_model_path; // Path to VAD model + bool vad; // Enable VAD + const char *vad_model_path; // Path to VAD model whisper_vad_params vad_params; }; // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() - WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(void); - WHISPER_API struct whisper_context_params whisper_context_default_params (void); + WHISPER_API struct whisper_context_params *whisper_context_default_params_by_ref(void); + WHISPER_API struct whisper_context_params whisper_context_default_params(void); - WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); - WHISPER_API struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy); + WHISPER_API struct whisper_full_params *whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); + WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Not thread safe for same context // Uses the specified decoding strategy to obtain the text. WHISPER_API int whisper_full( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples); + struct whisper_context *ctx, + struct whisper_full_params params, + const float *samples, + int n_samples); WHISPER_API int whisper_full_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - struct whisper_full_params params, - const float * samples, - int n_samples); + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + const float *samples, + int n_samples); // Split the input audio in chunks and process each chunk separately using whisper_full_with_state() // Result is stored in the default state of the context @@ -618,57 +626,64 @@ extern "C" { // It seems this approach can offer some speedup in some cases. // However, the transcription accuracy can be worse at the beginning and end of each chunk. WHISPER_API int whisper_full_parallel( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples, - int n_processors); + struct whisper_context *ctx, + struct whisper_full_params params, + const float *samples, + int n_samples, + int n_processors); + + WHISPER_API int whisper_full_batch_parallel(struct whisper_context *ctx, + struct whisper_full_params params, + const float *const *batches, + const int *size_per_batch, + int n_batches, + int n_processors); // Number of generated text segments // A segment can be a few words, a sentence, or even a paragraph. - WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx); - WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state); + WHISPER_API int whisper_full_n_segments(struct whisper_context *ctx); + WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state *state); // Language id associated with the context's default state - WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + WHISPER_API int whisper_full_lang_id(struct whisper_context *ctx); // Language id associated with the provided state - WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state); + WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state *state); // Get the start and end time of the specified segment - WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment); - WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t0(struct whisper_context *ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state *state, int i_segment); - WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); - WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1(struct whisper_context *ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state *state, int i_segment); // Get whether the next segment is predicted as a speaker turn - WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); - WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment); + WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context *ctx, int i_segment); + WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state *state, int i_segment); // Get the text of the specified segment - WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); - WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); + WHISPER_API const char *whisper_full_get_segment_text(struct whisper_context *ctx, int i_segment); + WHISPER_API const char *whisper_full_get_segment_text_from_state(struct whisper_state *state, int i_segment); // Get number of tokens in the specified segment - WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment); - WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment); + WHISPER_API int whisper_full_n_tokens(struct whisper_context *ctx, int i_segment); + WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state *state, int i_segment); // Get the token text of the specified token in the specified segment - WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token); + WHISPER_API const char *whisper_full_get_token_text(struct whisper_context *ctx, int i_segment, int i_token); + WHISPER_API const char *whisper_full_get_token_text_from_state(struct whisper_context *ctx, struct whisper_state *state, int i_segment, int i_token); - WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id(struct whisper_context *ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state *state, int i_segment, int i_token); // Get token data for the specified token in the specified segment // This contains probabilities, timestamps, etc. - WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context *ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state *state, int i_segment, int i_token); // Get the probability of the specified token in the specified segment - WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); - WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token); + WHISPER_API float whisper_full_get_token_p(struct whisper_context *ctx, int i_segment, int i_token); + WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state *state, int i_segment, int i_token); // // Voice Activity Detection (VAD) @@ -678,61 +693,62 @@ extern "C" { WHISPER_API struct whisper_vad_params whisper_vad_default_params(void); - struct whisper_vad_context_params { - int n_threads; // The number of threads to use for processing. - bool use_gpu; - int gpu_device; // CUDA device + struct whisper_vad_context_params + { + int n_threads; // The number of threads to use for processing. + bool use_gpu; + int gpu_device; // CUDA device }; WHISPER_API struct whisper_vad_context_params whisper_vad_default_context_params(void); - WHISPER_API struct whisper_vad_context * whisper_vad_init_from_file_with_params(const char * path_model, struct whisper_vad_context_params params); - WHISPER_API struct whisper_vad_context * whisper_vad_init_with_params (struct whisper_model_loader * loader, struct whisper_vad_context_params params); + WHISPER_API struct whisper_vad_context *whisper_vad_init_from_file_with_params(const char *path_model, struct whisper_vad_context_params params); + WHISPER_API struct whisper_vad_context *whisper_vad_init_with_params(struct whisper_model_loader *loader, struct whisper_vad_context_params params); WHISPER_API bool whisper_vad_detect_speech( - struct whisper_vad_context * vctx, - const float * samples, - int n_samples); + struct whisper_vad_context *vctx, + const float *samples, + int n_samples); - WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context * vctx); - WHISPER_API float * whisper_vad_probs (struct whisper_vad_context * vctx); + WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context *vctx); + WHISPER_API float *whisper_vad_probs(struct whisper_vad_context *vctx); struct whisper_vad_segments; - WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_probs( - struct whisper_vad_context * vctx, - struct whisper_vad_params params); + WHISPER_API struct whisper_vad_segments *whisper_vad_segments_from_probs( + struct whisper_vad_context *vctx, + struct whisper_vad_params params); - WHISPER_API struct whisper_vad_segments * whisper_vad_segments_from_samples( - struct whisper_vad_context * vctx, - struct whisper_vad_params params, - const float * samples, - int n_samples); + WHISPER_API struct whisper_vad_segments *whisper_vad_segments_from_samples( + struct whisper_vad_context *vctx, + struct whisper_vad_params params, + const float *samples, + int n_samples); - WHISPER_API int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments); + WHISPER_API int whisper_vad_segments_n_segments(struct whisper_vad_segments *segments); - WHISPER_API float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment); - WHISPER_API float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment); + WHISPER_API float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments *segments, int i_segment); + WHISPER_API float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments *segments, int i_segment); - WHISPER_API void whisper_vad_free_segments(struct whisper_vad_segments * segments); - WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx); + WHISPER_API void whisper_vad_free_segments(struct whisper_vad_segments *segments); + WHISPER_API void whisper_vad_free(struct whisper_vad_context *ctx); //////////////////////////////////////////////////////////////////////////// // Temporary helpers needed for exposing ggml interface - WHISPER_API int whisper_bench_memcpy (int n_threads); - WHISPER_API const char * whisper_bench_memcpy_str (int n_threads); - WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads); - WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); + WHISPER_API int whisper_bench_memcpy(int n_threads); + WHISPER_API const char *whisper_bench_memcpy_str(int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads); + WHISPER_API const char *whisper_bench_ggml_mul_mat_str(int n_threads); // Control logging output; default behavior is to print to stderr - WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data); + WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void *user_data); // Get the no_speech probability for the specified segment - WHISPER_API float whisper_full_get_segment_no_speech_prob (struct whisper_context * ctx, int i_segment); - WHISPER_API float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment); + WHISPER_API float whisper_full_get_segment_no_speech_prob(struct whisper_context *ctx, int i_segment); + WHISPER_API float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state *state, int i_segment); #ifdef __cplusplus } #endif diff --git a/src/whisper.cpp b/src/whisper.cpp index d99dd7be68c..e8caf1db285 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -37,65 +37,88 @@ #include #if defined(WHISPER_BIG_ENDIAN) -template -static T byteswap(T value) { +template +static T byteswap(T value) +{ T value_swapped; - char * source = reinterpret_cast(&value); - char * target = reinterpret_cast(&value_swapped); + char *source = reinterpret_cast(&value); + char *target = reinterpret_cast(&value_swapped); int size = sizeof(T); - for (int i = 0; i < size; i++) { + for (int i = 0; i < size; i++) + { target[size - 1 - i] = source[i]; } return value_swapped; } -template -static void byteswap_tensor_data(ggml_tensor * tensor) { - T * datum = reinterpret_cast(tensor->data); - for (int i = 0; i < ggml_nelements(tensor); i++) { +template +static void byteswap_tensor_data(ggml_tensor *tensor) +{ + T *datum = reinterpret_cast(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) + { datum[i] = byteswap(datum[i]); } } -static void byteswap_tensor(ggml_tensor * tensor) { - switch (tensor->type) { - case GGML_TYPE_I16: { - byteswap_tensor_data(tensor); - break; - } - case GGML_TYPE_F16: { - byteswap_tensor_data(tensor); - break; - } - case GGML_TYPE_I32: { - byteswap_tensor_data(tensor); - break; - } - case GGML_TYPE_F32: { - byteswap_tensor_data(tensor); - break; - } - default: { // GML_TYPE_I8 - break; - } +static void byteswap_tensor(ggml_tensor *tensor) +{ + switch (tensor->type) + { + case GGML_TYPE_I16: + { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F16: + { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_I32: + { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F32: + { + byteswap_tensor_data(tensor); + break; + } + default: + { // GML_TYPE_I8 + break; + } } } #define BYTESWAP_VALUE(d) d = byteswap(d) -#define BYTESWAP_FILTERS(f) \ - do { \ - for (auto & datum : f.data) { \ - datum = byteswap(datum); \ - } \ +#define BYTESWAP_FILTERS(f) \ + do \ + { \ + for (auto &datum : f.data) \ + { \ + datum = byteswap(datum); \ + } \ } while (0) #define BYTESWAP_TENSOR(t) \ - do { \ + do \ + { \ byteswap_tensor(t); \ } while (0) #else -#define BYTESWAP_VALUE(d) do {} while (0) -#define BYTESWAP_FILTERS(f) do {} while (0) -#define BYTESWAP_TENSOR(t) do {} while (0) +#define BYTESWAP_VALUE(d) \ + do \ + { \ + } while (0) +#define BYTESWAP_FILTERS(f) \ + do \ + { \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do \ + { \ + } while (0) #endif #ifdef __GNUC__ @@ -113,15 +136,15 @@ static void byteswap_tensor(ggml_tensor * tensor) { // WHISPER_ATTRIBUTE_FORMAT(2, 3) -static void whisper_log_internal (ggml_log_level level, const char * format, ...); -static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data); +static void whisper_log_internal(ggml_log_level level, const char *format, ...); +static void whisper_log_callback_default(ggml_log_level level, const char *text, void *user_data); #define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) -#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) -#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) +#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN, __VA_ARGS__) +#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO, __VA_ARGS__) // define this to enable verbose trace logging - useful for debugging purposes -//#define WHISPER_DEBUG +// #define WHISPER_DEBUG #if defined(WHISPER_DEBUG) #define WHISPER_LOG_DEBUG(...) whisper_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) @@ -129,18 +152,21 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text #define WHISPER_LOG_DEBUG(...) #endif -#define WHISPER_ASSERT(x) \ - do { \ - if (!(x)) { \ +#define WHISPER_ASSERT(x) \ + do \ + { \ + if (!(x)) \ + { \ WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - abort(); \ - } \ + abort(); \ + } \ } while (0) #define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 -static std::string format(const char * fmt, ...) { +static std::string format(const char *fmt, ...) +{ va_list ap; va_list ap2; va_start(ap, fmt); @@ -160,21 +186,24 @@ static std::string format(const char * fmt, ...) { // static bool ggml_graph_compute_helper( - struct ggml_cgraph * graph, - int n_threads, - ggml_abort_callback abort_callback, - void * abort_callback_data) { - ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + struct ggml_cgraph *graph, + int n_threads, + ggml_abort_callback abort_callback, + void *abort_callback_data) +{ + ggml_backend_ptr backend{ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr)}; - auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + auto *reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); - auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); - if (set_abort_callback_fn) { + auto *set_abort_callback_fn = (ggml_backend_set_abort_callback_t)ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) + { set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data); } - auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (ggml_backend_set_n_threads_fn) { + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t)ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) + { ggml_backend_set_n_threads_fn(backend.get(), n_threads); } @@ -182,24 +211,28 @@ static bool ggml_graph_compute_helper( } static bool ggml_graph_compute_helper( - ggml_backend_sched_t sched, - struct ggml_cgraph * graph, - int n_threads, - bool sched_reset = true) { - for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_sched_t sched, + struct ggml_cgraph *graph, + int n_threads, + bool sched_reset = true) +{ + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) + { ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); ggml_backend_dev_t dev = ggml_backend_get_device(backend); ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; - auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); - if (fn_set_n_threads) { + auto *fn_set_n_threads = (ggml_backend_set_n_threads_t)ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (fn_set_n_threads) + { fn_set_n_threads(backend, n_threads); } } const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS); - if (!t || sched_reset) { + if (!t || sched_reset) + { ggml_backend_sched_reset(sched); } @@ -208,52 +241,61 @@ static bool ggml_graph_compute_helper( // TODO: move these functions to ggml-base with support for ggml-backend? -static ggml_tensor * whisper_set_f32(struct ggml_tensor * t, float v) { +static ggml_tensor *whisper_set_f32(struct ggml_tensor *t, float v) +{ GGML_ASSERT(t->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(t)); size_t nels = ggml_nelements(t); - for (size_t i = 0; i < nels; ++i) { - ((float *) t->data)[i] = v; + for (size_t i = 0; i < nels; ++i) + { + ((float *)t->data)[i] = v; } return t; } -static ggml_tensor * whisper_set_i32(struct ggml_tensor * t, int32_t v) { +static ggml_tensor *whisper_set_i32(struct ggml_tensor *t, int32_t v) +{ GGML_ASSERT(t->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(t)); size_t nels = ggml_nelements(t); - for (size_t i = 0; i < nels; ++i) { - ((int32_t *) t->data)[i] = v; + for (size_t i = 0; i < nels; ++i) + { + ((int32_t *)t->data)[i] = v; } return t; } -static float whisper_get_f32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { +static float whisper_get_f32_nd(const struct ggml_tensor *t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) +{ GGML_ASSERT(t->type == GGML_TYPE_F32); - void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; - return *(float *) data; + void *data = (char *)t->data + i0 * t->nb[0] + i1 * t->nb[1] + i2 * t->nb[2] + i3 * t->nb[3]; + return *(float *)data; } -static void whisper_set_f32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) { +static void whisper_set_f32_nd(struct ggml_tensor *t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, float v) +{ GGML_ASSERT(t->type == GGML_TYPE_F32); - void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; - *(float *) data = v; + void *data = (char *)t->data + i0 * t->nb[0] + i1 * t->nb[1] + i2 * t->nb[2] + i3 * t->nb[3]; + *(float *)data = v; } -static int32_t whisper_get_i32_nd(const struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { +static int32_t whisper_get_i32_nd(const struct ggml_tensor *t, int64_t i0, int64_t i1, int64_t i2, int64_t i3) +{ GGML_ASSERT(t->type == GGML_TYPE_I32); - void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; - return *(int32_t *) data; + void *data = (char *)t->data + i0 * t->nb[0] + i1 * t->nb[1] + i2 * t->nb[2] + i3 * t->nb[3]; + return *(int32_t *)data; } -static void whisper_set_i32_nd(struct ggml_tensor * t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) { +static void whisper_set_i32_nd(struct ggml_tensor *t, int64_t i0, int64_t i1, int64_t i2, int64_t i3, int32_t v) +{ GGML_ASSERT(t->type == GGML_TYPE_I32); - void * data = (char *) t->data + i0*t->nb[0] + i1*t->nb[1] + i2*t->nb[2] + i3*t->nb[3]; - *(int32_t *) data = v; + void *data = (char *)t->data + i0 * t->nb[0] + i1 * t->nb[1] + i2 * t->nb[2] + i3 * t->nb[3]; + *(int32_t *)data = v; } // available whisper models -enum e_model { +enum e_model +{ MODEL_UNKNOWN, MODEL_TINY, MODEL_BASE, @@ -263,149 +305,450 @@ enum e_model { }; static const std::map g_model_name = { - { MODEL_UNKNOWN, "unknown" }, - { MODEL_TINY, "tiny" }, - { MODEL_BASE, "base" }, - { MODEL_SMALL, "small" }, - { MODEL_MEDIUM, "medium" }, - { MODEL_LARGE, "large" }, + {MODEL_UNKNOWN, "unknown"}, + {MODEL_TINY, "tiny"}, + {MODEL_BASE, "base"}, + {MODEL_SMALL, "small"}, + {MODEL_MEDIUM, "medium"}, + {MODEL_LARGE, "large"}, }; static const std::map> g_lang = { - { "en", { 0, "english", } }, - { "zh", { 1, "chinese", } }, - { "de", { 2, "german", } }, - { "es", { 3, "spanish", } }, - { "ru", { 4, "russian", } }, - { "ko", { 5, "korean", } }, - { "fr", { 6, "french", } }, - { "ja", { 7, "japanese", } }, - { "pt", { 8, "portuguese", } }, - { "tr", { 9, "turkish", } }, - { "pl", { 10, "polish", } }, - { "ca", { 11, "catalan", } }, - { "nl", { 12, "dutch", } }, - { "ar", { 13, "arabic", } }, - { "sv", { 14, "swedish", } }, - { "it", { 15, "italian", } }, - { "id", { 16, "indonesian", } }, - { "hi", { 17, "hindi", } }, - { "fi", { 18, "finnish", } }, - { "vi", { 19, "vietnamese", } }, - { "he", { 20, "hebrew", } }, - { "uk", { 21, "ukrainian", } }, - { "el", { 22, "greek", } }, - { "ms", { 23, "malay", } }, - { "cs", { 24, "czech", } }, - { "ro", { 25, "romanian", } }, - { "da", { 26, "danish", } }, - { "hu", { 27, "hungarian", } }, - { "ta", { 28, "tamil", } }, - { "no", { 29, "norwegian", } }, - { "th", { 30, "thai", } }, - { "ur", { 31, "urdu", } }, - { "hr", { 32, "croatian", } }, - { "bg", { 33, "bulgarian", } }, - { "lt", { 34, "lithuanian", } }, - { "la", { 35, "latin", } }, - { "mi", { 36, "maori", } }, - { "ml", { 37, "malayalam", } }, - { "cy", { 38, "welsh", } }, - { "sk", { 39, "slovak", } }, - { "te", { 40, "telugu", } }, - { "fa", { 41, "persian", } }, - { "lv", { 42, "latvian", } }, - { "bn", { 43, "bengali", } }, - { "sr", { 44, "serbian", } }, - { "az", { 45, "azerbaijani", } }, - { "sl", { 46, "slovenian", } }, - { "kn", { 47, "kannada", } }, - { "et", { 48, "estonian", } }, - { "mk", { 49, "macedonian", } }, - { "br", { 50, "breton", } }, - { "eu", { 51, "basque", } }, - { "is", { 52, "icelandic", } }, - { "hy", { 53, "armenian", } }, - { "ne", { 54, "nepali", } }, - { "mn", { 55, "mongolian", } }, - { "bs", { 56, "bosnian", } }, - { "kk", { 57, "kazakh", } }, - { "sq", { 58, "albanian", } }, - { "sw", { 59, "swahili", } }, - { "gl", { 60, "galician", } }, - { "mr", { 61, "marathi", } }, - { "pa", { 62, "punjabi", } }, - { "si", { 63, "sinhala", } }, - { "km", { 64, "khmer", } }, - { "sn", { 65, "shona", } }, - { "yo", { 66, "yoruba", } }, - { "so", { 67, "somali", } }, - { "af", { 68, "afrikaans", } }, - { "oc", { 69, "occitan", } }, - { "ka", { 70, "georgian", } }, - { "be", { 71, "belarusian", } }, - { "tg", { 72, "tajik", } }, - { "sd", { 73, "sindhi", } }, - { "gu", { 74, "gujarati", } }, - { "am", { 75, "amharic", } }, - { "yi", { 76, "yiddish", } }, - { "lo", { 77, "lao", } }, - { "uz", { 78, "uzbek", } }, - { "fo", { 79, "faroese", } }, - { "ht", { 80, "haitian creole", } }, - { "ps", { 81, "pashto", } }, - { "tk", { 82, "turkmen", } }, - { "nn", { 83, "nynorsk", } }, - { "mt", { 84, "maltese", } }, - { "sa", { 85, "sanskrit", } }, - { "lb", { 86, "luxembourgish", } }, - { "my", { 87, "myanmar", } }, - { "bo", { 88, "tibetan", } }, - { "tl", { 89, "tagalog", } }, - { "mg", { 90, "malagasy", } }, - { "as", { 91, "assamese", } }, - { "tt", { 92, "tatar", } }, - { "haw", { 93, "hawaiian", } }, - { "ln", { 94, "lingala", } }, - { "ha", { 95, "hausa", } }, - { "ba", { 96, "bashkir", } }, - { "jw", { 97, "javanese", } }, - { "su", { 98, "sundanese", } }, - { "yue", { 99, "cantonese", } }, + {"en", { + 0, + "english", + }}, + {"zh", { + 1, + "chinese", + }}, + {"de", { + 2, + "german", + }}, + {"es", { + 3, + "spanish", + }}, + {"ru", { + 4, + "russian", + }}, + {"ko", { + 5, + "korean", + }}, + {"fr", { + 6, + "french", + }}, + {"ja", { + 7, + "japanese", + }}, + {"pt", { + 8, + "portuguese", + }}, + {"tr", { + 9, + "turkish", + }}, + {"pl", { + 10, + "polish", + }}, + {"ca", { + 11, + "catalan", + }}, + {"nl", { + 12, + "dutch", + }}, + {"ar", { + 13, + "arabic", + }}, + {"sv", { + 14, + "swedish", + }}, + {"it", { + 15, + "italian", + }}, + {"id", { + 16, + "indonesian", + }}, + {"hi", { + 17, + "hindi", + }}, + {"fi", { + 18, + "finnish", + }}, + {"vi", { + 19, + "vietnamese", + }}, + {"he", { + 20, + "hebrew", + }}, + {"uk", { + 21, + "ukrainian", + }}, + {"el", { + 22, + "greek", + }}, + {"ms", { + 23, + "malay", + }}, + {"cs", { + 24, + "czech", + }}, + {"ro", { + 25, + "romanian", + }}, + {"da", { + 26, + "danish", + }}, + {"hu", { + 27, + "hungarian", + }}, + {"ta", { + 28, + "tamil", + }}, + {"no", { + 29, + "norwegian", + }}, + {"th", { + 30, + "thai", + }}, + {"ur", { + 31, + "urdu", + }}, + {"hr", { + 32, + "croatian", + }}, + {"bg", { + 33, + "bulgarian", + }}, + {"lt", { + 34, + "lithuanian", + }}, + {"la", { + 35, + "latin", + }}, + {"mi", { + 36, + "maori", + }}, + {"ml", { + 37, + "malayalam", + }}, + {"cy", { + 38, + "welsh", + }}, + {"sk", { + 39, + "slovak", + }}, + {"te", { + 40, + "telugu", + }}, + {"fa", { + 41, + "persian", + }}, + {"lv", { + 42, + "latvian", + }}, + {"bn", { + 43, + "bengali", + }}, + {"sr", { + 44, + "serbian", + }}, + {"az", { + 45, + "azerbaijani", + }}, + {"sl", { + 46, + "slovenian", + }}, + {"kn", { + 47, + "kannada", + }}, + {"et", { + 48, + "estonian", + }}, + {"mk", { + 49, + "macedonian", + }}, + {"br", { + 50, + "breton", + }}, + {"eu", { + 51, + "basque", + }}, + {"is", { + 52, + "icelandic", + }}, + {"hy", { + 53, + "armenian", + }}, + {"ne", { + 54, + "nepali", + }}, + {"mn", { + 55, + "mongolian", + }}, + {"bs", { + 56, + "bosnian", + }}, + {"kk", { + 57, + "kazakh", + }}, + {"sq", { + 58, + "albanian", + }}, + {"sw", { + 59, + "swahili", + }}, + {"gl", { + 60, + "galician", + }}, + {"mr", { + 61, + "marathi", + }}, + {"pa", { + 62, + "punjabi", + }}, + {"si", { + 63, + "sinhala", + }}, + {"km", { + 64, + "khmer", + }}, + {"sn", { + 65, + "shona", + }}, + {"yo", { + 66, + "yoruba", + }}, + {"so", { + 67, + "somali", + }}, + {"af", { + 68, + "afrikaans", + }}, + {"oc", { + 69, + "occitan", + }}, + {"ka", { + 70, + "georgian", + }}, + {"be", { + 71, + "belarusian", + }}, + {"tg", { + 72, + "tajik", + }}, + {"sd", { + 73, + "sindhi", + }}, + {"gu", { + 74, + "gujarati", + }}, + {"am", { + 75, + "amharic", + }}, + {"yi", { + 76, + "yiddish", + }}, + {"lo", { + 77, + "lao", + }}, + {"uz", { + 78, + "uzbek", + }}, + {"fo", { + 79, + "faroese", + }}, + {"ht", { + 80, + "haitian creole", + }}, + {"ps", { + 81, + "pashto", + }}, + {"tk", { + 82, + "turkmen", + }}, + {"nn", { + 83, + "nynorsk", + }}, + {"mt", { + 84, + "maltese", + }}, + {"sa", { + 85, + "sanskrit", + }}, + {"lb", { + 86, + "luxembourgish", + }}, + {"my", { + 87, + "myanmar", + }}, + {"bo", { + 88, + "tibetan", + }}, + {"tl", { + 89, + "tagalog", + }}, + {"mg", { + 90, + "malagasy", + }}, + {"as", { + 91, + "assamese", + }}, + {"tt", { + 92, + "tatar", + }}, + {"haw", { + 93, + "hawaiian", + }}, + {"ln", { + 94, + "lingala", + }}, + {"ha", { + 95, + "hausa", + }}, + {"ba", { + 96, + "bashkir", + }}, + {"jw", { + 97, + "javanese", + }}, + {"su", { + 98, + "sundanese", + }}, + {"yue", { + 99, + "cantonese", + }}, }; // [EXPERIMENTAL] Token-level timestamps with DTW -static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} }; -static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} }; -static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} }; -static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} }; -static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} }; -static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} }; -static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} }; -static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} }; -static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} }; -static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} }; -static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} }; -static const whisper_ahead g_aheads_large_v3_turbo[] = { {2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14} }; - -static const std::map g_aheads { - { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } }, - { WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } }, - { WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } }, - { WHISPER_AHEADS_BASE, { 8, g_aheads_base } }, - { WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } }, - { WHISPER_AHEADS_SMALL, { 10, g_aheads_small } }, - { WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } }, - { WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } }, - { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } }, - { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } }, - { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } }, - { WHISPER_AHEADS_LARGE_V3_TURBO, { 6, g_aheads_large_v3_turbo } }, +static const whisper_ahead g_aheads_tiny_en[] = {{1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4}}; +static const whisper_ahead g_aheads_tiny[] = {{2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5}}; +static const whisper_ahead g_aheads_base_en[] = {{3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7}}; +static const whisper_ahead g_aheads_base[] = {{3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6}}; +static const whisper_ahead g_aheads_small_en[] = {{6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4}}; +static const whisper_ahead g_aheads_small[] = {{5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5}}; +static const whisper_ahead g_aheads_medium_en[] = {{11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12}}; +static const whisper_ahead g_aheads_medium[] = {{13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4}}; +static const whisper_ahead g_aheads_large_v1[] = {{9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15}}; +static const whisper_ahead g_aheads_large_v2[] = {{10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15}}; +static const whisper_ahead g_aheads_large_v3[] = {{7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6}}; +static const whisper_ahead g_aheads_large_v3_turbo[] = {{2, 4}, {2, 11}, {3, 3}, {3, 6}, {3, 11}, {3, 14}}; + +static const std::map g_aheads{ + {WHISPER_AHEADS_TINY_EN, {8, g_aheads_tiny_en}}, + {WHISPER_AHEADS_TINY, {6, g_aheads_tiny}}, + {WHISPER_AHEADS_BASE_EN, {5, g_aheads_base_en}}, + {WHISPER_AHEADS_BASE, {8, g_aheads_base}}, + {WHISPER_AHEADS_SMALL_EN, {19, g_aheads_small_en}}, + {WHISPER_AHEADS_SMALL, {10, g_aheads_small}}, + {WHISPER_AHEADS_MEDIUM_EN, {18, g_aheads_medium_en}}, + {WHISPER_AHEADS_MEDIUM, {6, g_aheads_medium}}, + {WHISPER_AHEADS_LARGE_V1, {9, g_aheads_large_v1}}, + {WHISPER_AHEADS_LARGE_V2, {23, g_aheads_large_v2}}, + {WHISPER_AHEADS_LARGE_V3, {10, g_aheads_large_v3}}, + {WHISPER_AHEADS_LARGE_V3_TURBO, {6, g_aheads_large_v3_turbo}}, }; -static std::vector get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head); +static std::vector get_alignment_heads_by_layer(const whisper_context_params &cparams, int il, int32_t n_text_layer, int32_t n_head); -struct whisper_mel { +struct whisper_mel +{ int n_len; int n_len_org; int n_mel; @@ -413,15 +756,17 @@ struct whisper_mel { std::vector data; }; -struct whisper_filters { +struct whisper_filters +{ int32_t n_mel; int32_t n_fft; std::vector data; }; -struct whisper_vocab { - using id = int32_t; +struct whisper_vocab +{ + using id = int32_t; using token = std::string; int n_vocab = 51864; @@ -430,28 +775,31 @@ struct whisper_vocab { std::map id_to_token; // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349 - id token_eot = 50256; - id token_sot = 50257; + id token_eot = 50256; + id token_sot = 50257; // task tokens (used only for multilingual models) - id token_translate = 50357; + id token_translate = 50357; id token_transcribe = 50358; // other special tokens - id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn - id token_prev = 50360; - id token_nosp = 50361; - id token_not = 50362; // no timestamps - id token_beg = 50363; // begin timestamps + id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn + id token_prev = 50360; + id token_nosp = 50361; + id token_not = 50362; // no timestamps + id token_beg = 50363; // begin timestamps - bool is_multilingual() const { + bool is_multilingual() const + { return n_vocab >= 51865; } - int num_languages() const { + int num_languages() const + { return n_vocab - 51765 - (is_multilingual() ? 1 : 0); } }; -struct whisper_segment { +struct whisper_segment +{ int64_t t0; int64_t t1; @@ -463,81 +811,105 @@ struct whisper_segment { bool speaker_turn_next; }; -struct whisper_batch { +struct whisper_batch +{ int32_t n_tokens; - whisper_token * token; - whisper_pos * pos; - int32_t * n_seq_id; // always 1, here for consistency with llama.cpp - whisper_seq_id ** seq_id; // null terminated - int8_t * logits; + whisper_token *token; + whisper_pos *pos; + int32_t *n_seq_id; // always 1, here for consistency with llama.cpp + whisper_seq_id **seq_id; // null terminated + int8_t *logits; }; -static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) { - whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, }; +static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) +{ + whisper_batch batch = { + 0, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + }; - batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens)); - batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens)); - batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); - batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1)); - for (int i = 0; i < n_tokens; ++i) { - batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max); + batch.token = (whisper_token *)malloc(sizeof(whisper_token) * (n_tokens)); + batch.pos = (whisper_pos *)malloc(sizeof(whisper_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *)malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (whisper_seq_id **)malloc(sizeof(whisper_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) + { + batch.seq_id[i] = (whisper_seq_id *)malloc(sizeof(whisper_seq_id) * n_seq_max); } batch.seq_id[n_tokens] = nullptr; - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch.logits = (int8_t *)malloc(sizeof(int8_t) * n_tokens); return batch; } -static void whisper_batch_free(struct whisper_batch batch) { - if (batch.token) free(batch.token); - if (batch.pos) free(batch.pos); - if (batch.n_seq_id) free(batch.n_seq_id); - if (batch.seq_id) { - for (int i = 0; batch.seq_id[i]; ++i) { +static void whisper_batch_free(struct whisper_batch batch) +{ + if (batch.token) + free(batch.token); + if (batch.pos) + free(batch.pos); + if (batch.n_seq_id) + free(batch.n_seq_id); + if (batch.seq_id) + { + for (int i = 0; batch.seq_id[i]; ++i) + { free(batch.seq_id[i]); } free(batch.seq_id); } - if (batch.logits) free(batch.logits); + if (batch.logits) + free(batch.logits); } -static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) { +static void whisper_batch_prep_legacy(whisper_batch &batch, const whisper_token *tokens, int n_tokens, int n_past, int seq_id) +{ batch.n_tokens = n_tokens; - for (int i = 0; i < n_tokens; ++i) { - if (tokens) { + for (int i = 0; i < n_tokens; ++i) + { + if (tokens) + { batch.token[i] = tokens[i]; } - batch.pos [i] = n_past + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i][0] = seq_id; - batch.logits [i] = 0; + batch.pos[i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = seq_id; + batch.logits[i] = 0; } batch.logits[n_tokens - 1] = 1; } // replace std::pair by using customized pair struct (reason: std::pair is very slow) -template -struct whisper_pair { +template +struct whisper_pair +{ A first; B second; // Define a constructor that takes two arguments. - whisper_pair(const A& a, const B& b) : first(a), second(b) {} + whisper_pair(const A &a, const B &b) : first(a), second(b) {} // Define a constructor that takes no argument. whisper_pair() : first(A()), second(B()) {} }; // ggml_backend_sched wrapper for whisper usage -struct whisper_sched { +struct whisper_sched +{ ggml_backend_sched_t sched = nullptr; std::vector meta; }; -static size_t whisper_sched_size(struct whisper_sched & allocr) { +static size_t whisper_sched_size(struct whisper_sched &allocr) +{ size_t size = allocr.meta.size(); - for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) + { ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); } @@ -545,17 +917,19 @@ static size_t whisper_sched_size(struct whisper_sched & allocr) { } // measure the memory usage of a graph and prepare the allocr's internal data buffer -static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector backends, std::function && get_graph) { - auto & sched = allocr.sched; - auto & meta = allocr.meta; +static bool whisper_sched_graph_init(struct whisper_sched &allocr, std::vector backends, std::function &&get_graph) +{ + auto &sched = allocr.sched; + auto &meta = allocr.meta; sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), WHISPER_MAX_NODES, false, true); - meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); + meta.resize(ggml_tensor_overhead() * WHISPER_MAX_NODES + ggml_graph_overhead()); // since there are dependencies between the different graphs, // we need to allocate them instead of only reserving to get the correct compute buffer size - if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) + { // failed to allocate the compute buffer WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); return false; @@ -581,119 +955,125 @@ static bool whisper_sched_graph_init(struct whisper_sched & allocr, std::vector< // } // // default hparams (Whisper tiny) -struct whisper_hparams { - int32_t n_vocab = 51864; - int32_t n_audio_ctx = 1500; +struct whisper_hparams +{ + int32_t n_vocab = 51864; + int32_t n_audio_ctx = 1500; int32_t n_audio_state = 384; - int32_t n_audio_head = 6; + int32_t n_audio_head = 6; int32_t n_audio_layer = 4; - int32_t n_text_ctx = 448; - int32_t n_text_state = 384; - int32_t n_text_head = 6; - int32_t n_text_layer = 4; - int32_t n_mels = 80; - int32_t ftype = 1; - float eps = 1e-5f; + int32_t n_text_ctx = 448; + int32_t n_text_state = 384; + int32_t n_text_head = 6; + int32_t n_text_layer = 4; + int32_t n_mels = 80; + int32_t ftype = 1; + float eps = 1e-5f; }; // audio encoding layer -struct whisper_layer_encoder { +struct whisper_layer_encoder +{ // encoder.blocks.*.attn_ln - struct ggml_tensor * attn_ln_0_w; - struct ggml_tensor * attn_ln_0_b; + struct ggml_tensor *attn_ln_0_w; + struct ggml_tensor *attn_ln_0_b; // encoder.blocks.*.attn.out - struct ggml_tensor * attn_ln_1_w; - struct ggml_tensor * attn_ln_1_b; + struct ggml_tensor *attn_ln_1_w; + struct ggml_tensor *attn_ln_1_b; // encoder.blocks.*.attn.query - struct ggml_tensor * attn_q_w; - struct ggml_tensor * attn_q_b; + struct ggml_tensor *attn_q_w; + struct ggml_tensor *attn_q_b; // encoder.blocks.*.attn.key - struct ggml_tensor * attn_k_w; + struct ggml_tensor *attn_k_w; // encoder.blocks.*.attn.value - struct ggml_tensor * attn_v_w; - struct ggml_tensor * attn_v_b; + struct ggml_tensor *attn_v_w; + struct ggml_tensor *attn_v_b; // encoder.blocks.*.mlp_ln - struct ggml_tensor * mlp_ln_w; - struct ggml_tensor * mlp_ln_b; + struct ggml_tensor *mlp_ln_w; + struct ggml_tensor *mlp_ln_b; // encoder.blocks.*.mlp.0 - struct ggml_tensor * mlp_0_w; - struct ggml_tensor * mlp_0_b; + struct ggml_tensor *mlp_0_w; + struct ggml_tensor *mlp_0_b; // encoder.blocks.*.mlp.2 - struct ggml_tensor * mlp_1_w; - struct ggml_tensor * mlp_1_b; + struct ggml_tensor *mlp_1_w; + struct ggml_tensor *mlp_1_b; }; // token decoding layer -struct whisper_layer_decoder { +struct whisper_layer_decoder +{ // decoder.blocks.*.attn_ln - struct ggml_tensor * attn_ln_0_w; - struct ggml_tensor * attn_ln_0_b; + struct ggml_tensor *attn_ln_0_w; + struct ggml_tensor *attn_ln_0_b; // decoder.blocks.*.attn.out - struct ggml_tensor * attn_ln_1_w; - struct ggml_tensor * attn_ln_1_b; + struct ggml_tensor *attn_ln_1_w; + struct ggml_tensor *attn_ln_1_b; // decoder.blocks.*.attn.query - struct ggml_tensor * attn_q_w; - struct ggml_tensor * attn_q_b; + struct ggml_tensor *attn_q_w; + struct ggml_tensor *attn_q_b; // decoder.blocks.*.attn.key - struct ggml_tensor * attn_k_w; + struct ggml_tensor *attn_k_w; // decoder.blocks.*.attn.value - struct ggml_tensor * attn_v_w; - struct ggml_tensor * attn_v_b; + struct ggml_tensor *attn_v_w; + struct ggml_tensor *attn_v_b; // decoder.blocks.*.cross_attn_ln - struct ggml_tensor * cross_attn_ln_0_w; - struct ggml_tensor * cross_attn_ln_0_b; + struct ggml_tensor *cross_attn_ln_0_w; + struct ggml_tensor *cross_attn_ln_0_b; // decoder.blocks.*.cross_attn.out - struct ggml_tensor * cross_attn_ln_1_w; - struct ggml_tensor * cross_attn_ln_1_b; + struct ggml_tensor *cross_attn_ln_1_w; + struct ggml_tensor *cross_attn_ln_1_b; // decoder.blocks.*.cross_attn.query - struct ggml_tensor * cross_attn_q_w; - struct ggml_tensor * cross_attn_q_b; + struct ggml_tensor *cross_attn_q_w; + struct ggml_tensor *cross_attn_q_b; // decoder.blocks.*.cross_attn.key - struct ggml_tensor * cross_attn_k_w; + struct ggml_tensor *cross_attn_k_w; // decoder.blocks.*.cross_attn.value - struct ggml_tensor * cross_attn_v_w; - struct ggml_tensor * cross_attn_v_b; + struct ggml_tensor *cross_attn_v_w; + struct ggml_tensor *cross_attn_v_b; // decoder.blocks.*.mlp_ln - struct ggml_tensor * mlp_ln_w; - struct ggml_tensor * mlp_ln_b; + struct ggml_tensor *mlp_ln_w; + struct ggml_tensor *mlp_ln_b; // decoder.blocks.*.mlp.0 - struct ggml_tensor * mlp_0_w; - struct ggml_tensor * mlp_0_b; + struct ggml_tensor *mlp_0_w; + struct ggml_tensor *mlp_0_b; // decoder.blocks.*.mlp.2 - struct ggml_tensor * mlp_1_w; - struct ggml_tensor * mlp_1_b; + struct ggml_tensor *mlp_1_w; + struct ggml_tensor *mlp_1_b; }; -struct whisper_kv_cell { +struct whisper_kv_cell +{ whisper_pos pos = -1; std::set seq_id; - bool has_seq_id(const whisper_seq_id & id) const { + bool has_seq_id(const whisper_seq_id &id) const + { return seq_id.find(id) != seq_id.end(); } }; -struct whisper_kv_cache { +struct whisper_kv_cache +{ uint32_t head = 0; uint32_t size = 0; @@ -702,44 +1082,45 @@ struct whisper_kv_cache { std::vector cells; - struct ggml_tensor * k; - struct ggml_tensor * v; + struct ggml_tensor *k; + struct ggml_tensor *v; ggml_backend_buffer_t buffer = nullptr; std::vector ctx_buf; }; -struct whisper_model { +struct whisper_model +{ e_model type = MODEL_UNKNOWN; whisper_hparams hparams; whisper_filters filters; // encoder.positional_embedding - struct ggml_tensor * e_pe; + struct ggml_tensor *e_pe; // encoder.conv1 - struct ggml_tensor * e_conv_1_w; - struct ggml_tensor * e_conv_1_b; + struct ggml_tensor *e_conv_1_w; + struct ggml_tensor *e_conv_1_b; // encoder.conv2 - struct ggml_tensor * e_conv_2_w; - struct ggml_tensor * e_conv_2_b; + struct ggml_tensor *e_conv_2_w; + struct ggml_tensor *e_conv_2_b; // encoder.ln_post - struct ggml_tensor * e_ln_w; - struct ggml_tensor * e_ln_b; + struct ggml_tensor *e_ln_w; + struct ggml_tensor *e_ln_b; // decoder.positional_embedding - struct ggml_tensor * d_pe; + struct ggml_tensor *d_pe; // decoder.token_embedding - struct ggml_tensor * d_te; + struct ggml_tensor *d_te; // decoder.ln - struct ggml_tensor * d_ln_w; - struct ggml_tensor * d_ln_b; + struct ggml_tensor *d_ln_w; + struct ggml_tensor *d_ln_b; std::vector layers_encoder; std::vector layers_decoder; @@ -755,26 +1136,30 @@ struct whisper_model { std::map tensors; }; -struct whisper_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence +struct whisper_partial_utf8 +{ + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence }; -struct whisper_grammar { +struct whisper_grammar +{ /*const*/ std::vector> rules; - std::vector> stacks; + std::vector> stacks; // buffer for partially generated UTF-8 sequence from accepted tokens whisper_partial_utf8 partial_utf8; }; -struct whisper_grammar_candidate { - whisper_token id; - const uint32_t * code_points; - whisper_partial_utf8 partial_utf8; +struct whisper_grammar_candidate +{ + whisper_token id; + const uint32_t *code_points; + whisper_partial_utf8 partial_utf8; }; -struct whisper_sequence { +struct whisper_sequence +{ std::vector tokens; // the accumulated transcription in the current iteration (used to truncate the tokens array) @@ -788,12 +1173,13 @@ struct whisper_sequence { }; // TAGS: WHISPER_DECODER_INIT -struct whisper_decoder { +struct whisper_decoder +{ // the currently generated sequence of tokens whisper_sequence sequence; // grammar parse state of generated sequence of tokens - whisper_grammar grammar; + whisper_grammar grammar; int i_batch; // the index of the token in the current batch int seek_delta; // the window shift found so far based on the decoded timestamp tokens @@ -814,18 +1200,21 @@ struct whisper_decoder { }; // [EXPERIMENTAL] Token-level timestamps with DTW -struct whisper_aheads_masks { - std::vector m; // One mask per text layer. - struct ggml_context * ctx = nullptr; +struct whisper_aheads_masks +{ + std::vector m; // One mask per text layer. + struct ggml_context *ctx = nullptr; ggml_backend_buffer_t buffer = nullptr; }; -struct vad_time_mapping { - int64_t processed_time; // Time in processed (VAD) audio - int64_t original_time; // Corresponding time in original audio +struct vad_time_mapping +{ + int64_t processed_time; // Time in processed (VAD) audio + int64_t original_time; // Corresponding time in original audio }; -struct whisper_state { +struct whisper_state +{ int64_t t_sample_us = 0; int64_t t_encode_us = 0; int64_t t_decode_us = 0; @@ -869,8 +1258,8 @@ struct whisper_state { whisper_sched sched_decode; // result of the encoder - struct ggml_tensor * embd_conv = nullptr; - struct ggml_tensor * embd_enc = nullptr; + struct ggml_tensor *embd_conv = nullptr; + struct ggml_tensor *embd_enc = nullptr; // helpers for GPU offloading std::vector inp_mel; @@ -880,22 +1269,22 @@ struct whisper_state { std::vector logits; std::vector result_all; - std::vector prompt_past; + std::vector prompt_past; int lang_id = 0; // english by default std::string path_model; // populated by whisper_init_from_file_with_params() #ifdef WHISPER_USE_COREML - whisper_coreml_context * ctx_coreml = nullptr; + whisper_coreml_context *ctx_coreml = nullptr; #endif #ifdef WHISPER_USE_OPENVINO - whisper_openvino_context * ctx_openvino = nullptr; + whisper_openvino_context *ctx_openvino = nullptr; #endif // [EXPERIMENTAL] token-level timestamps data - int64_t t_beg = 0; + int64_t t_beg = 0; int64_t t_last = 0; whisper_token tid_last; @@ -905,15 +1294,16 @@ struct whisper_state { // [EXPERIMENTAL] Token-level timestamps with DTW whisper_aheads_masks aheads_masks; - ggml_tensor * aheads_cross_QKs = nullptr; + ggml_tensor *aheads_cross_QKs = nullptr; std::vector aheads_cross_QKs_data; // [EXPERIMENTAL] speed-up techniques int32_t exp_n_audio_ctx = 0; // 0 - use default - whisper_vad_context * vad_context = nullptr; + whisper_vad_context *vad_context = nullptr; - struct vad_segment_info { + struct vad_segment_info + { int64_t orig_start; int64_t orig_end; int64_t vad_start; @@ -925,8 +1315,9 @@ struct whisper_state { std::vector vad_mapping_table; }; -struct whisper_context { - int64_t t_load_us = 0; +struct whisper_context +{ + int64_t t_load_us = 0; int64_t t_start_us = 0; ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX) @@ -937,41 +1328,44 @@ struct whisper_context { whisper_model model; whisper_vocab vocab; - whisper_state * state = nullptr; + whisper_state *state = nullptr; std::string path_model; // populated by whisper_init_from_file_with_params() }; -struct whisper_global { +struct whisper_global +{ // We save the log callback globally ggml_log_callback log_callback = whisper_log_callback_default; - void * log_callback_user_data = nullptr; + void *log_callback_user_data = nullptr; }; static whisper_global g_state; -template -static void read_safe(whisper_model_loader * loader, T & dest) { +template +static void read_safe(whisper_model_loader *loader, T &dest) +{ loader->read(loader->context, &dest, sizeof(T)); BYTESWAP_VALUE(dest); } static bool whisper_kv_cache_init( - struct whisper_kv_cache & cache, - ggml_backend_t backend, - ggml_type wtype, - int64_t n_text_state, - int64_t n_text_layer, - int n_ctx) { - const int64_t n_mem = n_text_layer*n_ctx; - const int64_t n_elements = n_text_state*n_mem; + struct whisper_kv_cache &cache, + ggml_backend_t backend, + ggml_type wtype, + int64_t n_text_state, + int64_t n_text_layer, + int n_ctx) +{ + const int64_t n_mem = n_text_layer * n_ctx; + const int64_t n_elements = n_text_state * n_mem; - cache.ctx_buf.resize(2*ggml_tensor_overhead()); + cache.ctx_buf.resize(2 * ggml_tensor_overhead()); struct ggml_init_params params = { - /*.mem_size =*/ cache.ctx_buf.size(), - /*.mem_buffer =*/ cache.ctx_buf.data(), - /*.no_alloc =*/ true, + /*.mem_size =*/cache.ctx_buf.size(), + /*.mem_buffer =*/cache.ctx_buf.data(), + /*.no_alloc =*/true, }; cache.head = 0; @@ -980,9 +1374,10 @@ static bool whisper_kv_cache_init( cache.cells.clear(); cache.cells.resize(n_ctx); - struct ggml_context * ctx = ggml_init(params); + struct ggml_context *ctx = ggml_init(params); - if (!ctx) { + if (!ctx) + { WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__); return false; } @@ -991,7 +1386,8 @@ static bool whisper_kv_cache_init( cache.v = ggml_new_tensor_1d(ctx, wtype, n_elements); cache.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); - if (!cache.buffer) { + if (!cache.buffer) + { WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__); return false; } @@ -1003,54 +1399,65 @@ static bool whisper_kv_cache_init( return true; } -static void whisper_kv_cache_free(struct whisper_kv_cache & cache) { +static void whisper_kv_cache_free(struct whisper_kv_cache &cache) +{ ggml_backend_buffer_free(cache.buffer); } static bool whisper_kv_cache_find_slot( - struct whisper_kv_cache & cache, - const struct whisper_batch & batch) { - const uint32_t n_ctx = cache.size; + struct whisper_kv_cache &cache, + const struct whisper_batch &batch) +{ + const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; - if (n_tokens > n_ctx) { + if (n_tokens > n_ctx) + { WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); return false; } uint32_t n_tested = 0; - while (true) { - if (cache.head + n_tokens > n_ctx) { + while (true) + { + if (cache.head + n_tokens > n_ctx) + { n_tested += n_ctx - cache.head; cache.head = 0; continue; } bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.cells[cache.head + i].pos >= 0) { + for (uint32_t i = 0; i < n_tokens; i++) + { + if (cache.cells[cache.head + i].pos >= 0) + { found = false; cache.head += i + 1; - n_tested += i + 1; + n_tested += i + 1; break; } } - if (found) { + if (found) + { break; } - if (n_tested >= n_ctx) { - //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + if (n_tested >= n_ctx) + { + // WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); return false; } } - for (uint32_t i = 0; i < n_tokens; i++) { + for (uint32_t i = 0; i < n_tokens; i++) + { cache.cells[cache.head + i].pos = batch.pos[i]; - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) + { cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); } } @@ -1059,9 +1466,12 @@ static bool whisper_kv_cache_find_slot( } // find how many cells are currently in use -static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) { - for (uint32_t i = cache.size - 1; i > 0; --i) { - if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { +static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache &cache) +{ + for (uint32_t i = cache.size - 1; i > 0; --i) + { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) + { return i + 1; } } @@ -1069,8 +1479,10 @@ static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) return 1; } -static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) { - for (int32_t i = 0; i < (int32_t) cache.size; ++i) { +static void whisper_kv_cache_clear(struct whisper_kv_cache &cache) +{ + for (int32_t i = 0; i < (int32_t)cache.size; ++i) + { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } @@ -1080,66 +1492,88 @@ static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) { } static void whisper_kv_cache_seq_rm( - struct whisper_kv_cache & cache, - whisper_seq_id seq_id, - whisper_pos p0, - whisper_pos p1) { + struct whisper_kv_cache &cache, + whisper_seq_id seq_id, + whisper_pos p0, + whisper_pos p1) +{ uint32_t new_head = cache.size; - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) + p0 = 0; + if (p1 < 0) + p1 = std::numeric_limits::max(); - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - if (seq_id < 0) { + for (uint32_t i = 0; i < cache.size; ++i) + { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) + { + if (seq_id < 0) + { cache.cells[i].seq_id.clear(); - } else if (cache.cells[i].has_seq_id(seq_id)) { + } + else if (cache.cells[i].has_seq_id(seq_id)) + { cache.cells[i].seq_id.erase(seq_id); - } else { + } + else + { continue; } - if (cache.cells[i].seq_id.empty()) { + if (cache.cells[i].seq_id.empty()) + { cache.cells[i].pos = -1; - if (new_head == cache.size) new_head = i; + if (new_head == cache.size) + new_head = i; } } } // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size) cache.head = new_head; + if (new_head != cache.size) + cache.head = new_head; } static void whisper_kv_cache_seq_cp( - struct whisper_kv_cache & cache, - whisper_seq_id seq_id_src, - whisper_seq_id seq_id_dst, - whisper_pos p0, - whisper_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + struct whisper_kv_cache &cache, + whisper_seq_id seq_id_src, + whisper_seq_id seq_id_dst, + whisper_pos p0, + whisper_pos p1) +{ + if (p0 < 0) + p0 = 0; + if (p1 < 0) + p1 = std::numeric_limits::max(); cache.head = 0; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + for (uint32_t i = 0; i < cache.size; ++i) + { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) + { cache.cells[i].seq_id.insert(seq_id_dst); } } } -static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) { - if (!wctx.params.flash_attn || !wctx.params.use_gpu) { +static uint32_t whisper_kv_cache_get_padding(const struct whisper_context &wctx) +{ + if (!wctx.params.flash_attn || !wctx.params.use_gpu) + { return 1u; } #ifdef GGML_USE_METAL - if (wctx.params.use_gpu) { + if (wctx.params.use_gpu) + { return 32u; } #endif #ifdef GGML_USE_CUDA - if (wctx.params.use_gpu) { + if (wctx.params.use_gpu) + { return 256u; } #endif @@ -1149,49 +1583,64 @@ static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx // [EXPERIMENTAL] Token-level timestamps with DTW static bool aheads_masks_init( - const whisper_context_params & cparams, - const whisper_hparams & hparams, - struct whisper_aheads_masks & aheads_masks, - ggml_backend_t backend) { + const whisper_context_params &cparams, + const whisper_hparams &hparams, + struct whisper_aheads_masks &aheads_masks, + ggml_backend_t backend) +{ const int32_t n_text_layer = hparams.n_text_layer; const int32_t n_head = hparams.n_text_head; // Sanity checks - if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) + { WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__); return false; - } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) { - if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) { + } + else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) + { + if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) + { WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer); return false; } - } else { + } + else + { const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset); - if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) { - if (aheads.n_heads == 0) { + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) + { + if (aheads.n_heads == 0) + { WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__); return false; } - if (aheads.heads == NULL) { + if (aheads.heads == NULL) + { WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__); return false; } } - for (size_t i = 0; i < aheads.n_heads; ++i) { - if (aheads.heads[i].n_text_layer >= n_text_layer) { + for (size_t i = 0; i < aheads.n_heads; ++i) + { + if (aheads.heads[i].n_text_layer >= n_text_layer) + { WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer); return false; } - if (aheads.heads[i].n_text_layer < 0) { + if (aheads.heads[i].n_text_layer < 0) + { WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__); return false; } - if (aheads.heads[i].n_head >= n_head) { + if (aheads.heads[i].n_head >= n_head) + { WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head); return false; } - if (aheads.heads[i].n_head < 0) { + if (aheads.heads[i].n_head < 0) + { WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__); return false; } @@ -1199,29 +1648,35 @@ static bool aheads_masks_init( } struct ggml_init_params params = { - /*.mem_size =*/ (size_t) static_cast(n_text_layer)*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, + /*.mem_size =*/(size_t)static_cast(n_text_layer) * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, }; aheads_masks.ctx = ggml_init(params); - if (!aheads_masks.ctx) { + if (!aheads_masks.ctx) + { WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__); return false; } - for (int64_t il = 0; il < n_text_layer; ++il) { + for (int64_t il = 0; il < n_text_layer; ++il) + { auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head); - if (!aheads.empty()) { + if (!aheads.empty()) + { aheads_masks.m.push_back(ggml_new_tensor_2d(aheads_masks.ctx, GGML_TYPE_F32, n_head, aheads.size())); - } else { + } + else + { aheads_masks.m.push_back(nullptr); } } aheads_masks.buffer = ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend); - if (!aheads_masks.buffer) { + if (!aheads_masks.buffer) + { WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__); return false; } @@ -1237,8 +1692,10 @@ static bool aheads_masks_init( // 0 0 0 0 0 1 0 0 0 0 // 0 0 0 0 0 0 1 0 0 0 std::vector mask_data; - for (int64_t il = 0; il < n_text_layer; ++il) { - if (aheads_masks.m[il] != nullptr) { + for (int64_t il = 0; il < n_text_layer; ++il) + { + if (aheads_masks.m[il] != nullptr) + { auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head); size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1]; @@ -1246,7 +1703,8 @@ static bool aheads_masks_init( mask_data.resize(data_size); std::fill(mask_data.begin(), mask_data.end(), 0); - for (size_t ih = 0; ih < aheads.size(); ++ih) { + for (size_t ih = 0; ih < aheads.size(); ++ih) + { size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0])); mask_data[pos] = 1.0f; } @@ -1255,7 +1713,8 @@ static bool aheads_masks_init( } } - if (aheads_masks.m.empty()) { + if (aheads_masks.m.empty()) + { WHISPER_LOG_ERROR("%s: \n", __func__); return false; } @@ -1263,72 +1722,88 @@ static bool aheads_masks_init( return true; } -static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) { +static void aheads_masks_free(struct whisper_aheads_masks &aheads_masks) +{ ggml_free(aheads_masks.ctx); ggml_backend_buffer_free(aheads_masks.buffer); aheads_masks.ctx = nullptr; } -static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) { +static size_t aheads_masks_nbytes(struct whisper_aheads_masks &aheads_masks) +{ size_t size = 0; - for (size_t i = 0; i < aheads_masks.m.size(); ++i) { + for (size_t i = 0; i < aheads_masks.m.size(); ++i) + { if (aheads_masks.m[i] != nullptr) size += ggml_nbytes(aheads_masks.m[i]); } return size; } -static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & params) { +static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params ¶ms) +{ ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); ggml_backend_dev_t dev = nullptr; int cnt = 0; - if (params.use_gpu) { - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + if (params.use_gpu) + { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) + { ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) { - if (cnt == params.gpu_device) { + if (ggml_backend_dev_type(dev_cur) == GGML_BACKEND_DEVICE_TYPE_GPU) + { + if (cnt == params.gpu_device) + { dev = dev_cur; } - if (++cnt > params.gpu_device) { + if (++cnt > params.gpu_device) + { break; } } } } - if (dev == nullptr) { + if (dev == nullptr) + { WHISPER_LOG_INFO("%s: no GPU found\n", __func__); return nullptr; } WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); - if (!result) { + if (!result) + { WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); } return result; } -static std::vector whisper_backend_init(const whisper_context_params & params) { +static std::vector whisper_backend_init(const whisper_context_params ¶ms) +{ std::vector result; ggml_backend_t backend_gpu = whisper_backend_init_gpu(params); - if (backend_gpu) { + if (backend_gpu) + { result.push_back(backend_gpu); } // ACCEL backends - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) + { ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) + { WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); - if (!backend) { + if (!backend) + { WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); continue; } @@ -1337,7 +1812,8 @@ static std::vector whisper_backend_init(const whisper_context_pa } ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); - if (backend_cpu == nullptr) { + if (backend_cpu == nullptr) + { throw std::runtime_error("failed to initialize CPU backend"); } result.push_back(backend_cpu); @@ -1347,24 +1823,31 @@ static std::vector whisper_backend_init(const whisper_context_pa using buft_list_t = std::vector>; -static buft_list_t make_buft_list(whisper_context_params & params) { +static buft_list_t make_buft_list(whisper_context_params ¶ms) +{ // Prio order: GPU -> CPU Extra -> CPU buft_list_t buft_list; // GPU - if (params.use_gpu) { + if (params.use_gpu) + { int cnt = 0; - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) + { ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) { - if (cnt == params.gpu_device) { - auto * buft = ggml_backend_dev_buffer_type(dev); - if (buft) { + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) + { + if (cnt == params.gpu_device) + { + auto *buft = ggml_backend_dev_buffer_type(dev); + if (buft) + { buft_list.emplace_back(dev, buft); } } - if (++cnt > params.gpu_device) { + if (++cnt > params.gpu_device) + { break; } } @@ -1372,13 +1855,15 @@ static buft_list_t make_buft_list(whisper_context_params & params) { } // CPU Extra - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto *cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto *cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); - if (get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev); - while (extra_bufts && *extra_bufts) { + if (get_extra_bufts_fn) + { + ggml_backend_buffer_type_t *extra_bufts = get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) + { buft_list.emplace_back(cpu_dev, *extra_bufts); ++extra_bufts; } @@ -1390,66 +1875,80 @@ static buft_list_t make_buft_list(whisper_context_params & params) { return buft_list; } -static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { +static bool weight_buft_supported(const whisper_hparams &hparams, ggml_tensor *w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) +{ bool op_supported = true; if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || - (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) + { // GPU and default CPU backend support all operators op_supported = true; - } else { - switch (op) { - // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS - case GGML_OP_GET_ROWS: - case GGML_OP_MUL_MAT: { - ggml_init_params params = { - /*.mem_size =*/ 2 * ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error("failed to create ggml context"); - } - ggml_context * ctx = ctx_ptr.get(); - - ggml_tensor * op_tensor = nullptr; - - if (op == GGML_OP_MUL_MAT) { - int64_t n_ctx = hparams.n_audio_ctx; - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); - op_tensor = ggml_mul_mat(ctx, w, b); - } else if (op == GGML_OP_GET_ROWS) { - int64_t num_indices = 8; - ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices); - op_tensor = ggml_get_rows(ctx, w, indices); - } + } + else + { + switch (op) + { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS + case GGML_OP_GET_ROWS: + case GGML_OP_MUL_MAT: + { + ggml_init_params params = { + /*.mem_size =*/2 * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, + }; - // create a temporary dummy buffer for the weight so that supports_op can check the buffer type - GGML_ASSERT(w->buffer == nullptr); - w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); - op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - ggml_backend_buffer_free(w->buffer); - w->buffer = nullptr; - break; + ggml_context_ptr ctx_ptr{ggml_init(params)}; + if (!ctx_ptr) + { + throw std::runtime_error("failed to create ggml context"); } - default: { - op_supported = false; - break; + ggml_context *ctx = ctx_ptr.get(); + + ggml_tensor *op_tensor = nullptr; + + if (op == GGML_OP_MUL_MAT) + { + int64_t n_ctx = hparams.n_audio_ctx; + ggml_tensor *b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } + else if (op == GGML_OP_GET_ROWS) + { + int64_t num_indices = 8; + ggml_tensor *indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices); + op_tensor = ggml_get_rows(ctx, w, indices); } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: + { + op_supported = false; + break; + } }; } return op_supported; } -static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { +static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams &hparams, ggml_tensor *w, ggml_op op, buft_list_t buft_list) +{ GGML_ASSERT(!buft_list.empty()); - for (const auto & p : buft_list) { + for (const auto &p : buft_list) + { ggml_backend_dev_t dev = p.first; ggml_backend_buffer_type_t buft = p.second; - if (weight_buft_supported(hparams, w, op, buft, dev)) { + if (weight_buft_supported(hparams, w, op, buft, dev)) + { return buft; } } @@ -1468,29 +1967,31 @@ static ggml_backend_buffer_type_t select_weight_buft(const whisper_hparams & hpa // // see the convert-pt-to-ggml.py script for details // -static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { +static bool whisper_model_load(struct whisper_model_loader *loader, whisper_context &wctx) +{ WHISPER_LOG_INFO("%s: loading model\n", __func__); const int64_t t_start_us = ggml_time_us(); wctx.t_start_us = t_start_us; - auto & model = wctx.model; - auto & vocab = wctx.vocab; + auto &model = wctx.model; + auto &vocab = wctx.vocab; // verify magic { uint32_t magic; read_safe(loader, magic); - if (magic != GGML_FILE_MAGIC) { + if (magic != GGML_FILE_MAGIC) + { WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); return false; } } - //load hparams + // load hparams { - auto & hparams = model.hparams; + auto &hparams = model.hparams; read_safe(loader, hparams.n_vocab); read_safe(loader, hparams.n_audio_ctx); @@ -1508,26 +2009,32 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con std::string mver = ""; - if (hparams.n_audio_layer == 4) { + if (hparams.n_audio_layer == 4) + { model.type = e_model::MODEL_TINY; } - if (hparams.n_audio_layer == 6) { + if (hparams.n_audio_layer == 6) + { model.type = e_model::MODEL_BASE; } - if (hparams.n_audio_layer == 12) { + if (hparams.n_audio_layer == 12) + { model.type = e_model::MODEL_SMALL; } - if (hparams.n_audio_layer == 24) { + if (hparams.n_audio_layer == 24) + { model.type = e_model::MODEL_MEDIUM; } - if (hparams.n_audio_layer == 32) { + if (hparams.n_audio_layer == 32) + { model.type = e_model::MODEL_LARGE; - if (hparams.n_vocab == 51866) { + if (hparams.n_vocab == 51866) + { mver = " v3"; } } @@ -1538,8 +2045,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // for the big tensors, we have the option to store the data in 16-bit floats or quantized // in order to save memory and also to speed up the computation - wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); - if (wctx.wtype == GGML_TYPE_COUNT) { + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype)(model.hparams.ftype)); + if (wctx.wtype == GGML_TYPE_COUNT) + { WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); return false; } @@ -1561,7 +2069,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // load mel filters { - auto & filters = wctx.model.filters; + auto &filters = wctx.model.filters; read_safe(loader, filters.n_mel); read_safe(loader, filters.n_fft); @@ -1576,80 +2084,110 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con int32_t n_vocab = 0; read_safe(loader, n_vocab); - //if (n_vocab != model.hparams.n_vocab) { - // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n", - // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); - // return false; - //} + // if (n_vocab != model.hparams.n_vocab) { + // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + // } std::string word; std::vector tmp; tmp.reserve(128); - for (int i = 0; i < n_vocab; i++) { + for (int i = 0; i < n_vocab; i++) + { uint32_t len; read_safe(loader, len); - if (len > 0) { + if (len > 0) + { tmp.resize(len); loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer word.assign(&tmp[0], tmp.size()); - } else { + } + else + { // seems like we have an empty-string token in multi-language models (i = 50256) - //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + // WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); word = ""; } vocab.token_to_id[word] = i; vocab.id_to_token[i] = word; - //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + // printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); } vocab.n_vocab = model.hparams.n_vocab; - if (vocab.is_multilingual()) { + if (vocab.is_multilingual()) + { vocab.token_eot++; vocab.token_sot++; // account for variable number of language tokens const int dt = vocab.num_languages() - 98; - vocab.token_translate += dt; + vocab.token_translate += dt; vocab.token_transcribe += dt; - vocab.token_solm += dt; - vocab.token_prev += dt; - vocab.token_nosp += dt; - vocab.token_not += dt; - vocab.token_beg += dt; + vocab.token_solm += dt; + vocab.token_prev += dt; + vocab.token_nosp += dt; + vocab.token_not += dt; + vocab.token_beg += dt; } - if (n_vocab < model.hparams.n_vocab) { + if (n_vocab < model.hparams.n_vocab) + { WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); - for (int i = n_vocab; i < model.hparams.n_vocab; i++) { - if (i > vocab.token_beg) { + for (int i = n_vocab; i < model.hparams.n_vocab; i++) + { + if (i > vocab.token_beg) + { word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; - } else if (i == vocab.token_eot) { + } + else if (i == vocab.token_eot) + { word = "[_EOT_]"; - } else if (i == vocab.token_sot) { + } + else if (i == vocab.token_sot) + { word = "[_SOT_]"; - } else if (i == vocab.token_translate) { + } + else if (i == vocab.token_translate) + { word = "[_TRANSLATE_]"; - } else if (i == vocab.token_transcribe) { + } + else if (i == vocab.token_transcribe) + { word = "[_TRANSCRIBE_]"; - } else if (i == vocab.token_solm) { + } + else if (i == vocab.token_solm) + { word = "[_SOLM_]"; - } else if (i == vocab.token_prev) { + } + else if (i == vocab.token_prev) + { word = "[_PREV_]"; - } else if (i == vocab.token_nosp) { + } + else if (i == vocab.token_nosp) + { word = "[_NOSP_]"; - } else if (i == vocab.token_not) { + } + else if (i == vocab.token_not) + { word = "[_NOT_]"; - } else if (i == vocab.token_beg) { + } + else if (i == vocab.token_beg) + { word = "[_BEG_]"; - } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) { + } + else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) + { word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]"; - } else { + } + else + { word = "[_extra_token_" + std::to_string(i) + "]"; } vocab.token_to_id[word] = i; @@ -1663,25 +2201,28 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con const ggml_type wtype = wctx.wtype; const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type - const auto & hparams = model.hparams; + const auto &hparams = model.hparams; const int n_audio_layer = hparams.n_audio_layer; - const int n_text_layer = hparams.n_text_layer; + const int n_text_layer = hparams.n_text_layer; - const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer; + const size_t n_tensors = 10 /* input */ + 15 + 15 * n_audio_layer + 24 * n_text_layer; std::map ctx_map; - auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * + { auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { + if (it == ctx_map.end()) + { ggml_init_params params = { - /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, + /*.mem_size =*/n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, }; - ggml_context * ctx = ggml_init(params); - if (!ctx) { + ggml_context *ctx = ggml_init(params); + if (!ctx) + { throw std::runtime_error("failed to create ggml context"); } @@ -1697,41 +2238,42 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // Create a list of available bufts, in priority order buft_list_t buft_list = make_buft_list(wctx.params); - auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor * meta, int layer = 0) -> ggml_tensor * { + auto create_tensor = [&](asr_tensor type, asr_system system, ggml_tensor *meta, int layer = 0) -> ggml_tensor * + { ggml_op op = ASR_TENSOR_INFO.at(type); ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); - if (!buft) { + if (!buft) + { throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", ASR_TENSOR_NAMES.at(system).at(type))); } - ggml_context * ctx = get_ctx(buft); - ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + ggml_context *ctx = get_ctx(buft); + ggml_tensor *tensor = ggml_dup_tensor(ctx, meta); model.tensors[format(ASR_TENSOR_NAMES.at(system).at(type), layer)] = tensor; return tensor; }; - // prepare tensors for the weights { ggml_init_params params = { - /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, + /*.mem_size =*/n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, }; - ggml_context * ctx = ggml_init(params); + ggml_context *ctx = ggml_init(params); - const auto & hparams = model.hparams; + const auto &hparams = model.hparams; const int n_vocab = hparams.n_vocab; - const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_ctx = hparams.n_audio_ctx; const int n_audio_state = hparams.n_audio_state; const int n_audio_layer = hparams.n_audio_layer; - const int n_text_ctx = hparams.n_text_ctx; + const int n_text_ctx = hparams.n_text_ctx; const int n_text_state = hparams.n_text_state; const int n_text_layer = hparams.n_text_layer; @@ -1752,17 +2294,18 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.e_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); model.e_ln_b = create_tensor(ASR_TENSOR_LN_POST_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); - for (int i = 0; i < n_audio_layer; ++i) { - auto & layer = model.layers_encoder[i]; + for (int i = 0; i < n_audio_layer; ++i) + { + auto &layer = model.layers_encoder[i]; layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); - layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); - layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); - layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state), i); + layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4 * n_audio_state), i); + layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4 * n_audio_state), i); - layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); - layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, wtype, 4 * n_audio_state, n_audio_state), i); + layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); layer.attn_ln_0_b = create_tensor(ASR_TENSOR_ATTN_LN_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); @@ -1787,16 +2330,17 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.d_ln_w = create_tensor(ASR_TENSOR_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state)); model.d_ln_b = create_tensor(ASR_TENSOR_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state)); - for (int i = 0; i < n_text_layer; ++i) { - auto & layer = model.layers_decoder[i]; + for (int i = 0; i < n_text_layer; ++i) + { + auto &layer = model.layers_decoder[i]; layer.mlp_ln_w = create_tensor(ASR_TENSOR_MLP_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); layer.mlp_ln_b = create_tensor(ASR_TENSOR_MLP_LN_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); - layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state), i); - layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state), i); + layer.mlp_0_w = create_tensor(ASR_TENSOR_MLP_0_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, n_text_state, 4 * n_text_state), i); + layer.mlp_0_b = create_tensor(ASR_TENSOR_MLP_0_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4 * n_text_state), i); - layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state), i); + layer.mlp_1_w = create_tensor(ASR_TENSOR_MLP_2_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_2d(ctx, wtype, 4 * n_text_state, n_text_state), i); layer.mlp_1_b = create_tensor(ASR_TENSOR_MLP_2_BIAS, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); layer.attn_ln_0_w = create_tensor(ASR_TENSOR_ATTN_LN_WEIGHT, ASR_SYSTEM_DECODER, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state), i); @@ -1832,11 +2376,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } // allocate tensors in the backend buffers - for (auto & p : ctx_map) { + for (auto &p : ctx_map) + { ggml_backend_buffer_type_t buft = p.first; - ggml_context * ctx = p.second; + ggml_context *ctx = p.second; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (buf) { + if (buf) + { model.buffers.emplace_back(buf); size_t size_main = ggml_backend_buffer_get_size(buf); @@ -1852,7 +2398,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con std::vector read_buf; - while (true) { + while (true) + { int32_t n_dims; int32_t length; int32_t ttype; @@ -1861,55 +2408,64 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con read_safe(loader, length); read_safe(loader, ttype); - if (loader->eof(loader->context)) { + if (loader->eof(loader->context)) + { break; } int32_t nelements = 1; - int32_t ne[4] = { 1, 1, 1, 1 }; - for (int i = 0; i < n_dims; ++i) { + int32_t ne[4] = {1, 1, 1, 1}; + for (int i = 0; i < n_dims; ++i) + { read_safe(loader, ne[i]); nelements *= ne[i]; } std::string name; - std::vector tmp(length); // create a buffer + std::vector tmp(length); // create a buffer loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer name.assign(&tmp[0], tmp.size()); - if (model.tensors.find(name) == model.tensors.end()) { + if (model.tensors.find(name) == model.tensors.end()) + { WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); return false; } auto tensor = model.tensors[name.data()]; - if (ggml_nelements(tensor) != nelements) { + if (ggml_nelements(tensor) != nelements) + { WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", - __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + __func__, ne[0], ne[1], ne[2], (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2]); return false; } - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) + { WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", - __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + __func__, name.data(), (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2], ne[0], ne[1], ne[2]); return false; } const size_t bpe = ggml_type_size(ggml_type(ttype)); - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + if ((nelements * bpe) / ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) + { WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + __func__, name.data(), ggml_nbytes(tensor), nelements * bpe); return false; } - if (ggml_backend_buffer_is_host(tensor->buffer)) { + if (ggml_backend_buffer_is_host(tensor->buffer)) + { // for the CPU and Metal backend, we can read directly into the tensor loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); BYTESWAP_TENSOR(tensor); - } else { + } + else + { // read into a temporary buffer first, then copy to device memory read_buf.resize(ggml_nbytes(tensor)); @@ -1922,17 +2478,21 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.n_loaded++; } - WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size / 1e6); - if (model.n_loaded == 0) { + if (model.n_loaded == 0) + { WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); - } else if (model.n_loaded != (int) model.tensors.size()) { + } + else if (model.n_loaded != (int)model.tensors.size()) + { WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); return false; } } - for (auto & buf : model.buffers) { + for (auto &buf : model.buffers) + { ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); } @@ -1941,7 +2501,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con return true; } -static bool whisper_encode_external(const whisper_state & wstate) { +static bool whisper_encode_external(const whisper_state &wstate) +{ GGML_UNUSED(wstate); #ifndef WHISPER_USE_COREML @@ -1959,34 +2520,37 @@ static bool whisper_encode_external(const whisper_state & wstate) { return use_coreml || use_openvino; } -static struct ggml_cgraph * whisper_build_graph_conv( - whisper_context & wctx, - whisper_state & wstate) { - const auto & model = wctx.model; - const auto & hparams = model.hparams; +static struct ggml_cgraph *whisper_build_graph_conv( + whisper_context &wctx, + whisper_state &wstate) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; - const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; - const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state); + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + GGML_UNUSED(n_state); const int n_mels = hparams.n_mels; struct ggml_init_params params = { - /*.mem_size =*/ wstate.sched_conv.meta.size(), - /*.mem_buffer =*/ wstate.sched_conv.meta.data(), - /*.no_alloc =*/ true, + /*.mem_size =*/wstate.sched_conv.meta.size(), + /*.mem_buffer =*/wstate.sched_conv.meta.data(), + /*.no_alloc =*/true, }; - struct ggml_context * ctx0 = ggml_init(params); + struct ggml_context *ctx0 = ggml_init(params); - ggml_cgraph * gf = ggml_new_graph(ctx0); + ggml_cgraph *gf = ggml_new_graph(ctx0); - struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); + struct ggml_tensor *mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2 * n_ctx, n_mels); ggml_set_name(mel, "mel"); ggml_set_input(mel); - struct ggml_tensor * cur = nullptr; + struct ggml_tensor *cur = nullptr; - if (!whisper_encode_external(wstate)) { + if (!whisper_encode_external(wstate)) + { // convolution + gelu { cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); @@ -2002,7 +2566,9 @@ static struct ggml_cgraph * whisper_build_graph_conv( ggml_set_name(cur, "embd_conv"); wstate.embd_conv = cur; - } else { + } + else + { ggml_build_forward_expand(gf, mel); cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); @@ -2021,68 +2587,70 @@ static struct ggml_cgraph * whisper_build_graph_conv( return gf; } -static struct ggml_cgraph * whisper_build_graph_encoder( - whisper_context & wctx, - whisper_state & wstate) { - const auto & model = wctx.model; - const auto & hparams = model.hparams; +static struct ggml_cgraph *whisper_build_graph_encoder( + whisper_context &wctx, + whisper_state &wstate) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; - const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; - const int n_head = hparams.n_audio_head; + const int n_head = hparams.n_audio_head; const int n_layer = hparams.n_audio_layer; - const int n_state_head = n_state/n_head; + const int n_state_head = n_state / n_head; - auto & kv_pad = wstate.kv_pad; + auto &kv_pad = wstate.kv_pad; WHISPER_ASSERT(!!kv_pad.buffer); const int n_ctx_pad = GGML_PAD(n_ctx, 256); struct ggml_init_params params = { - /*.mem_size =*/ wstate.sched_encode.meta.size(), - /*.mem_buffer =*/ wstate.sched_encode.meta.data(), - /*.no_alloc =*/ true, + /*.mem_size =*/wstate.sched_encode.meta.size(), + /*.mem_buffer =*/wstate.sched_encode.meta.data(), + /*.no_alloc =*/true, }; - struct ggml_context * ctx0 = ggml_init(params); + struct ggml_context *ctx0 = ggml_init(params); - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); + ggml_cgraph *gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); - struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); + struct ggml_tensor *cur = ggml_view_tensor(ctx0, wstate.embd_conv); - const float KQscale = 1.0f/sqrtf(float(n_state_head)); + const float KQscale = 1.0f / sqrtf(float(n_state_head)); // =================================================================== // NOTE: experimenting with partial evaluation of the encoder (ignore) - //static int iter = -1; - //const int n_iter = 1500/n_ctx; + // static int iter = -1; + // const int n_iter = 1500/n_ctx; - //iter = (iter + 1) % n_iter; + // iter = (iter + 1) % n_iter; - //if (iter == 0) { - // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); - // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); - //} + // if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + // } static int iter = 0; - const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); - const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + const size_t e_pe_stride = model.e_pe->ne[0] * ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0] * ggml_element_size(model.e_pe) * n_ctx * iter; - struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + struct ggml_tensor *e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); // =================================================================== // original: - //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + // cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); - struct ggml_tensor * inpL = cur; + struct ggml_tensor *inpL = cur; - for (int il = 0; il < n_layer; ++il) { - const auto & layer = model.layers_encoder[il]; + for (int il = 0; il < n_layer; ++il) + { + const auto &layer = model.layers_encoder[il]; // norm { @@ -2090,86 +2658,89 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, layer.attn_ln_0_w), - layer.attn_ln_0_b); + ggml_mul(ctx0, cur, layer.attn_ln_0_w), + layer.attn_ln_0_b); } // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.attn_q_w, - cur); + struct ggml_tensor *Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); - //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); + // Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, - layer.attn_k_w, - cur); + struct ggml_tensor *Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); - //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); + // Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); + struct ggml_tensor *Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b); // ------ - struct ggml_tensor * Q = + struct ggml_tensor *Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx), - 0, 2, 1, 3); + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx), + 0, 2, 1, 3); - if (wctx.params.flash_attn) { - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0))); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0))); + if (wctx.params.flash_attn) + { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx * n_state, 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx * n_state, 0))); - struct ggml_tensor * K = + struct ggml_tensor *K = ggml_view_3d(ctx0, kv_pad.k, - n_state_head, n_ctx_pad, n_head, - ggml_element_size(kv_pad.k)*n_state, - ggml_element_size(kv_pad.k)*n_state_head, - 0); + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.k) * n_state, + ggml_element_size(kv_pad.k) * n_state_head, + 0); - struct ggml_tensor * V = + struct ggml_tensor *V = ggml_view_3d(ctx0, kv_pad.v, - n_state_head, n_ctx_pad, n_head, - ggml_element_size(kv_pad.v)*n_state, - ggml_element_size(kv_pad.v)*n_state_head, - 0); + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.v) * n_state, + ggml_element_size(kv_pad.v) * n_state_head, + 0); cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f); cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); - } else { - struct ggml_tensor * K = + } + else + { + struct ggml_tensor *K = ggml_permute(ctx0, - ggml_cast(ctx0, - ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx), - wctx.itype), - 0, 2, 1, 3); + ggml_cast(ctx0, + ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx), + wctx.itype), + 0, 2, 1, 3); // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + struct ggml_tensor *KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); - struct ggml_tensor * V = + struct ggml_tensor *V = ggml_cast(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state_head, n_head, n_ctx), - 1, 2, 0, 3), - wctx.itype); + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state_head, n_head, n_ctx), + 1, 2, 0, 3), + wctx.itype); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor *KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx); } @@ -2178,8 +2749,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // projection { cur = ggml_mul_mat(ctx0, - layer.attn_ln_1_w, - cur); + layer.attn_ln_1_w, + cur); cur = ggml_add(ctx0, cur, layer.attn_ln_1_b); } @@ -2187,7 +2758,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // add the input cur = ggml_add(ctx0, cur, inpL); - struct ggml_tensor * inpFF = cur; + struct ggml_tensor *inpFF = cur; // feed-forward network { @@ -2197,14 +2768,14 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, layer.mlp_ln_w), - layer.mlp_ln_b); + ggml_mul(ctx0, cur, layer.mlp_ln_w), + layer.mlp_ln_b); } // fully connected cur = ggml_mul_mat(ctx0, - layer.mlp_0_w, - cur); + layer.mlp_0_w, + cur); cur = ggml_add(ctx0, cur, layer.mlp_0_b); @@ -2213,8 +2784,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // projection cur = ggml_mul_mat(ctx0, - layer.mlp_1_w, - cur); + layer.mlp_1_w, + cur); cur = ggml_add(ctx0, cur, layer.mlp_1_b); } @@ -2230,24 +2801,24 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // cur = ln_f_g*cur + ln_f_b cur = ggml_add(ctx0, - ggml_mul(ctx0, cur, model.e_ln_w), - model.e_ln_b); + ggml_mul(ctx0, cur, model.e_ln_w), + model.e_ln_b); } ggml_build_forward_expand(gf, cur); wstate.embd_enc = cur; - //ggml_graph_print(gf); + // ggml_graph_print(gf); //////////////////////////////////////////////////////////////////////////// - //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // ggml_used_mem(ctx0)/1e6, - // wstate.get_buf_max_mem(0)/1e6, - // wstate.get_buf_max_mem(1)/1e6, - // wstate.get_buf_max_mem(2)/1e6, - // wstate.get_buf_max_mem(3)/1e6); + // printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); ggml_free(ctx0); @@ -2255,76 +2826,81 @@ static struct ggml_cgraph * whisper_build_graph_encoder( } // pre-compute cross-attention memory -static struct ggml_cgraph * whisper_build_graph_cross( - whisper_context & wctx, - whisper_state & wstate) { - const auto & model = wctx.model; - const auto & hparams = model.hparams; +static struct ggml_cgraph *whisper_build_graph_cross( + whisper_context &wctx, + whisper_state &wstate) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; - const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_state = hparams.n_audio_state; - const int n_head = hparams.n_audio_head; + const int n_head = hparams.n_audio_head; - const int n_state_head = n_state/n_head; + const int n_state_head = n_state / n_head; const int n_ctx_pad = GGML_PAD(n_ctx, 256); struct ggml_init_params params = { - /*.mem_size =*/ wstate.sched_cross.meta.size(), - /*.mem_buffer =*/ wstate.sched_cross.meta.data(), - /*.no_alloc =*/ true, + /*.mem_size =*/wstate.sched_cross.meta.size(), + /*.mem_buffer =*/wstate.sched_cross.meta.data(), + /*.no_alloc =*/true, }; - struct ggml_context * ctx0 = ggml_init(params); + struct ggml_context *ctx0 = ggml_init(params); - ggml_cgraph * gf = ggml_new_graph(ctx0); + ggml_cgraph *gf = ggml_new_graph(ctx0); - struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); + struct ggml_tensor *cur = ggml_view_tensor(ctx0, wstate.embd_enc); - const float Kscale = pow(float(n_state_head), -0.25); + const float Kscale = pow(float(n_state_head), -0.25); - for (int il = 0; il < model.hparams.n_text_layer; ++il) { - auto & layer = model.layers_decoder[il]; + for (int il = 0; il < model.hparams.n_text_layer; ++il) + { + auto &layer = model.layers_decoder[il]; - struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, - layer.cross_attn_k_w, - cur); + struct ggml_tensor *Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); Kcross = ggml_scale(ctx0, Kcross, Kscale); - struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, - layer.cross_attn_v_w, - cur); + struct ggml_tensor *Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); Vcross = ggml_add(ctx0, - Vcross, - layer.cross_attn_v_b); + Vcross, + layer.cross_attn_v_b); - struct ggml_tensor * k; - struct ggml_tensor * v; + struct ggml_tensor *k; + struct ggml_tensor *v; - if (wctx.params.flash_attn) { - k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, - (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad)); + if (wctx.params.flash_attn) + { + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state * n_ctx, + (ggml_element_size(wstate.kv_cross.k) * n_state) * (il * n_ctx_pad)); - v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, - (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad)); - } else { + v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state * n_ctx, + (ggml_element_size(wstate.kv_cross.v) * n_state) * (il * n_ctx_pad)); + } + else + { Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); - k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, - (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state * n_ctx, + (ggml_element_size(wstate.kv_cross.k) * n_state) * (il * n_ctx)); v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, - ( n_ctx)*ggml_element_size(wstate.kv_cross.v), - (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + (n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il * n_ctx) * ggml_element_size(wstate.kv_cross.v) * n_state); } ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v)); } - //ggml_graph_print(gf); + // ggml_graph_print(gf); ggml_free(ctx0); @@ -2342,61 +2918,69 @@ static struct ggml_cgraph * whisper_build_graph_cross( // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // static bool whisper_encode_internal( - whisper_context & wctx, - whisper_state & wstate, - const int mel_offset, - const int n_threads, - ggml_abort_callback abort_callback, - void * abort_callback_data) { + whisper_context &wctx, + whisper_state &wstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void *abort_callback_data) +{ const int64_t t_start_us = ggml_time_us(); // conv { - auto & sched = wstate.sched_conv.sched; + auto &sched = wstate.sched_conv.sched; - ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate); + ggml_cgraph *gf = whisper_build_graph_conv(wctx, wstate); - if (!ggml_backend_sched_alloc_graph(sched, gf)) { + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { // should never happen as we pre-allocate the memory return false; } - struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + struct ggml_tensor *mel = ggml_graph_get_tensor(gf, "mel"); // set the input { - const auto & mel_inp = wstate.mel; - const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx; + const auto &mel_inp = wstate.mel; + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx; assert(mel->type == GGML_TYPE_F32); assert(mel_inp.n_mel == wctx.model.hparams.n_mels); wstate.inp_mel.resize(ggml_nelements(mel)); - float * dst = wstate.inp_mel.data(); + float *dst = wstate.inp_mel.data(); memset(dst, 0, ggml_nbytes(mel)); - const int i0 = std::min(mel_offset, mel_inp.n_len); - const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2 * n_ctx, mel_inp.n_len); - for (int j = 0; j < mel_inp.n_mel; ++j) { - for (int i = i0; i < i1; ++i) { - dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; + for (int j = 0; j < mel_inp.n_mel; ++j) + { + for (int i = i0; i < i1; ++i) + { + dst[j * 2 * n_ctx + (i - i0)] = mel_inp.data[j * mel_inp.n_len + i]; } } - ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel) * sizeof(float)); } - if (!whisper_encode_external(wstate)) { - if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + if (!whisper_encode_external(wstate)) + { + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { return false; } - } else { + } + else + { ggml_backend_sched_reset(sched); #if defined(WHISPER_USE_COREML) - whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data); + whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *)mel->data, (float *)wstate.embd_enc->data); #elif defined(WHISPER_USE_OPENVINO) whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc); #endif @@ -2404,33 +2988,38 @@ static bool whisper_encode_internal( } // encoder - if (!whisper_encode_external(wstate)) { - auto & sched = wstate.sched_encode.sched; + if (!whisper_encode_external(wstate)) + { + auto &sched = wstate.sched_encode.sched; - ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate); + ggml_cgraph *gf = whisper_build_graph_encoder(wctx, wstate); - if (!ggml_backend_sched_alloc_graph(sched, gf)) { + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { // should never happen as we pre-allocate the memory return false; } - if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { return false; } } // cross { - auto & sched = wstate.sched_cross.sched; + auto &sched = wstate.sched_cross.sched; - ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate); + ggml_cgraph *gf = whisper_build_graph_cross(wctx, wstate); - if (!ggml_backend_sched_alloc_graph(sched, gf)) { + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { // should never happen as we pre-allocate the memory return false; } - if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { return false; } } @@ -2441,75 +3030,77 @@ static bool whisper_encode_internal( return !(abort_callback && abort_callback(abort_callback_data)); } -static struct ggml_cgraph * whisper_build_graph_decoder( - whisper_context & wctx, - whisper_state & wstate, - const whisper_batch & batch, - bool save_alignment_heads_QKs, - bool worst_case) { - const auto & model = wctx.model; - const auto & hparams = model.hparams; +static struct ggml_cgraph *whisper_build_graph_decoder( + whisper_context &wctx, + whisper_state &wstate, + const whisper_batch &batch, + bool save_alignment_heads_QKs, + bool worst_case) +{ + const auto &model = wctx.model; + const auto &hparams = model.hparams; - auto & kv_self = wstate.kv_self; + auto &kv_self = wstate.kv_self; WHISPER_ASSERT(!!kv_self.buffer); - const int n_ctx = kv_self.size; + const int n_ctx = kv_self.size; const int n_state = hparams.n_text_state; - const int n_head = hparams.n_text_head; + const int n_head = hparams.n_text_head; const int n_layer = hparams.n_text_layer; - const int n_state_head = n_state/n_head; + const int n_state_head = n_state / n_head; - const int n_tokens = batch.n_tokens; + const int n_tokens = batch.n_tokens; const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256); - const int32_t n_kv = worst_case ? n_ctx : kv_self.n; + const int32_t n_kv = worst_case ? n_ctx : kv_self.n; const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; - //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); + // WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); struct ggml_init_params params = { - /*.mem_size =*/ wstate.sched_decode.meta.size(), - /*.mem_buffer =*/ wstate.sched_decode.meta.data(), - /*.no_alloc =*/ true, + /*.mem_size =*/wstate.sched_decode.meta.size(), + /*.mem_buffer =*/wstate.sched_decode.meta.data(), + /*.no_alloc =*/true, }; - struct ggml_context * ctx0 = ggml_init(params); + struct ggml_context *ctx0 = ggml_init(params); - ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); + ggml_cgraph *gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor *embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_name(embd, "embd"); ggml_set_input(embd); - struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + struct ggml_tensor *position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_name(position, "position"); ggml_set_input(position); const float KQscale = pow(float(n_state_head), -0.25); - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1); + struct ggml_tensor *KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1); ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_input(KQ_mask); - struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); + struct ggml_tensor *KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); // token encoding + position encoding - struct ggml_tensor * cur = + struct ggml_tensor *cur = ggml_add(ctx0, - ggml_get_rows(ctx0, model.d_te, embd), - ggml_get_rows(ctx0, model.d_pe, position)); + ggml_get_rows(ctx0, model.d_te, embd), + ggml_get_rows(ctx0, model.d_pe, position)); - struct ggml_tensor * inpL = cur; + struct ggml_tensor *inpL = cur; // [EXPERIMENTAL] Token-level timestamps with DTW - struct ggml_tensor * aheads_cross_QKs = nullptr; + struct ggml_tensor *aheads_cross_QKs = nullptr; - for (int il = 0; il < n_layer; ++il) { - const auto & layer = model.layers_decoder[il]; + for (int il = 0; il < n_layer; ++il) + { + const auto &layer = model.layers_decoder[il]; // norm { @@ -2517,59 +3108,62 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - cur, - layer.attn_ln_0_w), - layer.attn_ln_0_b); + ggml_mul(ctx0, + cur, + layer.attn_ln_0_w), + layer.attn_ln_0_b); } // self-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.attn_q_w, - cur); + struct ggml_tensor *Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); Qcur = ggml_add(ctx0, - Qcur, - layer.attn_q_b); + Qcur, + layer.attn_q_b); Qcur = ggml_scale(ctx0, Qcur, KQscale); // note: no bias for Key - struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, - layer.attn_k_w, - cur); + struct ggml_tensor *Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); Kcur = ggml_scale(ctx0, Kcur, KQscale); // store key and value to memory { - struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, - layer.attn_v_w, - cur); + struct ggml_tensor *Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); Vcur = ggml_add(ctx0, - Vcur, - layer.attn_v_b); + Vcur, + layer.attn_v_b); - struct ggml_tensor * k; - struct ggml_tensor * v; + struct ggml_tensor *k; + struct ggml_tensor *v; - if (wctx.params.flash_attn) { - k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, - (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + if (wctx.params.flash_attn) + { + k = ggml_view_1d(ctx0, kv_self.k, n_tokens * n_state, + (ggml_element_size(kv_self.k) * n_state) * (il * n_ctx + kv_head)); - v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state, - (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head)); - } else { + v = ggml_view_1d(ctx0, kv_self.v, n_tokens * n_state, + (ggml_element_size(kv_self.v) * n_state) * (il * n_ctx + kv_head)); + } + else + { Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); - k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, - (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + k = ggml_view_1d(ctx0, kv_self.k, n_tokens * n_state, + (ggml_element_size(kv_self.k) * n_state) * (il * n_ctx + kv_head)); v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); + (n_ctx)*ggml_element_size(kv_self.v), + (il * n_ctx) * ggml_element_size(kv_self.v) * n_state + kv_head * ggml_element_size(kv_self.v)); } ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); @@ -2578,45 +3172,48 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // ------ - struct ggml_tensor * Q = + struct ggml_tensor *Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), - 0, 2, 1, 3); + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), + 0, 2, 1, 3); - struct ggml_tensor * K = + struct ggml_tensor *K = ggml_view_3d(ctx0, kv_self.k, - n_state_head, n_kv, n_head, - ggml_element_size(kv_self.k)*n_state, - ggml_element_size(kv_self.k)*n_state_head, - ggml_element_size(kv_self.k)*n_state*n_ctx*il); + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.k) * n_state, + ggml_element_size(kv_self.k) * n_state_head, + ggml_element_size(kv_self.k) * n_state * n_ctx * il); - if (wctx.params.flash_attn) { - struct ggml_tensor * V = + if (wctx.params.flash_attn) + { + struct ggml_tensor *V = ggml_view_3d(ctx0, kv_self.v, - n_state_head, n_kv, n_head, - ggml_element_size(kv_self.v)*n_state, - ggml_element_size(kv_self.v)*n_state_head, - ggml_element_size(kv_self.v)*n_state*n_ctx*il); + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.v) * n_state, + ggml_element_size(kv_self.v) * n_state_head, + ggml_element_size(kv_self.v) * n_state * n_ctx * il); cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f, 0.0f); cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); - } else { + } + else + { // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); + struct ggml_tensor *KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); - struct ggml_tensor * V = + struct ggml_tensor *V = ggml_view_3d(ctx0, kv_self.v, - n_kv, n_state_head, n_head, - n_ctx*ggml_element_size(kv_self.v), - n_ctx*ggml_element_size(kv_self.v)*n_state_head, - n_ctx*ggml_element_size(kv_self.v)*n_state*il); + n_kv, n_state_head, n_head, + n_ctx * ggml_element_size(kv_self.v), + n_ctx * ggml_element_size(kv_self.v) * n_state_head, + n_ctx * ggml_element_size(kv_self.v) * n_state * il); - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor *KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens); } @@ -2625,16 +3222,16 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // projection { cur = ggml_mul_mat(ctx0, - layer.attn_ln_1_w, - cur); + layer.attn_ln_1_w, + cur); cur = ggml_add(ctx0, - cur, - layer.attn_ln_1_b); + cur, + layer.attn_ln_1_b); } // add the input - struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); + struct ggml_tensor *inpCA = ggml_add(ctx0, cur, inpL); // norm { @@ -2642,88 +3239,96 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // cur = ln_0_w*cur + ln_0_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - cur, - layer.cross_attn_ln_0_w), - layer.cross_attn_ln_0_b); + ggml_mul(ctx0, + cur, + layer.cross_attn_ln_0_w), + layer.cross_attn_ln_0_b); } // cross-attention { - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, - layer.cross_attn_q_w, - cur); + struct ggml_tensor *Qcur = ggml_mul_mat(ctx0, + layer.cross_attn_q_w, + cur); Qcur = ggml_add(ctx0, - Qcur, - layer.cross_attn_q_b); + Qcur, + layer.cross_attn_q_b); - struct ggml_tensor * Q = + struct ggml_tensor *Q = ggml_permute(ctx0, - ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), - 0, 2, 1, 3); + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), + 0, 2, 1, 3); - if (wctx.params.flash_attn) { - struct ggml_tensor * Kcross = + if (wctx.params.flash_attn) + { + struct ggml_tensor *Kcross = ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state_head, n_audio_ctx_pad, n_head, - ggml_element_size(wstate.kv_cross.k)*n_state, - ggml_element_size(wstate.kv_cross.k)*n_state_head, - ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il); + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.k) * n_state, + ggml_element_size(wstate.kv_cross.k) * n_state_head, + ggml_element_size(wstate.kv_cross.k) * n_state * n_audio_ctx_pad * il); - struct ggml_tensor * Vcross = + struct ggml_tensor *Vcross = ggml_view_3d(ctx0, wstate.kv_cross.v, - n_state_head, n_audio_ctx_pad, n_head, - ggml_element_size(wstate.kv_cross.v)*n_state, - ggml_element_size(wstate.kv_cross.v)*n_state_head, - ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il); + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.v) * n_state, + ggml_element_size(wstate.kv_cross.v) * n_state_head, + ggml_element_size(wstate.kv_cross.v) * n_state * n_audio_ctx_pad * il); cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f, 0.0f); cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); - } else { - struct ggml_tensor * Kcross = + } + else + { + struct ggml_tensor *Kcross = ggml_view_3d(ctx0, wstate.kv_cross.k, - n_state_head, n_audio_ctx, n_head, - ggml_element_size(wstate.kv_cross.k)*n_state, - ggml_element_size(wstate.kv_cross.k)*n_state_head, - ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); + n_state_head, n_audio_ctx, n_head, + ggml_element_size(wstate.kv_cross.k) * n_state, + ggml_element_size(wstate.kv_cross.k) * n_state_head, + ggml_element_size(wstate.kv_cross.k) * n_state * n_audio_ctx * il); - struct ggml_tensor * Vcross = + struct ggml_tensor *Vcross = ggml_view_3d(ctx0, wstate.kv_cross.v, - n_audio_ctx, n_state_head, n_head, - n_audio_ctx*ggml_element_size(wstate.kv_cross.v), - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head, - n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); + n_audio_ctx, n_state_head, n_head, + n_audio_ctx * ggml_element_size(wstate.kv_cross.v), + n_audio_ctx * ggml_element_size(wstate.kv_cross.v) * n_state_head, + n_audio_ctx * ggml_element_size(wstate.kv_cross.v) * n_state * il); // ------ // K * Q - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); + struct ggml_tensor *KQ = ggml_mul_mat(ctx0, Kcross, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + struct ggml_tensor *KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); // [EXPERIMENTAL] Token-level timestamps with DTW - if (wctx.params.dtw_token_timestamps) { - if (wstate.aheads_masks.m[il] != nullptr) { - struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); + if (wctx.params.dtw_token_timestamps) + { + if (wstate.aheads_masks.m[il] != nullptr) + { + struct ggml_tensor *aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); aheads_KQs = ggml_transpose(ctx0, aheads_KQs); aheads_KQs = ggml_cont(ctx0, aheads_KQs); aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); aheads_KQs = ggml_transpose(ctx0, aheads_KQs); aheads_KQs = ggml_cont(ctx0, aheads_KQs); aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); - if (aheads_cross_QKs == NULL) { + if (aheads_cross_QKs == NULL) + { aheads_cross_QKs = aheads_KQs; - } else { + } + else + { aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs, 2); } } } - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max); + struct ggml_tensor *KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max); - struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + struct ggml_tensor *KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens); } @@ -2732,18 +3337,18 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // projection { cur = ggml_mul_mat(ctx0, - layer.cross_attn_ln_1_w, - cur); + layer.cross_attn_ln_1_w, + cur); cur = ggml_add(ctx0, - cur, - layer.cross_attn_ln_1_b); + cur, + layer.cross_attn_ln_1_b); } // add the input cur = ggml_add(ctx0, cur, inpCA); - struct ggml_tensor * inpFF = cur; + struct ggml_tensor *inpFF = cur; // feed-forward network { @@ -2753,32 +3358,32 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // cur = mlp_ln_w*cur + mlp_ln_b cur = ggml_add(ctx0, - ggml_mul(ctx0, - cur, - layer.mlp_ln_w), - layer.mlp_ln_b); + ggml_mul(ctx0, + cur, + layer.mlp_ln_w), + layer.mlp_ln_b); } // fully connected cur = ggml_mul_mat(ctx0, - layer.mlp_0_w, - cur); + layer.mlp_0_w, + cur); cur = ggml_add(ctx0, - cur, - layer.mlp_0_b); + cur, + layer.mlp_0_b); // GELU activation cur = ggml_gelu(ctx0, cur); // projection cur = ggml_mul_mat(ctx0, - layer.mlp_1_w, - cur); + layer.mlp_1_w, + cur); cur = ggml_add(ctx0, - cur, - layer.mlp_1_b); + cur, + layer.mlp_1_b); } inpL = ggml_add(ctx0, cur, inpFF); @@ -2791,24 +3396,26 @@ static struct ggml_cgraph * whisper_build_graph_decoder( cur = ggml_norm(ctx0, cur, hparams.eps); cur = ggml_add(ctx0, - ggml_mul(ctx0, - cur, - model.d_ln_w), - model.d_ln_b); + ggml_mul(ctx0, + cur, + model.d_ln_w), + model.d_ln_b); } // compute logits only for the last token // comment this line to compute logits for all n_tokens // might be useful in the future - //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + // cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); - struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + struct ggml_tensor *logits = ggml_mul_mat(ctx0, model.d_te, cur); // [EXPERIMENTAL] Token-level timestamps with DTW - if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) { + if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) + { aheads_cross_QKs = ggml_transpose(ctx0, aheads_cross_QKs); aheads_cross_QKs = ggml_cont(ctx0, aheads_cross_QKs); - if (save_alignment_heads_QKs) { + if (save_alignment_heads_QKs) + { ggml_build_forward_expand(gf, aheads_cross_QKs); wstate.aheads_cross_QKs = aheads_cross_QKs; } @@ -2832,130 +3439,149 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // - n_past: number of past tokens to prefix the prompt with // static bool whisper_decode_internal( - whisper_context & wctx, - whisper_state & wstate, - const whisper_batch & batch, - const int n_threads, - bool save_alignment_heads_QKs, - ggml_abort_callback abort_callback, - void * abort_callback_data) { + whisper_context &wctx, + whisper_state &wstate, + const whisper_batch &batch, + const int n_threads, + bool save_alignment_heads_QKs, + ggml_abort_callback abort_callback, + void *abort_callback_data) +{ const int64_t t_start_us = ggml_time_us(); - const auto & model = wctx.model; - const auto & hparams = model.hparams; + const auto &model = wctx.model; + const auto &hparams = model.hparams; - const int n_vocab = hparams.n_vocab; + const int n_vocab = hparams.n_vocab; const int n_tokens = batch.n_tokens; - auto & logits_out = wstate.logits; + auto &logits_out = wstate.logits; - struct ggml_tensor * logits; + struct ggml_tensor *logits; // find KV slot for the batch { - auto & kv_self = wstate.kv_self; + auto &kv_self = wstate.kv_self; - if (!whisper_kv_cache_find_slot(kv_self, batch)) { + if (!whisper_kv_cache_find_slot(kv_self, batch)) + { return false; } const uint32_t pad = whisper_kv_cache_get_padding(wctx); kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad))); - //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); - //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); + // kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + // printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); } // decoder { - auto & sched = wstate.sched_decode.sched; + auto &sched = wstate.sched_decode.sched; - ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false); + ggml_cgraph *gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false); - if (!ggml_backend_sched_alloc_graph(sched, gf)) { + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { // should never happen as we pre-allocate the memory return false; } // set the inputs { - struct ggml_tensor * embd = ggml_graph_get_tensor(gf, "embd"); - ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd)); + struct ggml_tensor *embd = ggml_graph_get_tensor(gf, "embd"); + ggml_backend_tensor_set(embd, batch.token, 0, n_tokens * ggml_element_size(embd)); } { - struct ggml_tensor * position = ggml_graph_get_tensor(gf, "position"); - for (int i = 0; i < n_tokens; ++i) { + struct ggml_tensor *position = ggml_graph_get_tensor(gf, "position"); + for (int i = 0; i < n_tokens; ++i) + { const int32_t val = batch.pos[i]; - ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t)); + ggml_backend_tensor_set(position, &val, i * sizeof(int32_t), sizeof(int32_t)); } } { - struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask"); + struct ggml_tensor *KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask"); - auto & kv_self = wstate.kv_self; + auto &kv_self = wstate.kv_self; const int32_t n_kv = kv_self.n; wstate.inp_mask.resize(ggml_nelements(KQ_mask)); - float * data = wstate.inp_mask.data(); + float *data = wstate.inp_mask.data(); memset(data, 0, ggml_nbytes(KQ_mask)); - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - const whisper_pos pos = batch.pos[j]; + for (int h = 0; h < 1; ++h) + { + for (int j = 0; j < n_tokens; ++j) + { + const whisper_pos pos = batch.pos[j]; const whisper_seq_id seq_id = batch.seq_id[j][0]; - for (int i = 0; i < n_kv; ++i) { - if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + for (int i = 0; i < n_kv; ++i) + { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) + { + data[h * (n_kv * n_tokens) + j * n_kv + i] = -INFINITY; } } } - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) + { + for (int j = 0; j < n_kv; ++j) + { + data[h * (n_kv * n_tokens) + i * n_kv + j] = -INFINITY; } } } - ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float)); + ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask) * sizeof(float)); } logits = ggml_graph_node(gf, -1); - if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + if (!ggml_graph_compute_helper(sched, gf, n_threads)) + { return false; } } - logits_out.resize(n_tokens*n_vocab); - for (int i = 0; i < n_tokens; i++) { - if (batch.logits[i] == 0) { + logits_out.resize(n_tokens * n_vocab); + for (int i = 0; i < n_tokens; i++) + { + if (batch.logits[i] == 0) + { continue; } - ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab); + ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab * i), sizeof(float) * (n_vocab * i), sizeof(float) * n_vocab); } - if (batch.n_tokens > 1) { - //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, - // ggml_used_mem(ctx0)/1e6, - // wstate.get_buf_max_mem(0)/1e6, - // wstate.get_buf_max_mem(1)/1e6, - // wstate.get_buf_max_mem(2)/1e6, - // wstate.get_buf_max_mem(3)/1e6); + if (batch.n_tokens > 1) + { + // printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); } - if (batch.n_tokens == 1) { + if (batch.n_tokens == 1) + { wstate.t_decode_us += ggml_time_us() - t_start_us; wstate.n_decode++; - } else if (batch.n_tokens < 16) { + } + else if (batch.n_tokens < 16) + { wstate.t_batchd_us += ggml_time_us() - t_start_us; wstate.n_batchd += n_tokens; - } else { + } + else + { wstate.t_prompt_us += ggml_time_us() - t_start_us; wstate.n_prompt += n_tokens; } @@ -2965,7 +3591,8 @@ static bool whisper_decode_internal( // 500 -> 00:05.000 // 6000 -> 01:00.000 -static std::string to_timestamp(int64_t t, bool comma = false) { +static std::string to_timestamp(int64_t t, bool comma = false) +{ int64_t msec = t * 10; int64_t hr = msec / (1000 * 60 * 60); msec = msec - hr * (1000 * 60 * 60); @@ -2975,67 +3602,78 @@ static std::string to_timestamp(int64_t t, bool comma = false) { msec = msec - sec * 1000; char buf[32]; - snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int)hr, (int)min, (int)sec, comma ? "," : ".", (int)msec); return std::string(buf); } #define SIN_COS_N_COUNT WHISPER_N_FFT -namespace { -struct whisper_global_cache { - // In FFT, we frequently use sine and cosine operations with the same values. - // We can use precalculated values to speed up the process. - float sin_vals[SIN_COS_N_COUNT]; - float cos_vals[SIN_COS_N_COUNT]; - - // Hann window (Use cosf to eliminate difference) - // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html - // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 - float hann_window[WHISPER_N_FFT]; +namespace +{ + struct whisper_global_cache + { + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + float sin_vals[SIN_COS_N_COUNT]; + float cos_vals[SIN_COS_N_COUNT]; - whisper_global_cache() { - fill_sin_cos_table(); - fill_hann_window(sizeof(hann_window)/sizeof(hann_window[0]), true, hann_window); - } + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + float hann_window[WHISPER_N_FFT]; - void fill_sin_cos_table() { - for (int i = 0; i < SIN_COS_N_COUNT; i++) { - double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); + whisper_global_cache() + { + fill_sin_cos_table(); + fill_hann_window(sizeof(hann_window) / sizeof(hann_window[0]), true, hann_window); } - } - void fill_hann_window(int length, bool periodic, float * output) { - int offset = -1; - if (periodic) { - offset = 0; + void fill_sin_cos_table() + { + for (int i = 0; i < SIN_COS_N_COUNT; i++) + { + double theta = (2 * M_PI * i) / SIN_COS_N_COUNT; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } } - for (int i = 0; i < length; i++) { - output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + + void fill_hann_window(int length, bool periodic, float *output) + { + int offset = -1; + if (periodic) + { + offset = 0; + } + for (int i = 0; i < length; i++) + { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } } - } -} global_cache; + } global_cache; } // naive Discrete Fourier Transform // input is real-valued // output is complex-valued -static void dft(const float* in, int N, float* out) { +static void dft(const float *in, int N, float *out) +{ const int sin_cos_step = SIN_COS_N_COUNT / N; - for (int k = 0; k < N; k++) { + for (int k = 0; k < N; k++) + { float re = 0; float im = 0; - for (int n = 0; n < N; n++) { + for (int n = 0; n < N; n++) + { int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N - re += in[n]*global_cache.cos_vals[idx]; // cos(t) - im -= in[n]*global_cache.sin_vals[idx]; // sin(t) + re += in[n] * global_cache.cos_vals[idx]; // cos(t) + im -= in[n] * global_cache.sin_vals[idx]; // sin(t) } - out[k*2 + 0] = re; - out[k*2 + 1] = im; + out[k * 2 + 0] = re; + out[k * 2 + 1] = im; } } @@ -3043,53 +3681,60 @@ static void dft(const float* in, int N, float* out) { // poor man's implementation - use something better // input is real-valued // output is complex-valued -static void fft(float* in, int N, float* out) { - if (N == 1) { +static void fft(float *in, int N, float *out) +{ + if (N == 1) + { out[0] = in[0]; out[1] = 0; return; } const int half_N = N / 2; - if (N - half_N*2 == 1) { + if (N - half_N * 2 == 1) + { dft(in, N, out); return; } - float* even = in + N; - for (int i = 0; i < half_N; ++i) { - even[i]= in[2*i]; + float *even = in + N; + for (int i = 0; i < half_N; ++i) + { + even[i] = in[2 * i]; } - float* even_fft = out + 2 * N; + float *even_fft = out + 2 * N; fft(even, half_N, even_fft); - float* odd = even; - for (int i = 0; i < half_N; ++i) { - odd[i] = in[2*i + 1]; + float *odd = even; + for (int i = 0; i < half_N; ++i) + { + odd[i] = in[2 * i + 1]; } - float* odd_fft = even_fft + N; + float *odd_fft = even_fft + N; fft(odd, half_N, odd_fft); const int sin_cos_step = SIN_COS_N_COUNT / N; - for (int k = 0; k < half_N; k++) { - int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = global_cache.cos_vals[idx]; // cos(t) + for (int k = 0; k < half_N; k++) + { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = global_cache.cos_vals[idx]; // cos(t) float im = -global_cache.sin_vals[idx]; // sin(t) - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + out[2 * k + 0] = even_fft[2 * k + 0] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; - out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2 * (k + half_N) + 0] = even_fft[2 * k + 0] - re * re_odd + im * im_odd; + out[2 * (k + half_N) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; } } -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, +static void log_mel_spectrogram_worker_thread(int ith, const float *hann, const std::vector &samples, int n_samples, int frame_size, int frame_step, int n_threads, - const whisper_filters & filters, whisper_mel & mel) { + const whisper_filters &filters, whisper_mel &mel) +{ std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); @@ -3100,16 +3745,19 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const assert(n_fft == 1 + (frame_size / 2)); // calculate FFT only when fft_in are not all zero - for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) + { const int offset = i * frame_step; // apply Hann window (~10% faster) - for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { + for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) + { fft_in[j] = hann[j] * samples[offset + j]; } // fill the rest with zeros - if (n_samples - offset < frame_size) { + if (n_samples - offset < frame_size) + { std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); } @@ -3118,24 +3766,28 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. - for (int j = 0; j < n_fft; j++) { + for (int j = 0; j < n_fft; j++) + { fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); } // mel spectrogram - for (int j = 0; j < mel.n_mel; j++) { + for (int j = 0; j < mel.n_mel; j++) + { double sum = 0.0; // unroll loop (suggested by GH user @lunixbochs) int k = 0; - for (k = 0; k < n_fft - 3; k += 4) { + for (k = 0; k < n_fft - 3; k += 4) + { sum += - fft_out[k + 0] * filters.data[j * n_fft + k + 0] + - fft_out[k + 1] * filters.data[j * n_fft + k + 1] + - fft_out[k + 2] * filters.data[j * n_fft + k + 2] + - fft_out[k + 3] * filters.data[j * n_fft + k + 3]; + fft_out[k + 0] * filters.data[j * n_fft + k + 0] + + fft_out[k + 1] * filters.data[j * n_fft + k + 1] + + fft_out[k + 2] * filters.data[j * n_fft + k + 2] + + fft_out[k + 3] * filters.data[j * n_fft + k + 3]; } // handle n_fft remainder - for (; k < n_fft; k++) { + for (; k < n_fft; k++) + { sum += fft_out[k] * filters.data[j * n_fft + k]; } sum = log10(std::max(sum, 1e-10)); @@ -3145,8 +3797,10 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // Otherwise fft_out are all zero double sum = log10(1e-10); - for (; i < mel.n_len; i += n_threads) { - for (int j = 0; j < mel.n_mel; j++) { + for (; i < mel.n_len; i += n_threads) + { + for (int j = 0; j < mel.n_mel; j++) + { mel.data[j * mel.n_len + i] = sum; } } @@ -3154,22 +3808,23 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 static bool log_mel_spectrogram( - whisper_state & wstate, - const float * samples, - const int n_samples, - const int /*sample_rate*/, - const int frame_size, - const int frame_step, - const int n_mel, - const int n_threads, - const whisper_filters & filters, - const bool debug, - whisper_mel & mel) { + whisper_state &wstate, + const float *samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const whisper_filters &filters, + const bool debug, + whisper_mel &mel) +{ const int64_t t_start_us = ggml_time_us(); // Hann window WHISPER_ASSERT(frame_size == WHISPER_N_FFT && "Unsupported frame_size"); - const float * hann = global_cache.hann_window; + const float *hann = global_cache.hann_window; // Calculate the length of padding int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; @@ -3186,56 +3841,64 @@ static bool log_mel_spectrogram( // reflective pad 200 samples at the beginning of audio std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); - mel.n_mel = n_mel; + mel.n_mel = n_mel; // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 // Calculate number of frames + remove the last frame - mel.n_len = (samples_padded.size() - frame_size) / frame_step; + mel.n_len = (samples_padded.size() - frame_size) / frame_step; // Calculate semi-padded sample length to ensure compatibility mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; mel.data.resize(mel.n_mel * mel.n_len); { std::vector workers(n_threads - 1); - for (int iw = 0; iw < n_threads - 1; ++iw) { + for (int iw = 0; iw < n_threads - 1; ++iw) + { workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples + stage_2_pad, frame_size, frame_step, n_threads, - std::cref(filters), std::ref(mel)); + log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), + n_samples + stage_2_pad, frame_size, frame_step, n_threads, + std::cref(filters), std::ref(mel)); } // main thread log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); - for (int iw = 0; iw < n_threads - 1; ++iw) { + for (int iw = 0; iw < n_threads - 1; ++iw) + { workers[iw].join(); } } // clamping and normalization double mmax = -1e20; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] > mmax) { + for (int i = 0; i < mel.n_mel * mel.n_len; i++) + { + if (mel.data[i] > mmax) + { mmax = mel.data[i]; } } mmax -= 8.0; - for (int i = 0; i < mel.n_mel*mel.n_len; i++) { - if (mel.data[i] < mmax) { + for (int i = 0; i < mel.n_mel * mel.n_len; i++) + { + if (mel.data[i] < mmax) + { mel.data[i] = mmax; } - mel.data[i] = (mel.data[i] + 4.0)/4.0; + mel.data[i] = (mel.data[i] + 4.0) / 4.0; } wstate.t_mel_us += ggml_time_us() - t_start_us; // Dump log_mel_spectrogram - if (debug) { + if (debug) + { std::ofstream outFile("log_mel_spectrogram.json"); outFile << "["; - for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + for (uint64_t i = 0; i < mel.data.size() - 1; i++) + { outFile << mel.data[i] << ", "; } outFile << mel.data[mel.data.size() - 1] << "]"; @@ -3255,7 +3918,8 @@ static bool log_mel_spectrogram( // Regex (C++): // R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" // -static std::vector tokenize(const whisper_vocab & vocab, const std::string & text) { +static std::vector tokenize(const whisper_vocab &vocab, const std::string &text) +{ std::vector words; // first split the text into words @@ -3266,8 +3930,10 @@ static std::vector tokenize(const whisper_vocab & vocab, cons std::regex re(pat); std::smatch m; - while (std::regex_search(str, m, re)) { - for (auto x : m) { + while (std::regex_search(str, m, re)) + { + for (auto x : m) + { words.push_back(x); } str = m.suffix(); @@ -3276,18 +3942,23 @@ static std::vector tokenize(const whisper_vocab & vocab, cons // find the longest tokens that form the words: std::vector tokens; - for (const auto & word : words) { - if (word.empty()) continue; + for (const auto &word : words) + { + if (word.empty()) + continue; int i = 0; int n = word.size(); - while (i < n) { + while (i < n) + { int j = n; bool found = false; - while (j > i) { - auto sub = word.substr(i, j-i); + while (j > i) + { + auto sub = word.substr(i, j - i); auto it = vocab.token_to_id.find(sub); - if (it != vocab.token_to_id.end()) { + if (it != vocab.token_to_id.end()) + { tokens.push_back(it->second); i = j; found = true; @@ -3295,7 +3966,8 @@ static std::vector tokenize(const whisper_vocab & vocab, cons } --j; } - if (!found) { + if (!found) + { WHISPER_LOG_ERROR("unknown token\n"); ++i; } @@ -3311,17 +3983,21 @@ static std::vector tokenize(const whisper_vocab & vocab, cons #ifdef WHISPER_USE_COREML // replace .bin with -encoder.mlmodelc -static std::string whisper_get_coreml_path_encoder(std::string path_bin) { +static std::string whisper_get_coreml_path_encoder(std::string path_bin) +{ auto pos = path_bin.rfind('.'); - if (pos != std::string::npos) { + if (pos != std::string::npos) + { path_bin = path_bin.substr(0, pos); } // match "-qx_x" pos = path_bin.rfind('-'); - if (pos != std::string::npos) { + if (pos != std::string::npos) + { auto sub = path_bin.substr(pos); - if (sub.size() == 5 && sub[1] == 'q' && sub[3] == '_') { + if (sub.size() == 5 && sub[1] == 'q' && sub[3] == '_') + { path_bin = path_bin.substr(0, pos); } } @@ -3334,9 +4010,11 @@ static std::string whisper_get_coreml_path_encoder(std::string path_bin) { #ifdef WHISPER_USE_OPENVINO // replace .bin with-encoder-openvino.xml -static std::string whisper_openvino_get_path_encoder(std::string path_bin) { +static std::string whisper_openvino_get_path_encoder(std::string path_bin) +{ auto pos = path_bin.rfind('.'); - if (pos != std::string::npos) { + if (pos != std::string::npos) + { path_bin = path_bin.substr(0, pos); } @@ -3345,9 +4023,11 @@ static std::string whisper_openvino_get_path_encoder(std::string path_bin) { return path_bin; } -static std::string whisper_openvino_get_path_cache(std::string path_bin) { +static std::string whisper_openvino_get_path_cache(std::string path_bin) +{ auto pos = path_bin.rfind('.'); - if (pos != std::string::npos) { + if (pos != std::string::npos) + { path_bin = path_bin.substr(0, pos); } @@ -3357,11 +4037,13 @@ static std::string whisper_openvino_get_path_cache(std::string path_bin) { } #endif -struct whisper_state * whisper_init_state(whisper_context * ctx) { - whisper_state * state = new whisper_state; +struct whisper_state *whisper_init_state(whisper_context *ctx) +{ + whisper_state *state = new whisper_state; state->backends = whisper_backend_init(ctx->params); - if (state->backends.empty()) { + if (state->backends.empty()) + { WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); whisper_free_state(state); return nullptr; @@ -3371,9 +4053,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // later during decoding, if more decoders are used, we will recreate the KV cache respectively state->kv_self_n_dec = 1; if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype, - ctx->model.hparams.n_text_state, - ctx->model.hparams.n_text_layer, - GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) { + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) + { WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3385,9 +4068,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } if (!whisper_kv_cache_init(state->kv_cross, state->backends[0], ctx->itype, - ctx->model.hparams.n_text_state, - ctx->model.hparams.n_text_layer, - GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) + { WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for cross-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3399,9 +4083,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } if (!whisper_kv_cache_init(state->kv_pad, state->backends[0], ctx->itype, - ctx->model.hparams.n_audio_state, - 1, - GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { + ctx->model.hparams.n_audio_state, + 1, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) + { WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return nullptr; @@ -3413,8 +4098,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } // [EXPERIMENTAL] Token-level timestamps with DTW - if (ctx->params.dtw_token_timestamps) { - if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) { + if (ctx->params.dtw_token_timestamps) + { + if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, state->backends[0])) + { WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__); whisper_free_state(state); return nullptr; @@ -3430,13 +4117,16 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); - if (!state->ctx_coreml) { + if (!state->ctx_coreml) + { WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); #ifndef WHISPER_COREML_ALLOW_FALLBACK whisper_free_state(state); return nullptr; #endif - } else { + } + else + { WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); } #endif @@ -3448,9 +4138,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // TAGS: WHISPER_DECODER_INIT state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); - state->decoders[0].probs.reserve (ctx->vocab.n_vocab); - state->decoders[0].logits.reserve (ctx->vocab.n_vocab); - state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab); + state->decoders[0].probs.reserve(ctx->vocab.n_vocab); + state->decoders[0].logits.reserve(ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab); state->decoders[0].rng = std::mt19937(0); @@ -3458,11 +4148,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // conv allocator { bool ok = whisper_sched_graph_init(state->sched_conv, state->backends, - [&]() { - return whisper_build_graph_conv(*ctx, *state); - }); + [&]() + { + return whisper_build_graph_conv(*ctx, *state); + }); - if (!ok) { + if (!ok) + { WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__); whisper_free_state(state); return nullptr; @@ -3472,13 +4164,16 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } // encoder allocator - if (!whisper_encode_external(*state)) { + if (!whisper_encode_external(*state)) + { bool ok = whisper_sched_graph_init(state->sched_encode, state->backends, - [&]() { - return whisper_build_graph_encoder(*ctx, *state); - }); + [&]() + { + return whisper_build_graph_encoder(*ctx, *state); + }); - if (!ok) { + if (!ok) + { WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__); whisper_free_state(state); return nullptr; @@ -3490,11 +4185,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // cross allocator { bool ok = whisper_sched_graph_init(state->sched_cross, state->backends, - [&]() { - return whisper_build_graph_cross(*ctx, *state); - }); + [&]() + { + return whisper_build_graph_cross(*ctx, *state); + }); - if (!ok) { + if (!ok) + { WHISPER_LOG_ERROR("%s: failed to init cross allocator\n", __func__); whisper_free_state(state); return nullptr; @@ -3506,19 +4203,21 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { // decoder allocator { bool ok = whisper_sched_graph_init(state->sched_decode, state->backends, - [&]() { - const auto & hparams = ctx->model.hparams; + [&]() + { + const auto &hparams = ctx->model.hparams; - // TODO: make sure this is the worst-case scenario - const int n_tokens = hparams.n_text_ctx; - const int n_past = 0; + // TODO: make sure this is the worst-case scenario + const int n_tokens = hparams.n_text_ctx; + const int n_past = 0; - whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0); + whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0); - return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true); - }); + return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true); + }); - if (!ok) { + if (!ok) + { WHISPER_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); whisper_free_state(state); return nullptr; @@ -3531,11 +4230,12 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } int whisper_ctx_init_openvino_encoder_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - const char * model_path, - const char * device, - const char * cache_dir) { + struct whisper_context *ctx, + struct whisper_state *state, + const char *model_path, + const char *device, + const char *cache_dir) +{ #ifndef WHISPER_USE_OPENVINO (void)(ctx); (void)(state); @@ -3545,24 +4245,31 @@ int whisper_ctx_init_openvino_encoder_with_state( return 1; #else - if (!model_path && ctx->path_model.empty()) { + if (!model_path && ctx->path_model.empty()) + { WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); return 1; } std::string path_encoder; - if (!model_path) { - //if model_path is not set, attempt to find it in the same directory as ggml-.bin model + if (!model_path) + { + // if model_path is not set, attempt to find it in the same directory as ggml-.bin model path_encoder = whisper_openvino_get_path_encoder(ctx->path_model); - } else { + } + else + { path_encoder = model_path; } std::string path_cache; - if (!cache_dir) { - //if cache_dir is not set, set it as a dir residing next to ggml-.bin + if (!cache_dir) + { + // if cache_dir is not set, set it as a dir residing next to ggml-.bin path_cache = whisper_openvino_get_path_cache(ctx->path_model); - } else { + } + else + { path_cache = cache_dir; } @@ -3570,10 +4277,13 @@ int whisper_ctx_init_openvino_encoder_with_state( WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str()); - if (!state->ctx_openvino) { + if (!state->ctx_openvino) + { WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); return 1; - } else { + } + else + { WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__); } @@ -3582,32 +4292,35 @@ int whisper_ctx_init_openvino_encoder_with_state( } int whisper_ctx_init_openvino_encoder( - struct whisper_context * ctx, - const char * model_path, - const char * device, - const char * cache_dir) { + struct whisper_context *ctx, + const char *model_path, + const char *device, + const char *cache_dir) +{ return whisper_ctx_init_openvino_encoder_with_state(ctx, ctx->state, model_path, device, cache_dir); } -struct whisper_context_params whisper_context_default_params() { +struct whisper_context_params whisper_context_default_params() +{ struct whisper_context_params result = { - /*.use_gpu =*/ true, - /*.flash_attn =*/ true, - /*.gpu_device =*/ 0, - - /*.dtw_token_timestamps =*/ false, - /*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE, - /*.dtw_n_top =*/ -1, - /*.dtw_aheads =*/ { - /*.n_heads =*/ 0, - /*.heads =*/ NULL, + /*.use_gpu =*/true, + /*.flash_attn =*/true, + /*.gpu_device =*/0, + + /*.dtw_token_timestamps =*/false, + /*.dtw_aheads_preset =*/WHISPER_AHEADS_NONE, + /*.dtw_n_top =*/-1, + /*.dtw_aheads =*/{ + /*.n_heads =*/0, + /*.heads =*/NULL, }, - /*.dtw_mem_size =*/ 1024*1024*128, + /*.dtw_mem_size =*/1024 * 1024 * 128, }; return result; } -struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) { +struct whisper_context *whisper_init_from_file_with_params_no_state(const char *path_model, struct whisper_context_params params) +{ WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); #ifdef _MSC_VER // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. @@ -3617,7 +4330,8 @@ struct whisper_context * whisper_init_from_file_with_params_no_state(const char #else auto fin = std::ifstream(path_model, std::ios::binary); #endif - if (!fin) { + if (!fin) + { WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); return nullptr; } @@ -3626,39 +4340,45 @@ struct whisper_context * whisper_init_from_file_with_params_no_state(const char loader.context = &fin; - loader.read = [](void * ctx, void * output, size_t read_size) { - std::ifstream * fin = (std::ifstream*)ctx; + loader.read = [](void *ctx, void *output, size_t read_size) + { + std::ifstream *fin = (std::ifstream *)ctx; fin->read((char *)output, read_size); return read_size; }; - loader.eof = [](void * ctx) { - std::ifstream * fin = (std::ifstream*)ctx; + loader.eof = [](void *ctx) + { + std::ifstream *fin = (std::ifstream *)ctx; return fin->eof(); }; - loader.close = [](void * ctx) { - std::ifstream * fin = (std::ifstream*)ctx; + loader.close = [](void *ctx) + { + std::ifstream *fin = (std::ifstream *)ctx; fin->close(); }; auto ctx = whisper_init_with_params_no_state(&loader, params); - if (ctx) { + if (ctx) + { ctx->path_model = path_model; } return ctx; } -struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) { - struct buf_context { - uint8_t* buffer; +struct whisper_context *whisper_init_from_buffer_with_params_no_state(void *buffer, size_t buffer_size, struct whisper_context_params params) +{ + struct buf_context + { + uint8_t *buffer; size_t size; size_t current_offset; }; - buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; + buf_context ctx = {reinterpret_cast(buffer), buffer_size, 0}; WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__); @@ -3666,8 +4386,9 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu loader.context = &ctx; - loader.read = [](void * ctx, void * output, size_t read_size) { - buf_context * buf = reinterpret_cast(ctx); + loader.read = [](void *ctx, void *output, size_t read_size) + { + buf_context *buf = reinterpret_cast(ctx); size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; @@ -3677,21 +4398,24 @@ struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * bu return size_to_copy; }; - loader.eof = [](void * ctx) { - buf_context * buf = reinterpret_cast(ctx); + loader.eof = [](void *ctx) + { + buf_context *buf = reinterpret_cast(ctx); return buf->current_offset >= buf->size; }; - loader.close = [](void * /*ctx*/) { }; + loader.close = [](void * /*ctx*/) {}; return whisper_init_with_params_no_state(&loader, params); } -struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) { +struct whisper_context *whisper_init_with_params_no_state(struct whisper_model_loader *loader, struct whisper_context_params params) +{ ggml_time_init(); - if (params.flash_attn && params.dtw_token_timestamps) { + if (params.flash_attn && params.dtw_token_timestamps) + { WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); params.dtw_token_timestamps = false; } @@ -3703,10 +4427,11 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count()); WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count()); - whisper_context * ctx = new whisper_context; + whisper_context *ctx = new whisper_context; ctx->params = params; - if (!whisper_model_load(loader, *ctx)) { + if (!whisper_model_load(loader, *ctx)) + { loader->close(loader->context); WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); delete ctx; @@ -3718,14 +4443,17 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ return ctx; } -struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) { - whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params); - if (!ctx) { +struct whisper_context *whisper_init_from_file_with_params(const char *path_model, struct whisper_context_params params) +{ + whisper_context *ctx = whisper_init_from_file_with_params_no_state(path_model, params); + if (!ctx) + { return nullptr; } ctx->state = whisper_init_state(ctx); - if (!ctx->state) { + if (!ctx->state) + { whisper_free(ctx); return nullptr; } @@ -3733,14 +4461,17 @@ struct whisper_context * whisper_init_from_file_with_params(const char * path_mo return ctx; } -struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) { - whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params); - if (!ctx) { +struct whisper_context *whisper_init_from_buffer_with_params(void *buffer, size_t buffer_size, struct whisper_context_params params) +{ + whisper_context *ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) + { return nullptr; } ctx->state = whisper_init_state(ctx); - if (!ctx->state) { + if (!ctx->state) + { whisper_free(ctx); return nullptr; } @@ -3748,14 +4479,17 @@ struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, siz return ctx; } -struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) { - whisper_context * ctx = whisper_init_with_params_no_state(loader, params); - if (!ctx) { +struct whisper_context *whisper_init_with_params(struct whisper_model_loader *loader, struct whisper_context_params params) +{ + whisper_context *ctx = whisper_init_with_params_no_state(loader, params); + if (!ctx) + { return nullptr; } ctx->state = whisper_init_state(ctx); - if (!ctx->state) { + if (!ctx->state) + { whisper_free(ctx); return nullptr; } @@ -3763,45 +4497,55 @@ struct whisper_context * whisper_init_with_params(struct whisper_model_loader * return ctx; } -struct whisper_context * whisper_init_from_file(const char * path_model) { +struct whisper_context *whisper_init_from_file(const char *path_model) +{ return whisper_init_from_file_with_params(path_model, whisper_context_default_params()); } -struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { +struct whisper_context *whisper_init_from_buffer(void *buffer, size_t buffer_size) +{ return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params()); } -struct whisper_context * whisper_init(struct whisper_model_loader * loader) { +struct whisper_context *whisper_init(struct whisper_model_loader *loader) +{ return whisper_init_with_params(loader, whisper_context_default_params()); } -struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { +struct whisper_context *whisper_init_from_file_no_state(const char *path_model) +{ return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params()); } -struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { +struct whisper_context *whisper_init_from_buffer_no_state(void *buffer, size_t buffer_size) +{ return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params()); } -struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { +struct whisper_context *whisper_init_no_state(struct whisper_model_loader *loader) +{ return whisper_init_with_params_no_state(loader, whisper_context_default_params()); } -void whisper_free_state(struct whisper_state * state) { - if (state) { +void whisper_free_state(struct whisper_state *state) +{ + if (state) + { whisper_kv_cache_free(state->kv_self); whisper_kv_cache_free(state->kv_cross); whisper_kv_cache_free(state->kv_pad); #ifdef WHISPER_USE_COREML - if (state->ctx_coreml != nullptr) { + if (state->ctx_coreml != nullptr) + { whisper_coreml_free(state->ctx_coreml); state->ctx_coreml = nullptr; } #endif #ifdef WHISPER_USE_OPENVINO - if (state->ctx_openvino != nullptr) { + if (state->ctx_openvino != nullptr) + { whisper_openvino_free(state->ctx_openvino); state->ctx_openvino = nullptr; } @@ -3814,14 +4558,16 @@ void whisper_free_state(struct whisper_state * state) { ggml_backend_sched_free(state->sched_cross.sched); ggml_backend_sched_free(state->sched_decode.sched); - for (auto & backend : state->backends) { + for (auto &backend : state->backends) + { ggml_backend_free(backend); } // [EXPERIMENTAL] Token-level timestamps with DTW aheads_masks_free(state->aheads_masks); - if (state->vad_context != nullptr) { + if (state->vad_context != nullptr) + { whisper_vad_free(state->vad_context); state->vad_context = nullptr; } @@ -3830,13 +4576,17 @@ void whisper_free_state(struct whisper_state * state) { } } -void whisper_free(struct whisper_context * ctx) { - if (ctx) { - for (ggml_context * context : ctx->model.ctxs) { +void whisper_free(struct whisper_context *ctx) +{ + if (ctx) + { + for (ggml_context *context : ctx->model.ctxs) + { ggml_free(context); } - for (ggml_backend_buffer_t buf : ctx->model.buffers) { + for (ggml_backend_buffer_t buf : ctx->model.buffers) + { ggml_backend_buffer_free(buf); } @@ -3846,20 +4596,26 @@ void whisper_free(struct whisper_context * ctx) { } } -void whisper_free_context_params(struct whisper_context_params * params) { - if (params) { +void whisper_free_context_params(struct whisper_context_params *params) +{ + if (params) + { delete params; } } -void whisper_free_params(struct whisper_full_params * params) { - if (params) { +void whisper_free_params(struct whisper_full_params *params) +{ + if (params) + { delete params; } } -int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { +int whisper_pcm_to_mel_with_state(struct whisper_context *ctx, struct whisper_state *state, const float *samples, int n_samples, int n_threads) +{ + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) + { WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3867,41 +4623,47 @@ int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_s return 0; } -int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { +int whisper_pcm_to_mel(struct whisper_context *ctx, const float *samples, int n_samples, int n_threads) +{ return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); } int whisper_set_mel_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - const float * data, - int n_len, - int n_mel) { - if (n_mel != ctx->model.filters.n_mel) { + struct whisper_context *ctx, + struct whisper_state *state, + const float *data, + int n_len, + int n_mel) +{ + if (n_mel != ctx->model.filters.n_mel) + { WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); return -1; } - state->mel.n_len = n_len; + state->mel.n_len = n_len; state->mel.n_len_org = n_len; - state->mel.n_mel = n_mel; + state->mel.n_mel = n_mel; - state->mel.data.resize(n_len*n_mel); - memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + state->mel.data.resize(n_len * n_mel); + memcpy(state->mel.data.data(), data, n_len * n_mel * sizeof(float)); return 0; } int whisper_set_mel( - struct whisper_context * ctx, - const float * data, - int n_len, - int n_mel) { + struct whisper_context *ctx, + const float *data, + int n_len, + int n_mel) +{ return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); } -int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { - if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { +int whisper_encode_with_state(struct whisper_context *ctx, struct whisper_state *state, int offset, int n_threads) +{ + if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) + { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return -1; } @@ -3909,8 +4671,10 @@ int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state return 0; } -int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { +int whisper_encode(struct whisper_context *ctx, int offset, int n_threads) +{ + if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) + { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return -1; } @@ -3918,12 +4682,14 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { return 0; } -int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { +int whisper_decode_with_state(struct whisper_context *ctx, struct whisper_state *state, const whisper_token *tokens, int n_tokens, int n_past, int n_threads) +{ whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0); whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1); - if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) + { WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -3931,8 +4697,10 @@ int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state return 0; } -int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - if (ctx->state == nullptr) { +int whisper_decode(struct whisper_context *ctx, const whisper_token *tokens, int n_tokens, int n_past, int n_threads) +{ + if (ctx->state == nullptr) + { WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); return -1; } @@ -3940,38 +4708,48 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads); } -int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { +int whisper_tokenize(struct whisper_context *ctx, const char *text, whisper_token *tokens, int n_max_tokens) +{ const auto res = tokenize(ctx->vocab, text); - if (n_max_tokens < (int) res.size()) { - WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); - return -(int) res.size(); + if (n_max_tokens < (int)res.size()) + { + WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int)res.size(), n_max_tokens); + return -(int)res.size(); } - for (int i = 0; i < (int) res.size(); i++) { + for (int i = 0; i < (int)res.size(); i++) + { tokens[i] = res[i]; } return res.size(); } -int whisper_token_count(struct whisper_context * ctx, const char * text) { +int whisper_token_count(struct whisper_context *ctx, const char *text) +{ return -whisper_tokenize(ctx, text, NULL, 0); } -int whisper_lang_max_id(void) { +int whisper_lang_max_id(void) +{ auto max_id = 0; - for (const auto & kv : g_lang) { + for (const auto &kv : g_lang) + { max_id = std::max(max_id, kv.second.first); } return max_id; } -int whisper_lang_id(const char * lang) { - if (!g_lang.count(lang)) { - for (const auto & kv : g_lang) { - if (kv.second.second == lang) { +int whisper_lang_id(const char *lang) +{ + if (!g_lang.count(lang)) + { + for (const auto &kv : g_lang) + { + if (kv.second.second == lang) + { return kv.second.first; } } @@ -3982,9 +4760,12 @@ int whisper_lang_id(const char * lang) { return g_lang.at(lang).first; } -const char * whisper_lang_str(int id) { - for (const auto & kv : g_lang) { - if (kv.second.first == id) { +const char *whisper_lang_str(int id) +{ + for (const auto &kv : g_lang) + { + if (kv.second.first == id) + { return kv.first.c_str(); } } @@ -3993,9 +4774,12 @@ const char * whisper_lang_str(int id) { return nullptr; } -const char * whisper_lang_str_full(int id) { - for (const auto & kv : g_lang) { - if (kv.second.first == id) { +const char *whisper_lang_str_full(int id) +{ + for (const auto &kv : g_lang) + { + if (kv.second.first == id) + { return kv.second.second.c_str(); } } @@ -4005,40 +4789,46 @@ const char * whisper_lang_str_full(int id) { } int whisper_lang_auto_detect_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - int offset_ms, - int n_threads, - float * lang_probs) { - const int seek = offset_ms/10; - - if (seek < 0) { + struct whisper_context *ctx, + struct whisper_state *state, + int offset_ms, + int n_threads, + float *lang_probs) +{ + const int seek = offset_ms / 10; + + if (seek < 0) + { WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); return -1; } - if (seek >= state->mel.n_len_org) { - WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); + if (seek >= state->mel.n_len_org) + { + WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org * 10); return -2; } // run the encoder - if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) { + if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) + { WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); return -6; } - const std::vector prompt = { whisper_token_sot(ctx) }; + const std::vector prompt = {whisper_token_sot(ctx)}; - if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { + if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) + { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -7; } - auto & logits_id = state->decoders[0].logits_id; + auto &logits_id = state->decoders[0].logits_id; logits_id.clear(); - for (const auto & kv : g_lang) { + for (const auto &kv : g_lang) + { const auto token_lang = whisper_token_lang(ctx, kv.second.first); logits_id.emplace_back(state->logits[token_lang], kv.second.first); } @@ -4046,9 +4836,8 @@ int whisper_lang_auto_detect_with_state( // sort descending { using pair_type = std::remove_reference::type::value_type; - std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) { - return a.first > b.first; - }); + std::sort(logits_id.begin(), logits_id.end(), [](const pair_type &a, const pair_type &b) + { return a.first > b.first; }); } // softmax @@ -4056,23 +4845,27 @@ int whisper_lang_auto_detect_with_state( const auto max = logits_id[0].first; double sum = 0.0f; - for (auto & kv : logits_id) { + for (auto &kv : logits_id) + { kv.first = exp(kv.first - max); sum += kv.first; } - for (auto & kv : logits_id) { + for (auto &kv : logits_id) + { kv.first /= sum; } } { - for (const auto & prob : logits_id) { - if (lang_probs) { + for (const auto &prob : logits_id) + { + if (lang_probs) + { lang_probs[prob.second] = prob.first; } - //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first); + // printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first); } } @@ -4080,63 +4873,78 @@ int whisper_lang_auto_detect_with_state( } int whisper_lang_auto_detect( - struct whisper_context * ctx, - int offset_ms, - int n_threads, - float * lang_probs) { + struct whisper_context *ctx, + int offset_ms, + int n_threads, + float *lang_probs) +{ return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs); } -int whisper_model_n_vocab(struct whisper_context * ctx) { +int whisper_model_n_vocab(struct whisper_context *ctx) +{ return ctx->model.hparams.n_vocab; } -int whisper_model_n_audio_ctx(struct whisper_context * ctx) { +int whisper_model_n_audio_ctx(struct whisper_context *ctx) +{ return ctx->model.hparams.n_audio_ctx; } -int whisper_model_n_audio_state(struct whisper_context * ctx) { +int whisper_model_n_audio_state(struct whisper_context *ctx) +{ return ctx->model.hparams.n_audio_state; } -int whisper_model_n_audio_head(struct whisper_context * ctx) { +int whisper_model_n_audio_head(struct whisper_context *ctx) +{ return ctx->model.hparams.n_audio_head; } -int whisper_model_n_audio_layer(struct whisper_context * ctx) { +int whisper_model_n_audio_layer(struct whisper_context *ctx) +{ return ctx->model.hparams.n_audio_layer; } -int whisper_model_n_text_ctx(struct whisper_context * ctx) { +int whisper_model_n_text_ctx(struct whisper_context *ctx) +{ return ctx->model.hparams.n_text_ctx; } -int whisper_model_n_text_state(struct whisper_context * ctx) { +int whisper_model_n_text_state(struct whisper_context *ctx) +{ return ctx->model.hparams.n_text_state; } -int whisper_model_n_text_head(struct whisper_context * ctx) { +int whisper_model_n_text_head(struct whisper_context *ctx) +{ return ctx->model.hparams.n_text_head; } -int whisper_model_n_text_layer(struct whisper_context * ctx) { +int whisper_model_n_text_layer(struct whisper_context *ctx) +{ return ctx->model.hparams.n_text_layer; } -int whisper_model_n_mels(struct whisper_context * ctx) { +int whisper_model_n_mels(struct whisper_context *ctx) +{ return ctx->model.hparams.n_mels; } -int whisper_model_ftype(struct whisper_context * ctx) { +int whisper_model_ftype(struct whisper_context *ctx) +{ return ctx->model.hparams.ftype; } -int whisper_model_type(struct whisper_context * ctx) { +int whisper_model_type(struct whisper_context *ctx) +{ return ctx->model.type; } -const char *whisper_model_type_readable(struct whisper_context * ctx) { - switch (ctx->model.type) { +const char *whisper_model_type_readable(struct whisper_context *ctx) +{ + switch (ctx->model.type) + { case e_model::MODEL_TINY: return "tiny"; case e_model::MODEL_BASE: @@ -4152,87 +4960,108 @@ const char *whisper_model_type_readable(struct whisper_context * ctx) { } } -int whisper_n_len_from_state(struct whisper_state * state) { +int whisper_n_len_from_state(struct whisper_state *state) +{ return state->mel.n_len_org; } -int whisper_n_len(struct whisper_context * ctx) { +int whisper_n_len(struct whisper_context *ctx) +{ return ctx->state->mel.n_len_org; } -int whisper_n_vocab(struct whisper_context * ctx) { +int whisper_n_vocab(struct whisper_context *ctx) +{ return ctx->vocab.n_vocab; } -int whisper_n_text_ctx(struct whisper_context * ctx) { +int whisper_n_text_ctx(struct whisper_context *ctx) +{ return ctx->model.hparams.n_text_ctx; } -int whisper_n_audio_ctx(struct whisper_context * ctx) { +int whisper_n_audio_ctx(struct whisper_context *ctx) +{ return ctx->model.hparams.n_audio_ctx; } -int whisper_is_multilingual(struct whisper_context * ctx) { +int whisper_is_multilingual(struct whisper_context *ctx) +{ return ctx->vocab.is_multilingual() ? 1 : 0; } -float * whisper_get_logits(struct whisper_context * ctx) { +float *whisper_get_logits(struct whisper_context *ctx) +{ return ctx->state->logits.data(); } -float * whisper_get_logits_from_state(struct whisper_state * state) { +float *whisper_get_logits_from_state(struct whisper_state *state) +{ return state->logits.data(); } -const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { +const char *whisper_token_to_str(struct whisper_context *ctx, whisper_token token) +{ return ctx->vocab.id_to_token.at(token).c_str(); } -whisper_token whisper_token_eot(struct whisper_context * ctx) { +whisper_token whisper_token_eot(struct whisper_context *ctx) +{ return ctx->vocab.token_eot; } -whisper_token whisper_token_sot(struct whisper_context * ctx) { +whisper_token whisper_token_sot(struct whisper_context *ctx) +{ return ctx->vocab.token_sot; } -whisper_token whisper_token_solm(struct whisper_context * ctx) { +whisper_token whisper_token_solm(struct whisper_context *ctx) +{ return ctx->vocab.token_solm; } -whisper_token whisper_token_prev(struct whisper_context * ctx) { +whisper_token whisper_token_prev(struct whisper_context *ctx) +{ return ctx->vocab.token_prev; } -whisper_token whisper_token_nosp(struct whisper_context * ctx) { +whisper_token whisper_token_nosp(struct whisper_context *ctx) +{ return ctx->vocab.token_nosp; } -whisper_token whisper_token_not(struct whisper_context * ctx) { +whisper_token whisper_token_not(struct whisper_context *ctx) +{ return ctx->vocab.token_not; } -whisper_token whisper_token_beg(struct whisper_context * ctx) { +whisper_token whisper_token_beg(struct whisper_context *ctx) +{ return ctx->vocab.token_beg; } -whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { +whisper_token whisper_token_lang(struct whisper_context *ctx, int lang_id) +{ return whisper_token_sot(ctx) + 1 + lang_id; } -whisper_token whisper_token_translate(struct whisper_context * ctx) { +whisper_token whisper_token_translate(struct whisper_context *ctx) +{ return ctx->vocab.token_translate; } -whisper_token whisper_token_transcribe(struct whisper_context * ctx) { +whisper_token whisper_token_transcribe(struct whisper_context *ctx) +{ return ctx->vocab.token_transcribe; } -struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) { - if (ctx->state == nullptr) { +struct whisper_timings *whisper_get_timings(struct whisper_context *ctx) +{ + if (ctx->state == nullptr) + { return nullptr; } - whisper_timings * timings = new whisper_timings; + whisper_timings *timings = new whisper_timings; timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample); timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode); timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode); @@ -4241,12 +5070,14 @@ struct whisper_timings * whisper_get_timings(struct whisper_context * ctx) { return timings; } -void whisper_print_timings(struct whisper_context * ctx) { +void whisper_print_timings(struct whisper_context *ctx) +{ const int64_t t_end_us = ggml_time_us(); WHISPER_LOG_INFO("\n"); WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); - if (ctx->state != nullptr) { + if (ctx->state != nullptr) + { const int32_t n_sample = std::max(1, ctx->state->n_sample); const int32_t n_encode = std::max(1, ctx->state->n_encode); @@ -4262,12 +5093,14 @@ void whisper_print_timings(struct whisper_context * ctx) { WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); } - WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); + WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us) / 1000.0f); } -void whisper_reset_timings(struct whisper_context * ctx) { +void whisper_reset_timings(struct whisper_context *ctx) +{ ctx->t_start_us = ggml_time_us(); - if (ctx->state != nullptr) { + if (ctx->state != nullptr) + { ctx->state->t_mel_us = 0; ctx->state->t_sample_us = 0; ctx->state->t_encode_us = 0; @@ -4282,7 +5115,8 @@ void whisper_reset_timings(struct whisper_context * ctx) { } } -static int whisper_has_coreml(void) { +static int whisper_has_coreml(void) +{ #ifdef WHISPER_USE_COREML return 1; #else @@ -4290,7 +5124,8 @@ static int whisper_has_coreml(void) { #endif } -static int whisper_has_openvino(void) { +static int whisper_has_openvino(void) +{ #ifdef WHISPER_USE_OPENVINO return 1; #else @@ -4298,22 +5133,26 @@ static int whisper_has_openvino(void) { #endif } -const char * whisper_print_system_info(void) { +const char *whisper_print_system_info(void) +{ static std::string s; - s = ""; + s = ""; s += "WHISPER : "; - s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; - s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | "; - - for (size_t i = 0; i < ggml_backend_reg_count(); i++) { - auto * reg = ggml_backend_reg_get(i); - auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); - if (get_features_fn) { - ggml_backend_feature * features = get_features_fn(reg); + s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; + s += "OPENVINO = " + std::to_string(whisper_has_openvino()) + " | "; + + for (size_t i = 0; i < ggml_backend_reg_count(); i++) + { + auto *reg = ggml_backend_reg_get(i); + auto *get_features_fn = (ggml_backend_get_features_t)ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); + if (get_features_fn) + { + ggml_backend_feature *features = get_features_fn(reg); s += ggml_backend_reg_name(reg); s += " : "; - for (; features->name; features++) { + for (; features->name; features++) + { s += features->name; s += " = "; s += features->value; @@ -4328,49 +5167,51 @@ const char * whisper_print_system_info(void) { // Voice Activity Detection (VAD) ////////////////////////////////// -struct whisper_vad_hparams { - int32_t n_encoder_layers; - int32_t * encoder_in_channels; - int32_t * encoder_out_channels; - int32_t * kernel_sizes; - int32_t lstm_input_size; - int32_t lstm_hidden_size; - int32_t final_conv_in; - int32_t final_conv_out; +struct whisper_vad_hparams +{ + int32_t n_encoder_layers; + int32_t *encoder_in_channels; + int32_t *encoder_out_channels; + int32_t *kernel_sizes; + int32_t lstm_input_size; + int32_t lstm_hidden_size; + int32_t final_conv_in; + int32_t final_conv_out; }; -struct whisper_vad_model { +struct whisper_vad_model +{ std::string type; std::string version; whisper_vad_hparams hparams; - struct ggml_tensor * stft_forward_basis; // [256, 1, 258] + struct ggml_tensor *stft_forward_basis; // [256, 1, 258] // Encoder tensors - 4 convolutional layers - struct ggml_tensor * encoder_0_weight; // [3, 129, 128] - struct ggml_tensor * encoder_0_bias; // [128] + struct ggml_tensor *encoder_0_weight; // [3, 129, 128] + struct ggml_tensor *encoder_0_bias; // [128] // Second encoder layer - struct ggml_tensor * encoder_1_weight; // [3, 128, 64] - struct ggml_tensor * encoder_1_bias; // [64] + struct ggml_tensor *encoder_1_weight; // [3, 128, 64] + struct ggml_tensor *encoder_1_bias; // [64] // Third encoder layer - struct ggml_tensor * encoder_2_weight; // [3, 64, 64] - struct ggml_tensor * encoder_2_bias; // [64] + struct ggml_tensor *encoder_2_weight; // [3, 64, 64] + struct ggml_tensor *encoder_2_bias; // [64] // Fourth encoder layer - struct ggml_tensor * encoder_3_weight; // [3, 64, 128] - struct ggml_tensor * encoder_3_bias; // [128] + struct ggml_tensor *encoder_3_weight; // [3, 64, 128] + struct ggml_tensor *encoder_3_bias; // [128] // LSTM decoder tensors - struct ggml_tensor * lstm_ih_weight; // [128, 512] input-to-hidden - struct ggml_tensor * lstm_ih_bias; // [512] - struct ggml_tensor * lstm_hh_weight; // [128, 512] hidden-to-hidden - struct ggml_tensor * lstm_hh_bias; // [512] + struct ggml_tensor *lstm_ih_weight; // [128, 512] input-to-hidden + struct ggml_tensor *lstm_ih_bias; // [512] + struct ggml_tensor *lstm_hh_weight; // [128, 512] hidden-to-hidden + struct ggml_tensor *lstm_hh_bias; // [512] // Final conv layer - struct ggml_tensor * final_conv_weight; // [128] - struct ggml_tensor * final_conv_bias; // [1] + struct ggml_tensor *final_conv_weight; // [128] + struct ggml_tensor *final_conv_bias; // [1] // ggml contexts std::vector ctxs; @@ -4383,36 +5224,40 @@ struct whisper_vad_model { std::map tensors; }; -struct whisper_vad_segment { +struct whisper_vad_segment +{ int64_t start; int64_t end; }; -struct whisper_vad_segments { +struct whisper_vad_segments +{ std::vector data; }; -struct whisper_vad_context { +struct whisper_vad_context +{ int64_t t_vad_us = 0; - int n_window; - int n_context; - int n_threads; + int n_window; + int n_context; + int n_threads; std::vector backends; - ggml_backend_buffer_t buffer = nullptr; - whisper_context_params params; - std::vector ctx_buf; - whisper_sched sched; - - whisper_vad_model model; - std::string path_model; - struct ggml_tensor * h_state; - struct ggml_tensor * c_state; - std::vector probs; + ggml_backend_buffer_t buffer = nullptr; + whisper_context_params params; + std::vector ctx_buf; + whisper_sched sched; + + whisper_vad_model model; + std::string path_model; + struct ggml_tensor *h_state; + struct ggml_tensor *c_state; + std::vector probs; }; -struct whisper_vad_context_params whisper_vad_default_context_params(void) { +struct whisper_vad_context_params whisper_vad_default_context_params(void) +{ whisper_vad_context_params result = { /*.n_thread = */ 4, /*.use_gpu = */ false, @@ -4421,7 +5266,8 @@ struct whisper_vad_context_params whisper_vad_default_context_params(void) { return result; } -struct whisper_vad_params whisper_vad_default_params(void) { +struct whisper_vad_params whisper_vad_default_params(void) +{ whisper_vad_params result = { /* threshold = */ 0.5f, /* min_speech_duration_ms = */ 250, @@ -4434,66 +5280,79 @@ struct whisper_vad_params whisper_vad_default_params(void) { } // Time conversion utility functions for whisper VAD -static int cs_to_samples(int64_t cs) { +static int cs_to_samples(int64_t cs) +{ return (int)((cs / 100.0) * WHISPER_SAMPLE_RATE + 0.5); } -static int64_t samples_to_cs(int samples) { +static int64_t samples_to_cs(int samples) +{ return (int64_t)((samples / (double)WHISPER_SAMPLE_RATE) * 100.0 + 0.5); } -static bool weight_buft_supported(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { +static bool weight_buft_supported(const whisper_vad_hparams &hparams, ggml_tensor *w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) +{ bool op_supported = true; if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || - (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) + { // GPU and default CPU backend support all operators op_supported = true; - } else { - switch (op) { - // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT - case GGML_OP_MUL_MAT: { - ggml_init_params params = { - /*.mem_size =*/ 2 * ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error("failed to create ggml context"); - } - ggml_context * ctx = ctx_ptr.get(); + } + else + { + switch (op) + { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT + case GGML_OP_MUL_MAT: + { + ggml_init_params params = { + /*.mem_size =*/2 * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, + }; + + ggml_context_ptr ctx_ptr{ggml_init(params)}; + if (!ctx_ptr) + { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context *ctx = ctx_ptr.get(); - ggml_tensor * op_tensor = nullptr; + ggml_tensor *op_tensor = nullptr; - int64_t n_ctx = hparams.lstm_hidden_size; - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); - op_tensor = ggml_mul_mat(ctx, w, b); + int64_t n_ctx = hparams.lstm_hidden_size; + ggml_tensor *b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); - // create a temporary dummy buffer for the weight so that supports_op can check the buffer type - GGML_ASSERT(w->buffer == nullptr); - w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); - op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - ggml_backend_buffer_free(w->buffer); - w->buffer = nullptr; - break; - } - default: { - op_supported = false; - break; - } + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: + { + op_supported = false; + break; + } }; } return op_supported; } -static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { +static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams &hparams, ggml_tensor *w, ggml_op op, buft_list_t buft_list) +{ GGML_ASSERT(!buft_list.empty()); - for (const auto & p : buft_list) { + for (const auto &p : buft_list) + { ggml_backend_dev_t dev = p.first; ggml_backend_buffer_type_t buft = p.second; - if (weight_buft_supported(hparams, w, op, buft, dev)) { + if (weight_buft_supported(hparams, w, op, buft, dev)) + { return buft; } } @@ -4501,31 +5360,33 @@ static ggml_backend_buffer_type_t select_weight_buft(const whisper_vad_hparams & return nullptr; } -static ggml_tensor * whisper_vad_build_stft_layer(ggml_context * ctx0, - const whisper_vad_model & model, ggml_tensor * cur) { +static ggml_tensor *whisper_vad_build_stft_layer(ggml_context *ctx0, + const whisper_vad_model &model, ggml_tensor *cur) +{ // Apply reflective padding to the input tensor - ggml_tensor * padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64); + ggml_tensor *padded = ggml_pad_reflect_1d(ctx0, cur, 64, 64); - struct ggml_tensor * stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1); + struct ggml_tensor *stft = ggml_conv_1d(ctx0, model.stft_forward_basis, padded, model.hparams.lstm_input_size, 0, 1); // Calculate cutoff for real/imaginary parts int cutoff = model.stft_forward_basis->ne[2] / 2; // Extract real part (first half of the STFT output). - struct ggml_tensor * real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0); + struct ggml_tensor *real_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], 0); // Extract imaginary part (second half of the STFT output). - struct ggml_tensor * img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]); + struct ggml_tensor *img_part = ggml_view_2d(ctx0, stft, 4, cutoff, stft->nb[1], cutoff * stft->nb[1]); // Calculate magnitude: sqrt(real^2 + imag^2) - struct ggml_tensor * real_squared = ggml_mul(ctx0, real_part, real_part); - struct ggml_tensor * img_squared = ggml_mul(ctx0, img_part, img_part); - struct ggml_tensor * sum_squares = ggml_add(ctx0, real_squared, img_squared); - struct ggml_tensor * magnitude = ggml_sqrt(ctx0, sum_squares); + struct ggml_tensor *real_squared = ggml_mul(ctx0, real_part, real_part); + struct ggml_tensor *img_squared = ggml_mul(ctx0, img_part, img_part); + struct ggml_tensor *sum_squares = ggml_add(ctx0, real_squared, img_squared); + struct ggml_tensor *magnitude = ggml_sqrt(ctx0, sum_squares); return magnitude; } -static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0, - const whisper_vad_model & model, ggml_tensor * cur) { +static ggml_tensor *whisper_vad_build_encoder_layer(ggml_context *ctx0, + const whisper_vad_model &model, ggml_tensor *cur) +{ // First Conv1D: expands to 128 channels. cur = ggml_conv_1d(ctx0, model.encoder_0_weight, cur, 1, 1, 1); cur = ggml_add(ctx0, cur, ggml_reshape_3d(ctx0, model.encoder_0_bias, 1, 128, 1)); @@ -4549,69 +5410,71 @@ static ggml_tensor * whisper_vad_build_encoder_layer(ggml_context * ctx0, return cur; } -static ggml_tensor * whisper_vad_build_lstm_layer(ggml_context * ctx0, - const whisper_vad_context & vctx, ggml_tensor * cur, ggml_cgraph * gf) { - const whisper_vad_model & model = vctx.model; +static ggml_tensor *whisper_vad_build_lstm_layer(ggml_context *ctx0, + const whisper_vad_context &vctx, ggml_tensor *cur, ggml_cgraph *gf) +{ + const whisper_vad_model &model = vctx.model; const int hdim = model.hparams.lstm_hidden_size; - struct ggml_tensor * x_t = ggml_transpose(ctx0, cur); + struct ggml_tensor *x_t = ggml_transpose(ctx0, cur); // Create operations using the input-to-hidden weights. - struct ggml_tensor * inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t); + struct ggml_tensor *inp_gate = ggml_mul_mat(ctx0, model.lstm_ih_weight, x_t); inp_gate = ggml_add(ctx0, inp_gate, model.lstm_ih_bias); // Create operations using the hidden-to-hidden weights. - struct ggml_tensor * hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state); + struct ggml_tensor *hid_gate = ggml_mul_mat(ctx0, model.lstm_hh_weight, vctx.h_state); hid_gate = ggml_add(ctx0, hid_gate, model.lstm_hh_bias); // Create add operation to get preactivations for all gates. - struct ggml_tensor * out_gate = ggml_add(ctx0, inp_gate, hid_gate); + struct ggml_tensor *out_gate = ggml_add(ctx0, inp_gate, hid_gate); const size_t hdim_size = ggml_row_size(out_gate->type, hdim); // Create sigmoid for input gate (using the first 128 bytes from the preactivations). - struct ggml_tensor * i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size)); + struct ggml_tensor *i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 0 * hdim_size)); // Create sigmoid for the forget gate (using the second 128 bytes from the preactivations). - struct ggml_tensor * f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size)); + struct ggml_tensor *f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 1 * hdim_size)); // Create sigmoid for the cell gate (using the third 128 bytes from the preactivations). - struct ggml_tensor * g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size)); + struct ggml_tensor *g_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 2 * hdim_size)); // Create sigmoid for the output gate (using the fourth 128 bytes from the preactivations). - struct ggml_tensor * o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size)); + struct ggml_tensor *o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gate, hdim, 3 * hdim_size)); // Update cell state - struct ggml_tensor * c_out = ggml_add(ctx0, - ggml_mul(ctx0, f_t, vctx.c_state), - ggml_mul(ctx0, i_t, g_t)); + struct ggml_tensor *c_out = ggml_add(ctx0, + ggml_mul(ctx0, f_t, vctx.c_state), + ggml_mul(ctx0, i_t, g_t)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_out, vctx.c_state)); // Update hidden state - struct ggml_tensor * out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state)); + struct ggml_tensor *out = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_out)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, out, vctx.h_state)); return out; } -static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) { - const auto & model = vctx.model; +static struct ggml_cgraph *whisper_vad_build_graph(whisper_vad_context &vctx) +{ + const auto &model = vctx.model; struct ggml_init_params params = { - /*.mem_size =*/ vctx.sched.meta.size(), - /*.mem_buffer =*/ vctx.sched.meta.data(), - /*.no_alloc =*/ true, + /*.mem_size =*/vctx.sched.meta.size(), + /*.mem_buffer =*/vctx.sched.meta.data(), + /*.no_alloc =*/true, }; - struct ggml_context * ctx0 = ggml_init(params); + struct ggml_context *ctx0 = ggml_init(params); - ggml_cgraph * gf = ggml_new_graph(ctx0); + ggml_cgraph *gf = ggml_new_graph(ctx0); - struct ggml_tensor * frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1); + struct ggml_tensor *frame = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, vctx.n_window, 1); ggml_set_name(frame, "frame"); ggml_set_input(frame); - struct ggml_tensor * cur = nullptr; + struct ggml_tensor *cur = nullptr; { cur = whisper_vad_build_stft_layer(ctx0, model, frame); @@ -4637,32 +5500,35 @@ static struct ggml_cgraph * whisper_vad_build_graph(whisper_vad_context & vctx) return gf; } -static bool whisper_vad_init_context(whisper_vad_context * vctx) { +static bool whisper_vad_init_context(whisper_vad_context *vctx) +{ auto whisper_context_params = whisper_context_default_params(); // TODO: GPU VAD is forced disabled until the performance is improved - //whisper_context_params.use_gpu = vctx->params.use_gpu; - whisper_context_params.use_gpu = false; + // whisper_context_params.use_gpu = vctx->params.use_gpu; + whisper_context_params.use_gpu = false; whisper_context_params.gpu_device = vctx->params.gpu_device; vctx->backends = whisper_backend_init(whisper_context_params); - if (vctx->backends.empty()) { + if (vctx->backends.empty()) + { WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); return false; } const int32_t lstm_hidden_size = vctx->model.hparams.lstm_hidden_size; - vctx->ctx_buf.resize(2u*ggml_tensor_overhead()); + vctx->ctx_buf.resize(2u * ggml_tensor_overhead()); struct ggml_init_params params = { - /*.mem_size =*/ vctx->ctx_buf.size(), - /*.mem_buffer =*/ vctx->ctx_buf.data(), - /*.no_alloc =*/ true, + /*.mem_size =*/vctx->ctx_buf.size(), + /*.mem_buffer =*/vctx->ctx_buf.data(), + /*.no_alloc =*/true, }; - ggml_context * ctx = ggml_init(params); - if (!ctx) { + ggml_context *ctx = ggml_init(params); + if (!ctx) + { WHISPER_LOG_ERROR("%s: failed to init LSTM state ggml context\n", __func__); return false; } @@ -4676,18 +5542,21 @@ static bool whisper_vad_init_context(whisper_vad_context * vctx) { ggml_set_name(vctx->c_state, "c_state"); vctx->buffer = ggml_backend_alloc_ctx_tensors(ctx, vctx->backends[0]); - if (!vctx->buffer) { + if (!vctx->buffer) + { WHISPER_LOG_ERROR("%s: failed to allocate memory for the VAD state\n", __func__); return false; } { bool ok = whisper_sched_graph_init(vctx->sched, vctx->backends, - [&]() { - return whisper_vad_build_graph(*vctx); - }); + [&]() + { + return whisper_vad_build_graph(*vctx); + }); - if (!ok) { + if (!ok) + { WHISPER_LOG_ERROR("%s: failed to init VAD allocator\n", __func__); return false; } @@ -4698,9 +5567,10 @@ static bool whisper_vad_init_context(whisper_vad_context * vctx) { return true; } -struct whisper_vad_context * whisper_vad_init_from_file_with_params( - const char * path_model, - struct whisper_vad_context_params params) { +struct whisper_vad_context *whisper_vad_init_from_file_with_params( + const char *path_model, + struct whisper_vad_context_params params) +{ WHISPER_LOG_INFO("%s: loading VAD model from '%s'\n", __func__, path_model); #ifdef _MSC_VER std::wstring_convert> converter; @@ -4709,7 +5579,8 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params( #else auto fin = std::ifstream(path_model, std::ios::binary); #endif - if (!fin) { + if (!fin) + { WHISPER_LOG_ERROR("%s: failed to open VAD model '%s'\n", __func__, path_model); return nullptr; } @@ -4717,24 +5588,28 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params( whisper_model_loader loader = {}; loader.context = &fin; - loader.read = [](void * ctx, void * output, size_t read_size) { - std::ifstream * fin = (std::ifstream*)ctx; + loader.read = [](void *ctx, void *output, size_t read_size) + { + std::ifstream *fin = (std::ifstream *)ctx; fin->read((char *)output, read_size); return read_size; }; - loader.eof = [](void * ctx) { - std::ifstream * fin = (std::ifstream*)ctx; + loader.eof = [](void *ctx) + { + std::ifstream *fin = (std::ifstream *)ctx; return fin->eof(); }; - loader.close = [](void * ctx) { - std::ifstream * fin = (std::ifstream*)ctx; + loader.close = [](void *ctx) + { + std::ifstream *fin = (std::ifstream *)ctx; fin->close(); }; auto ctx = whisper_vad_init_with_params(&loader, params); - if (!ctx) { + if (!ctx) + { whisper_vad_free(ctx); return nullptr; } @@ -4742,26 +5617,28 @@ struct whisper_vad_context * whisper_vad_init_from_file_with_params( return ctx; } -struct whisper_vad_context * whisper_vad_init_with_params( - struct whisper_model_loader * loader, - struct whisper_vad_context_params params) { +struct whisper_vad_context *whisper_vad_init_with_params( + struct whisper_model_loader *loader, + struct whisper_vad_context_params params) +{ // Read the VAD model { uint32_t magic; read_safe(loader, magic); - if (magic != GGML_FILE_MAGIC) { + if (magic != GGML_FILE_MAGIC) + { WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); return nullptr; } } - whisper_vad_context * vctx = new whisper_vad_context; + whisper_vad_context *vctx = new whisper_vad_context; vctx->n_threads = params.n_threads; vctx->params.use_gpu = params.use_gpu; vctx->params.gpu_device = params.gpu_device; - auto & model = vctx->model; - auto & hparams = model.hparams; + auto &model = vctx->model; + auto &hparams = model.hparams; // load model context params. { @@ -4795,7 +5672,8 @@ struct whisper_vad_context * whisper_vad_init_with_params( hparams.encoder_out_channels = new int32_t[hparams.n_encoder_layers]; hparams.kernel_sizes = new int32_t[hparams.n_encoder_layers]; - for (int32_t i = 0; i < hparams.n_encoder_layers; i++) { + for (int32_t i = 0; i < hparams.n_encoder_layers; i++) + { read_safe(loader, hparams.encoder_in_channels[i]); read_safe(loader, hparams.encoder_out_channels[i]); read_safe(loader, hparams.kernel_sizes[i]); @@ -4807,10 +5685,12 @@ struct whisper_vad_context * whisper_vad_init_with_params( read_safe(loader, hparams.final_conv_out); WHISPER_LOG_INFO("%s: n_encoder_layers = %d\n", __func__, hparams.n_encoder_layers); - for (int32_t i = 0; i < hparams.n_encoder_layers; i++) { + for (int32_t i = 0; i < hparams.n_encoder_layers; i++) + { WHISPER_LOG_INFO("%s: encoder_in_channels[%d] = %d\n", __func__, i, hparams.encoder_in_channels[i]); } - for (int32_t i = 0; i < hparams.n_encoder_layers; i++) { + for (int32_t i = 0; i < hparams.n_encoder_layers; i++) + { WHISPER_LOG_INFO("%s: encoder_out_channels[%d] = %d\n", __func__, i, hparams.encoder_out_channels[i]); } WHISPER_LOG_INFO("%s: lstm_input_size = %d\n", __func__, hparams.lstm_input_size); @@ -4823,17 +5703,20 @@ struct whisper_vad_context * whisper_vad_init_with_params( const size_t n_tensors = hparams.n_encoder_layers * 2 + 4 + 2 + 1; std::map ctx_map; - auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * + { auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { + if (it == ctx_map.end()) + { ggml_init_params params = { - /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, + /*.mem_size =*/n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, }; - ggml_context * ctx = ggml_init(params); - if (!ctx) { + ggml_context *ctx = ggml_init(params); + if (!ctx) + { throw std::runtime_error("failed to create ggml context"); } @@ -4851,14 +5734,16 @@ struct whisper_vad_context * whisper_vad_init_with_params( wparams.gpu_device = params.gpu_device; buft_list_t buft_list = make_buft_list(wparams); - auto create_tensor = [&](vad_tensor type, ggml_tensor * meta) -> ggml_tensor * { + auto create_tensor = [&](vad_tensor type, ggml_tensor *meta) -> ggml_tensor * + { ggml_op op = VAD_TENSOR_OPS.at(type); ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); - if (!buft) { + if (!buft) + { throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", VAD_TENSOR_NAMES.at(type))); } - ggml_context * ctx = get_ctx(buft); - ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + ggml_context *ctx = get_ctx(buft); + ggml_tensor *tensor = ggml_dup_tensor(ctx, meta); model.tensors[VAD_TENSOR_NAMES.at(type)] = tensor; return tensor; @@ -4867,61 +5752,57 @@ struct whisper_vad_context * whisper_vad_init_with_params( // create tensors { ggml_init_params params = { - /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, + /*.mem_size =*/n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/nullptr, + /*.no_alloc =*/true, }; - ggml_context * ctx = ggml_init(params); - const auto & hparams = model.hparams; + ggml_context *ctx = ggml_init(params); + const auto &hparams = model.hparams; // SFTF precomputed basis matrix model.stft_forward_basis = create_tensor(VAD_TENSOR_STFT_BASIS, - ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258)); + ggml_new_tensor_3d(ctx, GGML_TYPE_F16, 256, 1, 258)); model.encoder_0_weight = create_tensor(VAD_TENSOR_ENC_0_WEIGHT, - ggml_new_tensor_3d( - ctx, - GGML_TYPE_F16, - hparams.kernel_sizes[0], - hparams.encoder_in_channels[0], - hparams.encoder_out_channels[0] - )); + ggml_new_tensor_3d( + ctx, + GGML_TYPE_F16, + hparams.kernel_sizes[0], + hparams.encoder_in_channels[0], + hparams.encoder_out_channels[0])); model.encoder_0_bias = create_tensor(VAD_TENSOR_ENC_0_BIAS, - ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0])); + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[0])); model.encoder_1_weight = create_tensor(VAD_TENSOR_ENC_1_WEIGHT, - ggml_new_tensor_3d( - ctx, - GGML_TYPE_F16, - hparams.kernel_sizes[1], - hparams.encoder_in_channels[1], - hparams.encoder_out_channels[1] - )); + ggml_new_tensor_3d( + ctx, + GGML_TYPE_F16, + hparams.kernel_sizes[1], + hparams.encoder_in_channels[1], + hparams.encoder_out_channels[1])); model.encoder_1_bias = create_tensor(VAD_TENSOR_ENC_1_BIAS, - ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1])); + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[1])); model.encoder_2_weight = create_tensor(VAD_TENSOR_ENC_2_WEIGHT, - ggml_new_tensor_3d( - ctx, - GGML_TYPE_F16, - hparams.kernel_sizes[2], - hparams.encoder_in_channels[2], - hparams.encoder_out_channels[2] - )); + ggml_new_tensor_3d( + ctx, + GGML_TYPE_F16, + hparams.kernel_sizes[2], + hparams.encoder_in_channels[2], + hparams.encoder_out_channels[2])); model.encoder_2_bias = create_tensor(VAD_TENSOR_ENC_2_BIAS, - ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2])); + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[2])); model.encoder_3_weight = create_tensor(VAD_TENSOR_ENC_3_WEIGHT, - ggml_new_tensor_3d( - ctx, - GGML_TYPE_F16, - hparams.kernel_sizes[3], - hparams.encoder_in_channels[3], - hparams.encoder_out_channels[3] - )); + ggml_new_tensor_3d( + ctx, + GGML_TYPE_F16, + hparams.kernel_sizes[3], + hparams.encoder_in_channels[3], + hparams.encoder_out_channels[3])); model.encoder_3_bias = create_tensor(VAD_TENSOR_ENC_3_BIAS, - ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3])); + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.encoder_out_channels[3])); // Hidden State dimension (input gate, forget gate, cell gate, output gate) const int hstate_dim = hparams.lstm_hidden_size * 4; @@ -4929,42 +5810,38 @@ struct whisper_vad_context * whisper_vad_init_with_params( // LSTM weights - input to hidden model.lstm_ih_weight = create_tensor( VAD_TENSOR_LSTM_WEIGHT_IH, - ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim) - ); + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)); model.lstm_ih_bias = create_tensor( VAD_TENSOR_LSTM_BIAS_IH, - ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim) - ); + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)); // LSTM weights - hidden to hidden model.lstm_hh_weight = create_tensor( VAD_TENSOR_LSTM_WEIGHT_HH, - ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim) - ); + ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.lstm_hidden_size, hstate_dim)); model.lstm_hh_bias = create_tensor( VAD_TENSOR_LSTM_BIAS_HH, - ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim) - ); + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hstate_dim)); // Final conv layer weight model.final_conv_weight = create_tensor( VAD_TENSOR_FINAL_CONV_WEIGHT, - ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1) - ); + ggml_new_tensor_2d(ctx, GGML_TYPE_F16, hparams.final_conv_in, 1)); model.final_conv_bias = create_tensor( VAD_TENSOR_FINAL_CONV_BIAS, - ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1) - ); + ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1)); ggml_free(ctx); } // allocate tensors in the backend buffers - for (auto & p : ctx_map) { + for (auto &p : ctx_map) + { ggml_backend_buffer_type_t buft = p.first; - ggml_context * ctx = p.second; + ggml_context *ctx = p.second; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); - if (buf) { + if (buf) + { model.buffers.emplace_back(buf); size_t size_main = ggml_backend_buffer_get_size(buf); @@ -4978,7 +5855,8 @@ struct whisper_vad_context * whisper_vad_init_with_params( model.n_loaded = 0; std::vector read_buf; - while (true) { + while (true) + { int32_t n_dims; int32_t length; int32_t ttype; @@ -4987,13 +5865,15 @@ struct whisper_vad_context * whisper_vad_init_with_params( read_safe(loader, length); read_safe(loader, ttype); - if (loader->eof(loader->context)) { + if (loader->eof(loader->context)) + { break; } int32_t nelements = 1; - int32_t ne[4] = { 1, 1, 1, 1 }; - for (int i = 0; i < n_dims; ++i) { + int32_t ne[4] = {1, 1, 1, 1}; + for (int i = 0; i < n_dims; ++i) + { read_safe(loader, ne[i]); nelements *= ne[i]; } @@ -5003,39 +5883,46 @@ struct whisper_vad_context * whisper_vad_init_with_params( loader->read(loader->context, &tmp[0], tmp.size()); name.assign(&tmp[0], tmp.size()); - if (model.tensors.find(name) == model.tensors.end()) { + if (model.tensors.find(name) == model.tensors.end()) + { WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); return nullptr; } auto tensor = model.tensors[name.data()]; - if (ggml_nelements(tensor) != nelements) { + if (ggml_nelements(tensor) != nelements) + { WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", - __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + __func__, ne[0], ne[1], ne[2], (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2]); return nullptr; } - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) + { WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", - __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + __func__, name.data(), (int)tensor->ne[0], (int)tensor->ne[1], (int)tensor->ne[2], ne[0], ne[1], ne[2]); return nullptr; } const size_t bpe = ggml_type_size(ggml_type(ttype)); - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + if ((nelements * bpe) / ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) + { WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + __func__, name.data(), ggml_nbytes(tensor), nelements * bpe); return nullptr; } - if (ggml_backend_buffer_is_host(tensor->buffer)) { + if (ggml_backend_buffer_is_host(tensor->buffer)) + { // for the CPU and Metal backend, we can read directly into the tensor loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); BYTESWAP_TENSOR(tensor); - } else { + } + else + { // read into a temporary buffer first, then copy to device memory read_buf.resize(ggml_nbytes(tensor)); @@ -5048,18 +5935,21 @@ struct whisper_vad_context * whisper_vad_init_with_params( model.n_loaded++; } - WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size / 1e6); - if (model.n_loaded == 0) { + if (model.n_loaded == 0) + { WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); - } else if (model.n_loaded != (int) model.tensors.size()) { + } + else if (model.n_loaded != (int)model.tensors.size()) + { WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); return nullptr; } - } - if (!whisper_vad_init_context(vctx)) { + if (!whisper_vad_init_context(vctx)) + { whisper_vad_free(vctx); return nullptr; } @@ -5068,12 +5958,14 @@ struct whisper_vad_context * whisper_vad_init_with_params( } bool whisper_vad_detect_speech( - struct whisper_vad_context * vctx, - const float * samples, - int n_samples) { + struct whisper_vad_context *vctx, + const float *samples, + int n_samples) +{ int n_chunks = n_samples / vctx->n_window; - if (n_samples % vctx->n_window != 0) { - n_chunks += 1; // Add one more chunk for remaining samples. + if (n_samples % vctx->n_window != 0) + { + n_chunks += 1; // Add one more chunk for remaining samples. } WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples); @@ -5087,28 +5979,31 @@ bool whisper_vad_detect_speech( std::vector window(vctx->n_window, 0.0f); - auto & sched = vctx->sched.sched; + auto &sched = vctx->sched.sched; - ggml_cgraph * gf = whisper_vad_build_graph(*vctx); + ggml_cgraph *gf = whisper_vad_build_graph(*vctx); - if (!ggml_backend_sched_alloc_graph(sched, gf)) { + if (!ggml_backend_sched_alloc_graph(sched, gf)) + { WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); return false; } - struct ggml_tensor * frame = ggml_graph_get_tensor(gf, "frame"); - struct ggml_tensor * prob = ggml_graph_get_tensor(gf, "prob"); + struct ggml_tensor *frame = ggml_graph_get_tensor(gf, "frame"); + struct ggml_tensor *prob = ggml_graph_get_tensor(gf, "prob"); // we are going to reuse the graph multiple times for each chunk const int64_t t_start_vad_us = ggml_time_us(); - for (int i = 0; i < n_chunks; i++) { + for (int i = 0; i < n_chunks; i++) + { const int idx_start = i * vctx->n_window; const int idx_end = std::min(idx_start + vctx->n_window, n_samples); const int chunk_len = idx_end - idx_start; - if (chunk_len < vctx->n_window) { + if (chunk_len < vctx->n_window) + { WHISPER_LOG_INFO("%s: chunk_len: %d < n_window: %d\n", __func__, chunk_len, vctx->n_window); std::vector partial_chunk(vctx->n_window, 0.0f); std::copy(samples + idx_start, samples + idx_end, partial_chunk.begin()); @@ -5117,10 +6012,13 @@ bool whisper_vad_detect_speech( const int samples_to_copy_max = vctx->n_window; const int samples_to_copy_cur = std::min(samples_to_copy_max, (int)partial_chunk.size()); std::copy(partial_chunk.begin(), partial_chunk.begin() + samples_to_copy_cur, window.begin()); - if (samples_to_copy_cur < samples_to_copy_max) { + if (samples_to_copy_cur < samples_to_copy_max) + { std::fill(window.begin() + samples_to_copy_cur, window.end(), 0.0f); } - } else { + } + else + { // Copy current frame samples to the window. const int samples_to_copy = std::min(idx_end - idx_start, vctx->n_window); std::copy(samples + idx_start, samples + idx_start + samples_to_copy, window.begin()); @@ -5130,7 +6028,8 @@ bool whisper_vad_detect_speech( ggml_backend_tensor_set(frame, window.data(), 0, ggml_nelements(frame) * sizeof(float)); // do not reset the scheduler - we will reuse the graph in the next chunk - if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) { + if (!ggml_graph_compute_helper(sched, gf, vctx->n_threads, false)) + { WHISPER_LOG_ERROR("%s: failed to compute VAD graph\n", __func__); break; } @@ -5138,7 +6037,7 @@ bool whisper_vad_detect_speech( // Get the probability for this chunk. ggml_backend_tensor_get(prob, &vctx->probs[i], 0, sizeof(float)); - //WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]); + // WHISPER_LOG_DEBUG("chunk %d: p = %7.3f\n", i, probs[i]); } vctx->t_vad_us += ggml_time_us() - t_start_vad_us; @@ -5149,56 +6048,66 @@ bool whisper_vad_detect_speech( return true; } -int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) { +int whisper_vad_segments_n_segments(struct whisper_vad_segments *segments) +{ return segments->data.size(); } -float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments * segments, int i_segment) { +float whisper_vad_segments_get_segment_t0(struct whisper_vad_segments *segments, int i_segment) +{ return segments->data[i_segment].start; } -float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments * segments, int i_segment) { +float whisper_vad_segments_get_segment_t1(struct whisper_vad_segments *segments, int i_segment) +{ return segments->data[i_segment].end; } -int whisper_vad_n_probs(struct whisper_vad_context * vctx) { +int whisper_vad_n_probs(struct whisper_vad_context *vctx) +{ return vctx->probs.size(); } -float * whisper_vad_probs(struct whisper_vad_context * vctx) { +float *whisper_vad_probs(struct whisper_vad_context *vctx) +{ return vctx->probs.data(); } -struct whisper_vad_segments * whisper_vad_segments_from_probs( - struct whisper_vad_context * vctx, - whisper_vad_params params) { +struct whisper_vad_segments *whisper_vad_segments_from_probs( + struct whisper_vad_context *vctx, + whisper_vad_params params) +{ WHISPER_LOG_INFO("%s: detecting speech timestamps using %d probabilities\n", __func__, whisper_vad_n_probs(vctx)); - int n_probs = whisper_vad_n_probs(vctx); - float * probs = whisper_vad_probs(vctx); - float threshold = params.threshold; - int min_speech_duration_ms = params.min_speech_duration_ms; - int min_silence_duration_ms = params.min_silence_duration_ms; - float max_speech_duration_s = params.max_speech_duration_s; - int speech_pad_ms = params.speech_pad_ms; - int n_window = vctx->n_window; - int sample_rate = WHISPER_SAMPLE_RATE; - int min_silence_samples = sample_rate * min_silence_duration_ms / 1000; - int audio_length_samples = n_probs * n_window; + int n_probs = whisper_vad_n_probs(vctx); + float *probs = whisper_vad_probs(vctx); + float threshold = params.threshold; + int min_speech_duration_ms = params.min_speech_duration_ms; + int min_silence_duration_ms = params.min_silence_duration_ms; + float max_speech_duration_s = params.max_speech_duration_s; + int speech_pad_ms = params.speech_pad_ms; + int n_window = vctx->n_window; + int sample_rate = WHISPER_SAMPLE_RATE; + int min_silence_samples = sample_rate * min_silence_duration_ms / 1000; + int audio_length_samples = n_probs * n_window; // Min number of samples to be considered valid speech. - int min_speech_samples = sample_rate * min_speech_duration_ms / 1000; - int speech_pad_samples = sample_rate * speech_pad_ms / 1000; + int min_speech_samples = sample_rate * min_speech_duration_ms / 1000; + int speech_pad_samples = sample_rate * speech_pad_ms / 1000; // Max number of samples that a speech segment can contain before it is // split into multiple segments. int max_speech_samples; - if (max_speech_duration_s > 100000.0f) { + if (max_speech_duration_s > 100000.0f) + { max_speech_samples = INT_MAX / 2; - } else { - int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s) - n_window - 2 * speech_pad_samples; + } + else + { + int64_t temp = (int64_t)sample_rate * (int64_t)(max_speech_duration_s)-n_window - 2 * speech_pad_samples; max_speech_samples = (temp > INT_MAX) ? INT_MAX / 2 : (int)temp; - if (max_speech_samples < 0) { + if (max_speech_samples < 0) + { max_speech_samples = INT_MAX / 2; } } @@ -5206,16 +6115,18 @@ struct whisper_vad_segments * whisper_vad_segments_from_probs( // is marked as a potential place where the segment could be split if // max_speech_samples is reached. The value 98 was taken from the original // silaro-vad python implementation: - //https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291 + // https://github.com/snakers4/silero-vad/blob/0dd45f0bcd7271463c234f3bae5ad25181f9df8b/src/silero_vad/utils_vad.py#L291 int min_silence_samples_at_max_speech = sample_rate * 98 / 1000; // Calculate lower threshold for detecting end of speech segments. float neg_threshold = threshold - 0.15f; - if (neg_threshold < 0.01f) { + if (neg_threshold < 0.01f) + { neg_threshold = 0.01f; } - struct speech_segment_t { + struct speech_segment_t + { int start; int end; }; @@ -5224,26 +6135,30 @@ struct whisper_vad_segments * whisper_vad_segments_from_probs( speeches.reserve(256); bool is_speech_segment = false; - int temp_end = 0; - int prev_end = 0; - int next_start = 0; - int curr_speech_start = 0; - bool has_curr_speech = false; + int temp_end = 0; + int prev_end = 0; + int next_start = 0; + int curr_speech_start = 0; + bool has_curr_speech = false; - for (int i = 0; i < n_probs; i++) { - float curr_prob = probs[i]; - int curr_sample = n_window * i; + for (int i = 0; i < n_probs; i++) + { + float curr_prob = probs[i]; + int curr_sample = n_window * i; // Reset temp_end when we get back to speech - if ((curr_prob >= threshold) && temp_end) { + if ((curr_prob >= threshold) && temp_end) + { temp_end = 0; - if (next_start < prev_end) { + if (next_start < prev_end) + { next_start = curr_sample; } } // Start a new speech segment when probability exceeds threshold and not already in speech - if ((curr_prob >= threshold) && !is_speech_segment) { + if ((curr_prob >= threshold) && !is_speech_segment) + { is_speech_segment = true; curr_speech_start = curr_sample; has_curr_speech = true; @@ -5251,20 +6166,27 @@ struct whisper_vad_segments * whisper_vad_segments_from_probs( } // Handle maximum speech duration - if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) { - if (prev_end) { - speeches.push_back({ curr_speech_start, prev_end }); + if (is_speech_segment && (curr_sample - curr_speech_start) > max_speech_samples) + { + if (prev_end) + { + speeches.push_back({curr_speech_start, prev_end}); has_curr_speech = true; - if (next_start < prev_end) { // Previously reached silence and is still not speech + if (next_start < prev_end) + { // Previously reached silence and is still not speech is_speech_segment = false; has_curr_speech = false; - } else { + } + else + { curr_speech_start = next_start; } prev_end = next_start = temp_end = 0; - } else { - speeches.push_back({ curr_speech_start, curr_sample }); + } + else + { + speeches.push_back({curr_speech_start, curr_sample}); prev_end = next_start = temp_end = 0; is_speech_segment = false; @@ -5274,23 +6196,30 @@ struct whisper_vad_segments * whisper_vad_segments_from_probs( } // Handle silence after speech - if ((curr_prob < neg_threshold) && is_speech_segment) { - if (!temp_end) { + if ((curr_prob < neg_threshold) && is_speech_segment) + { + if (!temp_end) + { temp_end = curr_sample; } // Track potential segment ends for max_speech handling - if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) { + if ((curr_sample - temp_end) > min_silence_samples_at_max_speech) + { prev_end = temp_end; } // Check if silence is long enough to end the segment - if ((curr_sample - temp_end) < min_silence_samples) { + if ((curr_sample - temp_end) < min_silence_samples) + { continue; - } else { + } + else + { // End the segment if it's long enough - if ((temp_end - curr_speech_start) > min_speech_samples) { - speeches.push_back({ curr_speech_start, temp_end }); + if ((temp_end - curr_speech_start) > min_speech_samples) + { + speeches.push_back({curr_speech_start, temp_end}); } prev_end = next_start = temp_end = 0; @@ -5302,21 +6231,25 @@ struct whisper_vad_segments * whisper_vad_segments_from_probs( } // Handle the case if we're still in a speech segment at the end - if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) { - speeches.push_back({ curr_speech_start, audio_length_samples }); + if (has_curr_speech && (audio_length_samples - curr_speech_start) > min_speech_samples) + { + speeches.push_back({curr_speech_start, audio_length_samples}); } // Merge adjacent segments with small gaps in between (post-processing) - if (speeches.size() > 1) { + if (speeches.size() > 1) + { int merged_count = 0; - for (int i = 0; i < (int) speeches.size() - 1; i++) { + for (int i = 0; i < (int)speeches.size() - 1; i++) + { // Define maximum gap allowed for merging (e.g., 200ms converted to samples) int max_merge_gap_samples = sample_rate * 200 / 1000; // If the gap between this segment and the next is small enough - if (speeches[i+1].start - speeches[i].end < max_merge_gap_samples) { + if (speeches[i + 1].start - speeches[i].end < max_merge_gap_samples) + { // Merge by extending current segment to the end of next segment - speeches[i].end = speeches[i+1].end; + speeches[i].end = speeches[i + 1].end; speeches.erase(speeches.begin() + i + 1); i--; @@ -5324,78 +6257,88 @@ struct whisper_vad_segments * whisper_vad_segments_from_probs( } } WHISPER_LOG_INFO("%s: Merged %d adjacent segments, now have %d segments\n", - __func__, merged_count, (int) speeches.size()); + __func__, merged_count, (int)speeches.size()); } // Double-check for minimum speech duration - for (int i = 0; i < (int) speeches.size(); i++) { - if (speeches[i].end - speeches[i].start < min_speech_samples) { + for (int i = 0; i < (int)speeches.size(); i++) + { + if (speeches[i].end - speeches[i].start < min_speech_samples) + { WHISPER_LOG_INFO("%s: Removing segment %d (too short: %d samples)\n", - __func__, i, speeches[i].end - speeches[i].start); + __func__, i, speeches[i].end - speeches[i].start); speeches.erase(speeches.begin() + i); i--; } } - WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int) speeches.size()); + WHISPER_LOG_INFO("%s: Final speech segments after filtering: %d\n", __func__, (int)speeches.size()); // Allocate final segments std::vector segments; - if (speeches.size() > 0) { - try { + if (speeches.size() > 0) + { + try + { segments.resize(speeches.size()); - } catch (const std::bad_alloc &) { + } + catch (const std::bad_alloc &) + { WHISPER_LOG_ERROR("%s: failed to allocate memory for final segments\n", __func__); return nullptr; } } // Apply padding to segments and copy to final segments - for (int i = 0; i < (int) speeches.size(); i++) { + for (int i = 0; i < (int)speeches.size(); i++) + { // Apply padding to the start of the first segment - if (i == 0) { + if (i == 0) + { speeches[i].start = - (speeches[i].start > speech_pad_samples) ? - (speeches[i].start - speech_pad_samples) : 0; + (speeches[i].start > speech_pad_samples) ? (speeches[i].start - speech_pad_samples) : 0; } // Handle spacing between segments - if (i < (int) speeches.size() - 1) { - int silence_duration = speeches[i+1].start - speeches[i].end; + if (i < (int)speeches.size() - 1) + { + int silence_duration = speeches[i + 1].start - speeches[i].end; - if (silence_duration < 2 * speech_pad_samples) { + if (silence_duration < 2 * speech_pad_samples) + { // If segments are close, split the difference speeches[i].end += silence_duration / 2; - speeches[i+1].start = - (speeches[i+1].start > silence_duration / 2) ? - (speeches[i+1].start - silence_duration / 2) : 0; - } else { + speeches[i + 1].start = + (speeches[i + 1].start > silence_duration / 2) ? (speeches[i + 1].start - silence_duration / 2) : 0; + } + else + { // Otherwise, apply full padding to both speeches[i].end = - (speeches[i].end + speech_pad_samples < audio_length_samples) ? - (speeches[i].end + speech_pad_samples) : audio_length_samples; - speeches[i+1].start = - (speeches[i+1].start > speech_pad_samples) ? - (speeches[i+1].start - speech_pad_samples) : 0; + (speeches[i].end + speech_pad_samples < audio_length_samples) ? (speeches[i].end + speech_pad_samples) : audio_length_samples; + speeches[i + 1].start = + (speeches[i + 1].start > speech_pad_samples) ? (speeches[i + 1].start - speech_pad_samples) : 0; } - } else { + } + else + { // Apply padding to the end of the last segment speeches[i].end = - (speeches[i].end + speech_pad_samples < audio_length_samples) ? - (speeches[i].end + speech_pad_samples) : audio_length_samples; + (speeches[i].end + speech_pad_samples < audio_length_samples) ? (speeches[i].end + speech_pad_samples) : audio_length_samples; } // Convert from samples to centiseconds segments[i].start = samples_to_cs(speeches[i].start); - segments[i].end = samples_to_cs(speeches[i].end); + segments[i].end = samples_to_cs(speeches[i].end); WHISPER_LOG_INFO("%s: VAD segment %d: start = %.2f, end = %.2f (duration: %.2f)\n", - __func__, i, segments[i].start/100.0, segments[i].end/100.0, (segments[i].end - segments[i].start)/100.0); + __func__, i, segments[i].start / 100.0, segments[i].end / 100.0, (segments[i].end - segments[i].start) / 100.0); } - whisper_vad_segments * vad_segments = new whisper_vad_segments; - if (vad_segments == NULL) { + whisper_vad_segments *vad_segments = new whisper_vad_segments; + if (vad_segments == NULL) + { WHISPER_LOG_ERROR("%s: failed to allocate memory for whisper_vad_segments\n", __func__); return nullptr; } @@ -5405,42 +6348,50 @@ struct whisper_vad_segments * whisper_vad_segments_from_probs( return vad_segments; } -struct whisper_vad_segments * whisper_vad_segments_from_samples( - whisper_vad_context * vctx, - whisper_vad_params params, - const float * samples, - int n_samples) { +struct whisper_vad_segments *whisper_vad_segments_from_samples( + whisper_vad_context *vctx, + whisper_vad_params params, + const float *samples, + int n_samples) +{ WHISPER_LOG_INFO("%s: detecting speech timestamps in %d samples\n", __func__, n_samples); - if (!whisper_vad_detect_speech(vctx, samples, n_samples)) { + if (!whisper_vad_detect_speech(vctx, samples, n_samples)) + { WHISPER_LOG_ERROR("%s: failed to detect speech\n", __func__); return nullptr; } return whisper_vad_segments_from_probs(vctx, params); } -void whisper_vad_free(whisper_vad_context * ctx) { - if (ctx) { - for (ggml_context * context : ctx->model.ctxs) { +void whisper_vad_free(whisper_vad_context *ctx) +{ + if (ctx) + { + for (ggml_context *context : ctx->model.ctxs) + { ggml_free(context); } - for (ggml_backend_buffer_t buf : ctx->model.buffers) { + for (ggml_backend_buffer_t buf : ctx->model.buffers) + { ggml_backend_buffer_free(buf); } ggml_backend_sched_free(ctx->sched.sched); - for (auto & backend : ctx->backends) { + for (auto &backend : ctx->backends) + { ggml_backend_free(backend); } - delete ctx; } } -void whisper_vad_free_segments(whisper_vad_segments * segments) { - if (segments) { +void whisper_vad_free_segments(whisper_vad_segments *segments) +{ + if (segments) + { delete segments; } } @@ -5452,87 +6403,105 @@ void whisper_vad_free_segments(whisper_vad_segments * segments) { // Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as // pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`. static std::pair, whisper_partial_utf8> decode_utf8( - const char * src, - whisper_partial_utf8 partial_start) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; - const char * pos = src; + const char *src, + whisper_partial_utf8 partial_start) +{ + static const int lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4}; + const char *pos = src; std::vector code_points; - uint32_t value = partial_start.value; - int n_remain = partial_start.n_remain; + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; // continue previous decode, if applicable - while (*pos != 0 && n_remain > 0) { + while (*pos != 0 && n_remain > 0) + { uint8_t next_byte = static_cast(*pos); - if ((next_byte >> 6) != 2) { + if ((next_byte >> 6) != 2) + { // invalid sequence, abort code_points.push_back(0); - return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 }); + return std::make_pair(std::move(code_points), whisper_partial_utf8{0, -1}); } value = (value << 6) + (next_byte & 0x3F); ++pos; --n_remain; } - if (partial_start.n_remain > 0 && n_remain == 0) { + if (partial_start.n_remain > 0 && n_remain == 0) + { code_points.push_back(value); } // decode any subsequent utf-8 sequences, which may end in an incomplete one - while (*pos != 0) { - uint8_t first_byte = static_cast(*pos); - uint8_t highbits = first_byte >> 4; - n_remain = lookup[highbits] - 1; + while (*pos != 0) + { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; - if (n_remain < 0) { + if (n_remain < 0) + { // invalid sequence, abort code_points.clear(); code_points.push_back(0); - return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain }); + return std::make_pair(std::move(code_points), whisper_partial_utf8{0, n_remain}); } - uint8_t mask = (1 << (7 - n_remain)) - 1; - value = first_byte & mask; + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; ++pos; - while (*pos != 0 && n_remain > 0) { + while (*pos != 0 && n_remain > 0) + { value = (value << 6) + (static_cast(*pos) & 0x3F); ++pos; --n_remain; } - if (n_remain == 0) { + if (n_remain == 0) + { code_points.push_back(value); } } code_points.push_back(0); - return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain }); + return std::make_pair(std::move(code_points), whisper_partial_utf8{value, n_remain}); } // returns true iff pos points to the end of one of the definitions of a rule -static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) { - switch (pos->type) { - case WHISPER_GRETYPE_END: return true; // NOLINT - case WHISPER_GRETYPE_ALT: return true; // NOLINT - default: return false; +static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element *pos) +{ + switch (pos->type) + { + case WHISPER_GRETYPE_END: + return true; // NOLINT + case WHISPER_GRETYPE_ALT: + return true; // NOLINT + default: + return false; } } // returns true iff chr satisfies the char range at pos (regular or inverse range) // asserts that pos is pointing to a char range element static std::pair whisper_grammar_match_char( - const whisper_grammar_element * pos, - const uint32_t chr) { + const whisper_grammar_element *pos, + const uint32_t chr) +{ - bool found = false; + bool found = false; bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT - do { - if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + do + { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) + { // inclusive range, e.g. [a-z] found = found || (pos->value <= chr && chr <= pos[1].value); pos += 2; - } else { + } + else + { // exact char match, e.g. [a] or "a" found = found || pos->value == chr; pos += 1; @@ -5546,42 +6515,54 @@ static std::pair whisper_grammar_match_ch // range at pos (regular or inverse range) // asserts that pos is pointing to a char range element static bool whisper_grammar_match_partial_char( - const whisper_grammar_element * pos, - const whisper_partial_utf8 partial_utf8) { + const whisper_grammar_element *pos, + const whisper_partial_utf8 partial_utf8) +{ bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); uint32_t partial_value = partial_utf8.value; - int n_remain = partial_utf8.n_remain; + int n_remain = partial_utf8.n_remain; // invalid sequence or 7-bit char split across 2 bytes (overlong) - if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) + { return false; } // range of possible code points this partial UTF-8 sequence could complete to - uint32_t low = partial_value << (n_remain * 6); + uint32_t low = partial_value << (n_remain * 6); uint32_t high = low | ((1 << (n_remain * 6)) - 1); - if (low == 0) { - if (n_remain == 2) { + if (low == 0) + { + if (n_remain == 2) + { low = 1 << 11; - } else if (n_remain == 3) { + } + else if (n_remain == 3) + { low = 1 << 16; } } - do { - if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + do + { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) + { // inclusive range, e.g. [a-z] - if (pos->value <= high && low <= pos[1].value) { + if (pos->value <= high && low <= pos[1].value) + { return is_positive_char; } pos += 2; - } else { + } + else + { // exact char match, e.g. [a] or "a" - if (low <= pos->value && pos->value <= high) { + if (low <= pos->value && pos->value <= high) + { return is_positive_char; } pos += 1; @@ -5591,59 +6572,69 @@ static bool whisper_grammar_match_partial_char( return !is_positive_char; } - // transforms a grammar pushdown stack into N possible stacks, all ending // at a character range (terminal element) static void whisper_grammar_advance_stack( - const std::vector> & rules, - const std::vector & stack, - std::vector> & new_stacks) { + const std::vector> &rules, + const std::vector &stack, + std::vector> &new_stacks) +{ - if (stack.empty()) { + if (stack.empty()) + { new_stacks.emplace_back(); return; } - const whisper_grammar_element * pos = stack.back(); + const whisper_grammar_element *pos = stack.back(); - switch (pos->type) { - case WHISPER_GRETYPE_RULE_REF: { - const size_t rule_id = static_cast(pos->value); - const whisper_grammar_element * subpos = rules[rule_id].data(); - do { - // init new stack without the top (pos) - std::vector new_stack(stack.begin(), stack.end() - 1); - if (!whisper_grammar_is_end_of_sequence(pos + 1)) { - // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); - } - if (!whisper_grammar_is_end_of_sequence(subpos)) { - // if alternate is nonempty, add to stack - new_stack.push_back(subpos); - } - whisper_grammar_advance_stack(rules, new_stack, new_stacks); - while (!whisper_grammar_is_end_of_sequence(subpos)) { - // scan to end of alternate def - subpos++; - } - if (subpos->type == WHISPER_GRETYPE_ALT) { - // there's another alternate def of this rule to process - subpos++; - } else { - break; - } - } while (true); - break; - } - case WHISPER_GRETYPE_CHAR: - case WHISPER_GRETYPE_CHAR_NOT: - new_stacks.push_back(stack); - break; - default: - // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range - // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on - // those - WHISPER_ASSERT(false); + switch (pos->type) + { + case WHISPER_GRETYPE_RULE_REF: + { + const size_t rule_id = static_cast(pos->value); + const whisper_grammar_element *subpos = rules[rule_id].data(); + do + { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos + 1)) + { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!whisper_grammar_is_end_of_sequence(subpos)) + { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + while (!whisper_grammar_is_end_of_sequence(subpos)) + { + // scan to end of alternate def + subpos++; + } + if (subpos->type == WHISPER_GRETYPE_ALT) + { + // there's another alternate def of this rule to process + subpos++; + } + else + { + break; + } + } while (true); + break; + } + case WHISPER_GRETYPE_CHAR: + case WHISPER_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range + // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + WHISPER_ASSERT(false); } } @@ -5652,24 +6643,29 @@ static void whisper_grammar_advance_stack( // produces the N possible stacks if the given char is accepted at those // positions static std::vector> whisper_grammar_accept( - const std::vector> & rules, - const std::vector> & stacks, - const uint32_t chr) { + const std::vector> &rules, + const std::vector> &stacks, + const uint32_t chr) +{ std::vector> new_stacks; - for (const auto & stack : stacks) { - if (stack.empty()) { + for (const auto &stack : stacks) + { + if (stack.empty()) + { continue; } auto match = whisper_grammar_match_char(stack.back(), chr); - if (match.first) { - const whisper_grammar_element * pos = match.second; + if (match.first) + { + const whisper_grammar_element *pos = match.second; // update top of stack to next element, if any std::vector new_stack(stack.begin(), stack.end() - 1); - if (!whisper_grammar_is_end_of_sequence(pos)) { + if (!whisper_grammar_is_end_of_sequence(pos)) + { new_stack.push_back(pos); } whisper_grammar_advance_stack(rules, new_stack, new_stacks); @@ -5680,87 +6676,106 @@ static std::vector> whisper_grammar } static std::vector whisper_grammar_reject_candidates( - const std::vector> & rules, - const std::vector> & stacks, - const std::vector & candidates); + const std::vector> &rules, + const std::vector> &stacks, + const std::vector &candidates); static std::vector whisper_grammar_reject_candidates_for_stack( - const std::vector> & rules, - const std::vector & stack, - const std::vector & candidates) { + const std::vector> &rules, + const std::vector &stack, + const std::vector &candidates) +{ std::vector rejects; - if (stack.empty()) { - for (auto tok : candidates) { - if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + if (stack.empty()) + { + for (auto tok : candidates) + { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) + { rejects.push_back(tok); } } return rejects; } - const whisper_grammar_element * stack_pos = stack.back(); + const whisper_grammar_element *stack_pos = stack.back(); std::vector next_candidates; - for (auto tok : candidates) { - if (*tok.code_points == 0) { + for (auto tok : candidates) + { + if (*tok.code_points == 0) + { // reached end of full codepoints in token, reject iff it ended in a partial sequence // that cannot satisfy this position in grammar - if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) + { rejects.push_back(tok); } - } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 }); - } else { + } + else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) + { + next_candidates.push_back({tok.id, tok.code_points + 1, tok.partial_utf8}); + } + else + { rejects.push_back(tok); } } - const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second; + const auto *stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second; // update top of stack to next element, if any std::vector stack_after(stack.begin(), stack.end() - 1); - if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) { + if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) + { stack_after.push_back(stack_pos_after); } std::vector> next_stacks; whisper_grammar_advance_stack(rules, stack_after, next_stacks); auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates); - for (auto tok : next_rejects) { - rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 }); + for (auto tok : next_rejects) + { + rejects.push_back({tok.id, tok.code_points - 1, tok.partial_utf8}); } return rejects; } static std::vector whisper_grammar_reject_candidates( - const std::vector> & rules, - const std::vector> & stacks, - const std::vector & candidates) { - if (candidates.empty() || stacks.empty()) { + const std::vector> &rules, + const std::vector> &stacks, + const std::vector &candidates) +{ + if (candidates.empty() || stacks.empty()) + { return std::vector(); } auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); - for (size_t i = 1, size = stacks.size(); i < size; ++i) { + for (size_t i = 1, size = stacks.size(); i < size; ++i) + { rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); } return rejects; } static struct whisper_grammar whisper_grammar_init( - const whisper_grammar_element ** rules, - size_t n_rules, - size_t i_start_rule) { - const whisper_grammar_element * pos; + const whisper_grammar_element **rules, + size_t n_rules, + size_t i_start_rule) +{ + const whisper_grammar_element *pos; // copy rule definitions into vectors std::vector> vec_rules(n_rules); - for (size_t i = 0; i < n_rules; i++) { - for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) { + for (size_t i = 0; i < n_rules; i++) + { + for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) + { vec_rules[i].push_back(*pos); } vec_rules[i].push_back({WHISPER_GRETYPE_END, 0}); @@ -5769,91 +6784,106 @@ static struct whisper_grammar whisper_grammar_init( // loop over alternates of start rule to build initial stacks std::vector> stacks; pos = rules[i_start_rule]; - do { + do + { std::vector stack; - if (!whisper_grammar_is_end_of_sequence(pos)) { + if (!whisper_grammar_is_end_of_sequence(pos)) + { // if alternate is nonempty, add to stack stack.push_back(pos); } whisper_grammar_advance_stack(vec_rules, stack, stacks); - while (!whisper_grammar_is_end_of_sequence(pos)) { + while (!whisper_grammar_is_end_of_sequence(pos)) + { // scan to end of alternate def pos++; } - if (pos->type == WHISPER_GRETYPE_ALT) { + if (pos->type == WHISPER_GRETYPE_ALT) + { // there's another alternate def of this rule to process pos++; - } else { + } + else + { break; } } while (true); - return { std::move(vec_rules), std::move(stacks), {} }; + return {std::move(vec_rules), std::move(stacks), {}}; } static void whisper_suppress_invalid_grammar( - whisper_context & ctx, - const whisper_full_params & params, - std::vector & logits, - const whisper_grammar & grammar) { + whisper_context &ctx, + const whisper_full_params ¶ms, + std::vector &logits, + const whisper_grammar &grammar) +{ - if (grammar.rules.empty() || grammar.stacks.empty()) { + if (grammar.rules.empty() || grammar.stacks.empty()) + { return; } - //bool allow_eot = false; - //for (const auto & stack : grammar.stacks) { - // if (stack.empty()) { - // allow_eot = true; - // break; - // } - //} + // bool allow_eot = false; + // for (const auto & stack : grammar.stacks) { + // if (stack.empty()) { + // allow_eot = true; + // break; + // } + // } const whisper_token eot = whisper_token_eot(&ctx); std::vector, whisper_partial_utf8>> candidates_decoded; - std::vector candidates_grammar; + std::vector candidates_grammar; - for (whisper_token id = 0; id < eot; ++id) { - const std::string & text = ctx.vocab.id_to_token[id]; - if (!text.empty()) { + for (whisper_token id = 0; id < eot; ++id) + { + const std::string &text = ctx.vocab.id_to_token[id]; + if (!text.empty()) + { candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); - candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + candidates_grammar.push_back({id, candidates_decoded.back().first.data(), candidates_decoded.back().second}); } } const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); - for (const auto & reject : rejects) { + for (const auto &reject : rejects) + { logits[reject.id] -= params.grammar_penalty; } // when the grammar allows a continuation, we penalize the end-of-text token - //if (!allow_eot) { + // if (!allow_eot) { // logits[eot] -= params.grammar_penalty; //} - //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); + // fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); } -static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { - if (grammar.rules.empty() || grammar.stacks.empty()) { +static void whisper_grammar_accept_token(whisper_context &ctx, whisper_grammar &grammar, whisper_token token) +{ + if (grammar.rules.empty() || grammar.stacks.empty()) + { return; } - //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); + // fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); - const std::string & text = ctx.vocab.id_to_token[token]; + const std::string &text = ctx.vocab.id_to_token[token]; - if (text.rfind("[_", 0) == 0) { + if (text.rfind("[_", 0) == 0) + { // fprintf(stderr, " (skipped)\n"); return; } // fprintf(stderr, "\n"); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8); - const auto & code_points = decoded.first; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8); + const auto &code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) + { grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it); } grammar.partial_utf8 = decoded.second; @@ -5865,158 +6895,167 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar //////////////////////////////////////////////////////////////////////////// -struct whisper_context_params * whisper_context_default_params_by_ref(void) { +struct whisper_context_params *whisper_context_default_params_by_ref(void) +{ struct whisper_context_params params = whisper_context_default_params(); - struct whisper_context_params* result = new whisper_context_params(); + struct whisper_context_params *result = new whisper_context_params(); *result = params; return result; } -struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) { +struct whisper_full_params *whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) +{ struct whisper_full_params params = whisper_full_default_params(strategy); - struct whisper_full_params* result = new whisper_full_params(); + struct whisper_full_params *result = new whisper_full_params(); *result = params; return result; } -struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) +{ struct whisper_full_params result = { - /*.strategy =*/ strategy, + /*.strategy =*/strategy, - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, + /*.n_threads =*/std::min(4, (int32_t)std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/16384, + /*.offset_ms =*/0, + /*.duration_ms =*/0, - /*.translate =*/ false, - /*.no_context =*/ true, - /*.no_timestamps =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, + /*.translate =*/false, + /*.no_context =*/true, + /*.no_timestamps =*/false, + /*.single_segment =*/false, + /*.print_special =*/false, + /*.print_progress =*/true, + /*.print_realtime =*/false, + /*.print_timestamps =*/true, - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.split_on_word =*/ false, - /*.max_tokens =*/ 0, + /*.token_timestamps =*/false, + /*.thold_pt =*/0.01f, + /*.thold_ptsum =*/0.01f, + /*.max_len =*/0, + /*.split_on_word =*/false, + /*.max_tokens =*/0, - /*.debug_mode =*/ false, - /*.audio_ctx =*/ 0, + /*.debug_mode =*/false, + /*.audio_ctx =*/0, - /*.tdrz_enable =*/ false, + /*.tdrz_enable =*/false, - /* suppress_regex =*/ nullptr, + /* suppress_regex =*/nullptr, - /*.initial_prompt =*/ nullptr, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, + /*.initial_prompt =*/nullptr, + /*.prompt_tokens =*/nullptr, + /*.prompt_n_tokens =*/0, - /*.language =*/ "en", - /*.detect_language =*/ false, + /*.language =*/"en", + /*.detect_language =*/false, - /*.suppress_blank =*/ true, - /*.suppress_nst =*/ false, + /*.suppress_blank =*/true, + /*.suppress_nst =*/false, - /*.temperature =*/ 0.0f, - /*.max_initial_ts =*/ 1.0f, - /*.length_penalty =*/ -1.0f, + /*.temperature =*/0.0f, + /*.max_initial_ts =*/1.0f, + /*.length_penalty =*/-1.0f, - /*.temperature_inc =*/ 0.2f, - /*.entropy_thold =*/ 2.4f, - /*.logprob_thold =*/ -1.0f, - /*.no_speech_thold =*/ 0.6f, + /*.temperature_inc =*/0.2f, + /*.entropy_thold =*/2.4f, + /*.logprob_thold =*/-1.0f, + /*.no_speech_thold =*/0.6f, - /*.greedy =*/ { - /*.best_of =*/ -1, + /*.greedy =*/{ + /*.best_of =*/-1, }, - /*.beam_search =*/ { - /*.beam_size =*/ -1, + /*.beam_search =*/{ + /*.beam_size =*/-1, - /*.patience =*/ -1.0f, + /*.patience =*/-1.0f, }, - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, + /*.new_segment_callback =*/nullptr, + /*.new_segment_callback_user_data =*/nullptr, - /*.progress_callback =*/ nullptr, - /*.progress_callback_user_data =*/ nullptr, + /*.progress_callback =*/nullptr, + /*.progress_callback_user_data =*/nullptr, - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/nullptr, + /*.encoder_begin_callback_user_data =*/nullptr, - /*.abort_callback =*/ nullptr, - /*.abort_callback_user_data =*/ nullptr, + /*.abort_callback =*/nullptr, + /*.abort_callback_user_data =*/nullptr, - /*.logits_filter_callback =*/ nullptr, - /*.logits_filter_callback_user_data =*/ nullptr, + /*.logits_filter_callback =*/nullptr, + /*.logits_filter_callback_user_data =*/nullptr, - /*.grammar_rules =*/ nullptr, - /*.n_grammar_rules =*/ 0, - /*.i_start_rule =*/ 0, - /*.grammar_penalty =*/ 100.0f, + /*.grammar_rules =*/nullptr, + /*.n_grammar_rules =*/0, + /*.i_start_rule =*/0, + /*.grammar_penalty =*/100.0f, - /*.vad =*/ false, - /*.vad_model_path =*/ nullptr, + /*.vad =*/false, + /*.vad_model_path =*/nullptr, - /* vad_params =*/ whisper_vad_default_params(), + /* vad_params =*/whisper_vad_default_params(), }; - switch (strategy) { - case WHISPER_SAMPLING_GREEDY: - { - result.greedy = { - /*.best_of =*/ 5, - }; - } break; - case WHISPER_SAMPLING_BEAM_SEARCH: - { - result.beam_search = { - /*.beam_size =*/ 5, + switch (strategy) + { + case WHISPER_SAMPLING_GREEDY: + { + result.greedy = { + /*.best_of =*/5, + }; + } + break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + result.beam_search = { + /*.beam_size =*/5, - /*.patience =*/ -1.0f, - }; - } break; + /*.patience =*/-1.0f, + }; + } + break; } return result; } // forward declarations -static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static std::vector get_signal_energy(const float *signal, int n_samples, int n_samples_per_half_window); static void whisper_exp_compute_token_level_timestamps( - struct whisper_context & ctx, - struct whisper_state & state, - int i_segment, - float thold_pt, - float thold_ptsum); + struct whisper_context &ctx, + struct whisper_state &state, + int i_segment, + float thold_pt, + float thold_ptsum); -static inline bool should_split_on_word(const char * txt, bool split_on_word) { - if (!split_on_word) return true; +static inline bool should_split_on_word(const char *txt, bool split_on_word) +{ + if (!split_on_word) + return true; return txt[0] == ' '; } static void whisper_exp_compute_token_level_timestamps_dtw( - struct whisper_context * ctx, - struct whisper_state * state, - struct whisper_full_params params, - int i_segment, - size_t n_segments, - int seek, - int n_frames, - int medfilt_width, - int n_threads); + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + int i_segment, + size_t n_segments, + int seek, + int n_frames, + int medfilt_width, + int n_threads); // wrap the last segment to max_len characters // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) { +static int whisper_wrap_segment(struct whisper_context &ctx, struct whisper_state &state, int max_len, bool split_on_word) +{ auto segment = state.result_all.back(); int res = 1; @@ -6024,16 +7063,19 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta std::string text; - for (int i = 0; i < (int) segment.tokens.size(); i++) { - const auto & token = segment.tokens[i]; - if (token.id >= whisper_token_eot(&ctx)) { + for (int i = 0; i < (int)segment.tokens.size(); i++) + { + const auto &token = segment.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) + { continue; } const auto txt = whisper_token_to_str(&ctx, token.id); const int cur = strlen(txt); - if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { + if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) + { state.result_all.back().text = std::move(text); state.result_all.back().t1 = token.t0; state.result_all.back().tokens.resize(i); @@ -6046,8 +7088,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta // add tokens [i, end] to the new segment state.result_all.back().tokens.insert( state.result_all.back().tokens.end(), - segment.tokens.begin() + i, - segment.tokens.end()); + segment.tokens.begin() + i, + segment.tokens.end()); state.result_all.back().speaker_turn_next = segment.speaker_turn_next; @@ -6058,7 +7100,9 @@ static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_sta i = -1; res++; - } else { + } + else + { acc += cur; text += txt; } @@ -6073,40 +7117,51 @@ static const std::vector non_speech_tokens = { "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", - "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" -}; + "♪♪♪", "♩", "♪", "♫", "♬", "♭", "♮", "♯"}; static void whisper_compute_logprobs( - const std::vector & logits, - const int n_logits, - std::vector & logprobs) { + const std::vector &logits, + const int n_logits, + std::vector &logprobs) +{ const float logit_max = *std::max_element(logits.begin(), logits.end()); float logsumexp = 0.0f; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] > -INFINITY) + { logsumexp += expf(logits[i] - logit_max); } } logsumexp = logf(logsumexp) + logit_max; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] > -INFINITY) + { logprobs[i] = logits[i] - logsumexp; - } else { + } + else + { logprobs[i] = -INFINITY; } } } static void whisper_compute_probs( - const std::vector & logits, - const int n_logits, - const std::vector & logprobs, - std::vector & probs) { - for (int i = 0; i < n_logits; ++i) { - if (logits[i] == -INFINITY) { + const std::vector &logits, + const int n_logits, + const std::vector &logprobs, + std::vector &probs) +{ + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] == -INFINITY) + { probs[i] = 0.0f; - } else { + } + else + { probs[i] = expf(logprobs[i]); } } @@ -6117,30 +7172,33 @@ static void whisper_compute_probs( // - computes logprobs and probs // TODO: optimize static void whisper_process_logits( - struct whisper_context & ctx, - struct whisper_state & state, - struct whisper_decoder & decoder, - const struct whisper_full_params params, - float temperature) { - const auto & vocab = ctx.vocab; - const auto & tokens_cur = decoder.sequence.tokens; + struct whisper_context &ctx, + struct whisper_state &state, + struct whisper_decoder &decoder, + const struct whisper_full_params params, + float temperature) +{ + const auto &vocab = ctx.vocab; + const auto &tokens_cur = decoder.sequence.tokens; const bool is_initial = tokens_cur.size() == 0; - const int n_logits = vocab.id_to_token.size(); + const int n_logits = vocab.id_to_token.size(); WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); // extract the logits for the last token // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly - auto & probs = decoder.probs; - auto & logits = decoder.logits; - auto & logprobs = decoder.logprobs; + auto &probs = decoder.probs; + auto &logits = decoder.logits; + auto &logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float)); + memcpy(logits.data(), state.logits.data() + decoder.i_batch * n_logits, n_logits * sizeof(float)); - if (temperature > 0.0f) { - for (int i = 0; i < n_logits; i++) { + if (temperature > 0.0f) + { + for (int i = 0; i < n_logits; i++) + { logits[i] /= temperature; } } @@ -6155,9 +7213,11 @@ static void whisper_process_logits( { // suppress blank // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 - if (params.suppress_blank) { - if (is_initial) { - logits[vocab.token_eot] = -INFINITY; + if (params.suppress_blank) + { + if (is_initial) + { + logits[vocab.token_eot] = -INFINITY; logits[vocab.token_to_id.at(" ")] = -INFINITY; } } @@ -6165,44 +7225,52 @@ static void whisper_process_logits( // suppress <|notimestamps|> token // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; - if (params.no_timestamps) { - for (int i = vocab.token_beg; i < n_logits; ++i) { + if (params.no_timestamps) + { + for (int i = vocab.token_beg; i < n_logits; ++i) + { logits[i] = -INFINITY; } } // suppress sot and nosp tokens - logits[vocab.token_sot] = -INFINITY; + logits[vocab.token_sot] = -INFINITY; logits[vocab.token_nosp] = -INFINITY; // [TDRZ] when tinydiarize is disabled, suppress solm token - if (params.tdrz_enable == false) { + if (params.tdrz_enable == false) + { logits[vocab.token_solm] = -INFINITY; } // suppress task tokens - logits[vocab.token_translate] = -INFINITY; + logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; - logits[vocab.token_prev] = -INFINITY; + logits[vocab.token_prev] = -INFINITY; // suppress lang tokens - for (size_t i = 0; i < g_lang.size(); ++i) { + for (size_t i = 0; i < g_lang.size(); ++i) + { logits[whisper_token_lang(&ctx, i)] = -INFINITY; } // suppress prev token logits[vocab.token_prev] = -INFINITY; - if (params.logits_filter_callback) { + if (params.logits_filter_callback) + { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } // suppress any tokens matching a regular expression // ref: https://github.com/openai/whisper/discussions/1041 - if (params.suppress_regex != nullptr) { + if (params.suppress_regex != nullptr) + { std::regex re(params.suppress_regex); - for (std::pair token_id : vocab.token_to_id) { - if (std::regex_match(token_id.first, re)) { + for (std::pair token_id : vocab.token_to_id) + { + if (std::regex_match(token_id.first, re)) + { logits[token_id.second] = -INFINITY; } } @@ -6210,21 +7278,27 @@ static void whisper_process_logits( // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 - if (params.suppress_nst) { - for (const std::string & token : non_speech_tokens) { + if (params.suppress_nst) + { + for (const std::string &token : non_speech_tokens) + { const std::string suppress_tokens[] = {token, " " + token}; - for (const std::string & suppress_token : suppress_tokens) { - if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) { + for (const std::string &suppress_token : suppress_tokens) + { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) + { logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; } } } // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word - if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) { + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) + { logits[vocab.token_to_id.at(" -")] = -INFINITY; } - if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) { + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) + { logits[vocab.token_to_id.at(" '")] = -INFINITY; } } @@ -6232,18 +7306,24 @@ static void whisper_process_logits( // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 { - const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; - //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + // WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); - if (last_was_timestamp) { - if (penultimate_was_timestamp) { - for (int i = vocab.token_beg; i < n_logits; ++i) { + if (last_was_timestamp) + { + if (penultimate_was_timestamp) + { + for (int i = vocab.token_beg; i < n_logits; ++i) + { logits[i] = -INFINITY; } - } else { - for (int i = 0; i < vocab.token_eot; ++i) { + } + else + { + for (int i = 0; i < vocab.token_eot; ++i) + { logits[i] = -INFINITY; } } @@ -6252,21 +7332,25 @@ static void whisper_process_logits( // the initial timestamp cannot be larger than max_initial_ts // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 - if (is_initial && params.max_initial_ts > 0.0f) { - const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; - const int tid0 = std::round(params.max_initial_ts/precision); + if (is_initial && params.max_initial_ts > 0.0f) + { + const float precision = float(WHISPER_CHUNK_SIZE) / ctx.model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_ts / precision); - for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) + { logits[i] = -INFINITY; } } // condition timestamp tokens to be increasing // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556 - if (decoder.has_ts) { - const int tid0 = decoder.seek_delta/2; + if (decoder.has_ts) + { + const int tid0 = decoder.seek_delta / 2; - for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) { + for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) + { logits[i] = -INFINITY; } } @@ -6282,44 +7366,58 @@ static void whisper_process_logits( { float logsumexp = 0.0f; const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); - for (int i = vocab.token_beg; i < n_logits; ++i) { - if (logprobs[i] > -INFINITY) { + for (int i = vocab.token_beg; i < n_logits; ++i) + { + if (logprobs[i] > -INFINITY) + { logsumexp += expf(logprobs[i] - logprob_max); } } - if (logsumexp > 0.0f) { + if (logsumexp > 0.0f) + { timestamp_logprob = logf(logsumexp) + logprob_max; } } const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); - //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + // WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); - if (timestamp_logprob > max_text_token_logprob) { - for (int i = 0; i < vocab.token_beg; ++i) { - logits[i] = -INFINITY; + if (timestamp_logprob > max_text_token_logprob) + { + for (int i = 0; i < vocab.token_beg; ++i) + { + logits[i] = -INFINITY; logprobs[i] = -INFINITY; } - } else { - if (params.n_grammar_rules > 0) { + } + else + { + if (params.n_grammar_rules > 0) + { whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); // populate the logprobs array (log_softmax) { const float logit_max = *std::max_element(logits.begin(), logits.end()); float logsumexp = 0.0f; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] > -INFINITY) + { logsumexp += expf(logits[i] - logit_max); } } logsumexp = logf(logsumexp) + logit_max; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { + for (int i = 0; i < n_logits; ++i) + { + if (logits[i] > -INFINITY) + { logprobs[i] = logits[i] - logsumexp; - } else { + } + else + { logprobs[i] = -INFINITY; } } @@ -6386,13 +7484,17 @@ static void whisper_process_logits( #endif } -static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) { - if (a.tokens.size() != b.tokens.size()) { +static bool whisper_sequence_tokens_equal(const whisper_sequence &a, const whisper_sequence &b) +{ + if (a.tokens.size() != b.tokens.size()) + { return false; } // sequences are more likely to diverge at the end - for (int i = a.tokens.size() - 1; i >= 0; i--) { - if (a.tokens[i].id != b.tokens[i].id) { + for (int i = a.tokens.size() - 1; i >= 0; i--) + { + if (a.tokens[i].id != b.tokens[i].id) + { return false; } } @@ -6400,17 +7502,27 @@ static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whis } static whisper_token_data whisper_sample_token( - whisper_context & ctx, - const whisper_decoder & decoder, - bool best) { + whisper_context &ctx, + const whisper_decoder &decoder, + bool best) +{ whisper_token_data result = { - 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f, + 0, + 0, + 0.0f, + 0.0f, + 0.0f, + 0.0f, + -1, + -1, + -1, + 0.0f, }; - const auto & vocab = ctx.vocab; + const auto &vocab = ctx.vocab; - const auto & probs = decoder.probs; - const auto & logprobs = decoder.logprobs; + const auto &probs = decoder.probs; + const auto &logprobs = decoder.logprobs; const int n_logits = vocab.n_vocab; @@ -6418,62 +7530,73 @@ static whisper_token_data whisper_sample_token( double sum_ts = 0.0; double max_ts = 0.0; - for (int i = vocab.token_beg; i < n_logits; i++) { - if (probs[i] == -INFINITY) { + for (int i = vocab.token_beg; i < n_logits; i++) + { + if (probs[i] == -INFINITY) + { continue; } sum_ts += probs[i]; - if (max_ts < probs[i]) { + if (max_ts < probs[i]) + { max_ts = probs[i]; result.tid = i; } } - result.pt = max_ts/(sum_ts + 1e-10); + result.pt = max_ts / (sum_ts + 1e-10); result.ptsum = sum_ts; } - if (best) { - for (int i = 0; i < n_logits; ++i) { - if (result.p < probs[i]) { - result.id = i; - result.p = probs[i]; + if (best) + { + for (int i = 0; i < n_logits; ++i) + { + if (result.p < probs[i]) + { + result.id = i; + result.p = probs[i]; result.plog = logprobs[i]; } } - } else { + } + else + { std::discrete_distribution<> dist(probs.begin(), probs.end()); - result.id = dist(decoder.rng); - result.p = probs[result.id]; + result.id = dist(decoder.rng); + result.p = probs[result.id]; result.plog = logprobs[result.id]; } - if (result.id >= vocab.token_beg) { + if (result.id >= vocab.token_beg) + { result.tid = result.id; - result.pt = result.p; + result.pt = result.p; } return result; } static std::vector whisper_sample_token_topk( - whisper_context & ctx, - whisper_decoder & decoder, - int k) { - const auto & vocab = ctx.vocab; + whisper_context &ctx, + whisper_decoder &decoder, + int k) +{ + const auto &vocab = ctx.vocab; - const auto & probs = decoder.probs; - const auto & logits = decoder.logits; - const auto & logprobs = decoder.logprobs; + const auto &probs = decoder.probs; + const auto &logits = decoder.logits; + const auto &logprobs = decoder.logprobs; const int n_logits = vocab.n_vocab; - auto & logits_id = decoder.logits_id; + auto &logits_id = decoder.logits_id; logits_id.resize(n_logits); - for (int i = 0; i < n_logits; ++i) { + for (int i = 0; i < n_logits; ++i) + { logits_id[i].first = logits[i]; logits_id[i].second = i; } @@ -6481,11 +7604,12 @@ static std::vector whisper_sample_token_topk( { using pair_type = std::remove_reference::type::value_type; std::partial_sort( - logits_id.begin(), - logits_id.begin() + k, logits_id.end(), - [](const pair_type & a, const pair_type & b) { - return a.first > b.first; - }); + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const pair_type &a, const pair_type &b) + { + return a.first > b.first; + }); } std::vector result; @@ -6493,40 +7617,56 @@ static std::vector whisper_sample_token_topk( whisper_token tid = vocab.token_beg; - float pt = 0.0; + float pt = 0.0; float ptsum = 0.0; { double sum_ts = 0.0; double max_ts = 0.0; - for (int i = vocab.token_beg; i < n_logits; i++) { - if (probs[i] == -INFINITY) { + for (int i = vocab.token_beg; i < n_logits; i++) + { + if (probs[i] == -INFINITY) + { continue; } sum_ts += probs[i]; - if (max_ts < probs[i]) { + if (max_ts < probs[i]) + { max_ts = probs[i]; tid = i; } } - pt = max_ts/(sum_ts + 1e-10); + pt = max_ts / (sum_ts + 1e-10); ptsum = sum_ts; } std::discrete_distribution<> dist(probs.begin(), probs.end()); - for (int i = 0; i < k; ++i) { + for (int i = 0; i < k; ++i) + { const auto id = dist(decoder.rng); - //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); - - result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, }); + // printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); + + result.push_back({ + id, + tid, + probs[id], + logprobs[id], + pt, + ptsum, + -1, + -1, + -1, + 0.0f, + }); - if (result[i].id >= vocab.token_beg) { + if (result[i].id >= vocab.token_beg) + { result[i].tid = result[i].id; - result[i].pt = result[i].p; + result[i].pt = result[i].p; } } @@ -6535,28 +7675,32 @@ static std::vector whisper_sample_token_topk( // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 static void whisper_sequence_score( - const struct whisper_full_params & params, - whisper_sequence & sequence) { - if (sequence.result_len == 0) { + const struct whisper_full_params ¶ms, + whisper_sequence &sequence) +{ + if (sequence.result_len == 0) + { return; } double result = 0.0f; - for (int i = 0; i < sequence.result_len; ++i) { + for (int i = 0; i < sequence.result_len; ++i) + { result += sequence.tokens[i].plog; } sequence.sum_logprobs = result; - sequence.avg_logprobs = result/sequence.result_len; + sequence.avg_logprobs = result / sequence.result_len; double penalty = sequence.result_len; - if (params.length_penalty > 0.0f) { - penalty = pow((5.0 + penalty)/6.0, params.length_penalty); + if (params.length_penalty > 0.0f) + { + penalty = pow((5.0 + penalty) / 6.0, params.length_penalty); } - sequence.score = result/penalty; + sequence.score = result / penalty; // compute the entropy of the sequence of the last 32 tokens { @@ -6566,16 +7710,18 @@ static void whisper_sequence_score( double entropy = 0.0f; std::map token_counts; - for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) + { token_counts[sequence.tokens[i].id]++; cnt++; } - for (const auto & kv : token_counts) { - const auto p = kv.second/(double)cnt; - entropy -= p*log(p); + for (const auto &kv : token_counts) + { + const auto p = kv.second / (double)cnt; + entropy -= p * log(p); - //WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + // WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); } sequence.entropy = entropy; @@ -6583,12 +7729,13 @@ static void whisper_sequence_score( } static bool whisper_vad( - struct whisper_context * ctx, - struct whisper_state * state, - struct whisper_full_params params, - const float * samples, - int n_samples, - std::vector & filtered_samples) { + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + const float *samples, + int n_samples, + std::vector &filtered_samples) +{ WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); int filtered_n_samples = 0; @@ -6596,10 +7743,12 @@ static bool whisper_vad( state->vad_mapping_table.clear(); state->has_vad_segments = false; - if (state->vad_context == nullptr) { + if (state->vad_context == nullptr) + { struct whisper_vad_context_params vad_ctx_params = whisper_vad_default_context_params(); - struct whisper_vad_context * vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params); - if (vctx == nullptr) { + struct whisper_vad_context *vctx = whisper_vad_init_from_file_with_params(params.vad_model_path, vad_ctx_params); + if (vctx == nullptr) + { WHISPER_LOG_ERROR("%s: failed to initialize VAD context\n", __func__); return false; } @@ -6607,11 +7756,12 @@ static bool whisper_vad( } auto vctx = state->vad_context; - const whisper_vad_params & vad_params = params.vad_params; + const whisper_vad_params &vad_params = params.vad_params; - whisper_vad_segments * vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples); + whisper_vad_segments *vad_segments = whisper_vad_segments_from_samples(vctx, vad_params, samples, n_samples); - if (vad_segments->data.size() > 0) { + if (vad_segments->data.size() > 0) + { state->has_vad_segments = true; ctx->state->vad_segments.clear(); ctx->state->vad_segments.reserve(vad_segments->data.size()); @@ -6624,21 +7774,23 @@ static bool whisper_vad( float overlap_seconds = vad_params.samples_overlap; int overlap_samples = overlap_seconds * WHISPER_SAMPLE_RATE; - for (int i = 0; i < (int)vad_segments->data.size(); i++) { + for (int i = 0; i < (int)vad_segments->data.size(); i++) + { int segment_start_samples = cs_to_samples(vad_segments->data[i].start); - int segment_end_samples = cs_to_samples(vad_segments->data[i].end); + int segment_end_samples = cs_to_samples(vad_segments->data[i].end); - if (i < (int)vad_segments->data.size() - 1) { + if (i < (int)vad_segments->data.size() - 1) + { segment_end_samples += overlap_samples; } segment_end_samples = std::min(segment_end_samples, n_samples - 1); - filtered_n_samples += (segment_end_samples - segment_start_samples); + filtered_n_samples += (segment_end_samples - segment_start_samples); WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n", - __func__, i, vad_segments->data[i].start/100.0, - (vad_segments->data[i].end/100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)), - (vad_segments->data[i].end - vad_segments->data[i].start)/100.0 + - (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)); + __func__, i, vad_segments->data[i].start / 100.0, + (vad_segments->data[i].end / 100.0 + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)), + (vad_segments->data[i].end - vad_segments->data[i].start) / 100.0 + + (i < (int)vad_segments->data.size() - 1 ? overlap_seconds : 0)); } int silence_samples = 0.1 * WHISPER_SAMPLE_RATE; @@ -6646,11 +7798,14 @@ static bool whisper_vad( int total_samples_needed = filtered_n_samples + total_silence_samples; WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n", - __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE); + __func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE); - try { + try + { filtered_samples.resize(total_samples_needed); - } catch (const std::bad_alloc & /* e */) { + } + catch (const std::bad_alloc & /* e */) + { WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__); whisper_vad_free_segments(vad_segments); whisper_vad_free(vctx); @@ -6658,25 +7813,28 @@ static bool whisper_vad( } int offset = 0; - for (int i = 0; i < (int)vad_segments->data.size(); i++) { + for (int i = 0; i < (int)vad_segments->data.size(); i++) + { int segment_start_samples = cs_to_samples(vad_segments->data[i].start); - int segment_end_samples = cs_to_samples(vad_segments->data[i].end); + int segment_end_samples = cs_to_samples(vad_segments->data[i].end); - if (i < (int)vad_segments->data.size() - 1) { + if (i < (int)vad_segments->data.size() - 1) + { segment_end_samples += overlap_samples; } segment_start_samples = std::min(segment_start_samples, n_samples - 1); segment_end_samples = std::min(segment_end_samples, n_samples); int segment_length = segment_end_samples - segment_start_samples; - if (segment_length > 0) { + if (segment_length > 0) + { whisper_state::vad_segment_info segment; segment.orig_start = vad_segments->data[i].start; - segment.orig_end = vad_segments->data[i].end; + segment.orig_end = vad_segments->data[i].end; segment.vad_start = samples_to_cs(offset); - segment.vad_end = samples_to_cs(offset + segment_length); + segment.vad_end = samples_to_cs(offset + segment_length); // Add segment boundaries to mapping table vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start}; @@ -6687,16 +7845,19 @@ static bool whisper_vad( // Add intermediate points for longer segments to improve interpolation accuracy const int64_t min_segment_length = 100; // 1 second - const int64_t point_interval = 20; // Add a point every 200ms + const int64_t point_interval = 20; // Add a point every 200ms - if (segment.vad_end - segment.vad_start > min_segment_length) { + if (segment.vad_end - segment.vad_start > min_segment_length) + { int64_t segment_duration = segment.vad_end - segment.vad_start; int num_points = (int)(segment_duration / point_interval) - 1; - for (int j = 1; j <= num_points; j++) { + for (int j = 1; j <= num_points; j++) + { int64_t vad_time = segment.vad_start + j * point_interval; - if (vad_time >= segment.vad_end) continue; + if (vad_time >= segment.vad_end) + continue; int64_t vad_elapsed = vad_time - segment.vad_start; int64_t vad_total = segment.vad_end - segment.vad_start; @@ -6709,7 +7870,7 @@ static bool whisper_vad( } WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n", - __func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0); + __func__, segment.orig_start / 100.0, segment.orig_end / 100.0, segment.vad_start / 100.0, segment.vad_end / 100.0); ctx->state->vad_segments.push_back(segment); // Copy this speech segment @@ -6717,13 +7878,14 @@ static bool whisper_vad( offset += segment_length; // Add silence after this segment (except after the last segment) - if (i < (int)vad_segments->data.size() - 1) { + if (i < (int)vad_segments->data.size() - 1) + { // Calculate the start and end time of the silence gap in processed audio int64_t silence_start_vad = samples_to_cs(offset); int64_t silence_end_vad = samples_to_cs(offset + silence_samples); // Calculate the corresponding original times int64_t orig_silence_start = segment.orig_end; - int64_t orig_silence_end = vad_segments->data[i+1].start; + int64_t orig_silence_end = vad_segments->data[i + 1].start; // Add mapping points for silence boundaries state->vad_mapping_table.push_back({silence_start_vad, orig_silence_start}); @@ -6738,17 +7900,20 @@ static bool whisper_vad( // Sort the mapping table by processed time std::sort(state->vad_mapping_table.begin(), state->vad_mapping_table.end(), - [](const vad_time_mapping& a, const vad_time_mapping& b) { - return a.processed_time < b.processed_time; - }); + [](const vad_time_mapping &a, const vad_time_mapping &b) + { + return a.processed_time < b.processed_time; + }); // Remove any duplicate processed times to ensure monotonicity which is // needed for binary search and interpolation later. - if (!state->vad_mapping_table.empty()) { + if (!state->vad_mapping_table.empty()) + { auto last = std::unique(state->vad_mapping_table.begin(), state->vad_mapping_table.end(), - [](const vad_time_mapping& a, const vad_time_mapping& b) { - return a.processed_time == b.processed_time; - }); + [](const vad_time_mapping &a, const vad_time_mapping &b) + { + return a.processed_time == b.processed_time; + }); state->vad_mapping_table.erase(last, state->vad_mapping_table.end()); } @@ -6756,37 +7921,42 @@ static bool whisper_vad( filtered_n_samples = offset; WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n", - __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples)); + __func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples)); } return true; } int whisper_full_with_state( - struct whisper_context * ctx, - struct whisper_state * state, - struct whisper_full_params params, - const float * samples, - int n_samples) { + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + const float *samples, + int n_samples) +{ // clear old results - auto & result_all = state->result_all; + auto &result_all = state->result_all; result_all.clear(); - if (n_samples > 0) { + if (n_samples > 0) + { // compute log mel spectrogram - if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) + { WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); return -2; } } // auto-detect language if not specified - if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) { + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) + { std::vector probs(whisper_lang_max_id() + 1, 0.0f); const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); - if (lang_id < 0) { + if (lang_id < 0) + { WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); return -3; } @@ -6794,73 +7964,86 @@ int whisper_full_with_state( params.language = whisper_lang_str(lang_id); WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); - if (params.detect_language) { + if (params.detect_language) + { return 0; } } - if (params.token_timestamps) { - state->t_beg = 0; - state->t_last = 0; + if (params.token_timestamps) + { + state->t_beg = 0; + state->t_last = 0; state->tid_last = 0; - if (n_samples > 0) { + if (n_samples > 0) + { state->energy = get_signal_energy(samples, n_samples, 32); } } - const int seek_start = params.offset_ms/10; - const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10; + const int seek_start = params.offset_ms / 10; + const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms / 10; // if length of spectrogram is less than 100ms (10 frames), then return // basically don't process anything that is less than 100ms // ref: https://github.com/ggml-org/whisper.cpp/issues/2065 const int delta_min = 10; - if (seek_end < seek_start + delta_min) { - WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10); + if (seek_end < seek_start + delta_min) + { + WHISPER_LOG_WARN("%s: input is too short - %d ms < 100 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start) * 10); return 0; } // a set of temperatures to use // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] std::vector temperatures; - if (params.temperature_inc > 0.0f) { - for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { + if (params.temperature_inc > 0.0f) + { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) + { temperatures.push_back(t); } - } else { + } + else + { temperatures.push_back(params.temperature); } // initialize the decoders int n_decoders = 1; - switch (params.strategy) { - case WHISPER_SAMPLING_GREEDY: - { - n_decoders = params.greedy.best_of; - } break; - case WHISPER_SAMPLING_BEAM_SEARCH: - { - n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); - } break; + switch (params.strategy) + { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } + break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } + break; }; n_decoders = std::max(1, n_decoders); - if (n_decoders > WHISPER_MAX_DECODERS) { + if (n_decoders > WHISPER_MAX_DECODERS) + { WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS); return -4; } // TAGS: WHISPER_DECODER_INIT - for (int j = 1; j < n_decoders; j++) { - auto & decoder = state->decoders[j]; + for (int j = 1; j < n_decoders; j++) + { + auto &decoder = state->decoders[j]; decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); - decoder.probs.resize (ctx->vocab.n_vocab); - decoder.logits.resize (ctx->vocab.n_vocab); + decoder.probs.resize(ctx->vocab.n_vocab); + decoder.logits.resize(ctx->vocab.n_vocab); decoder.logprobs.resize(ctx->vocab.n_vocab); decoder.logits_id.reserve(ctx->model.hparams.n_vocab); @@ -6868,8 +8051,9 @@ int whisper_full_with_state( } // the accumulated text context so far - auto & prompt_past = state->prompt_past; - if (params.no_context) { + auto &prompt_past = state->prompt_past; + if (params.no_context) + { prompt_past.clear(); } @@ -6878,22 +8062,26 @@ int whisper_full_with_state( std::vector prompt_tokens; // initial prompt - if (!params.prompt_tokens && params.initial_prompt) { + if (!params.prompt_tokens && params.initial_prompt) + { prompt_tokens.resize(1024); int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); - if (n_needed < 0) { + if (n_needed < 0) + { prompt_tokens.resize(-n_needed); n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); } prompt_tokens.resize(n_needed); - params.prompt_tokens = prompt_tokens.data(); + params.prompt_tokens = prompt_tokens.data(); params.prompt_n_tokens = prompt_tokens.size(); } // prepend the prompt tokens to the prompt_past - if (params.prompt_tokens && params.prompt_n_tokens > 0) { + if (params.prompt_tokens && params.prompt_n_tokens > 0) + { // parse tokens from the pointer - for (int i = 0; i < params.prompt_n_tokens; i++) { + for (int i = 0; i < params.prompt_n_tokens; i++) + { prompt_past.push_back(params.prompt_tokens[i]); } std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); @@ -6901,22 +8089,29 @@ int whisper_full_with_state( } // overwrite audio_ctx, max allowed is hparams.n_audio_ctx - if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) + { WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); return -5; } state->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed - std::vector prompt_init = { whisper_token_sot(ctx), }; + std::vector prompt_init = { + whisper_token_sot(ctx), + }; - if (whisper_is_multilingual(ctx)) { + if (whisper_is_multilingual(ctx)) + { const int lang_id = whisper_lang_id(params.language); state->lang_id = lang_id; prompt_init.push_back(whisper_token_lang(ctx, lang_id)); - if (params.translate) { + if (params.translate) + { prompt_init.push_back(whisper_token_translate(ctx)); - } else { + } + else + { prompt_init.push_back(whisper_token_transcribe(ctx)); } } @@ -6924,13 +8119,15 @@ int whisper_full_with_state( // first release distilled models require the "no_timestamps" token { const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866; - if (is_distil && !params.no_timestamps) { + if (is_distil && !params.no_timestamps) + { WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__); params.no_timestamps = true; } } - if (params.no_timestamps) { + if (params.no_timestamps) + { prompt_init.push_back(whisper_token_not(ctx)); } @@ -6939,7 +8136,8 @@ int whisper_full_with_state( std::vector prompt; prompt.reserve(whisper_n_text_ctx(ctx)); - struct beam_candidate { + struct beam_candidate + { int decoder_idx; int seek_delta; @@ -6953,60 +8151,75 @@ int whisper_full_with_state( std::vector beam_candidates; // main loop - while (true) { - if (params.progress_callback) { - const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); + while (true) + { + if (params.progress_callback) + { + const int progress_cur = (100 * (seek - seek_start)) / (seek_end - seek_start); params.progress_callback( ctx, state, progress_cur, params.progress_callback_user_data); } // if only 100ms left, then stop - if (seek + delta_min >= seek_end) { + if (seek + delta_min >= seek_end) + { break; } - if (params.encoder_begin_callback) { - if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { + if (params.encoder_begin_callback) + { + if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) + { WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); break; } } // encode audio features starting at offset seek - if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) + { WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); return -6; } // if there is a very short audio segment left to process, we remove any past prompt since it tends // to confuse the decoder and often make it repeat or hallucinate stuff - if (seek > seek_start && seek + 500 >= seek_end) { + if (seek > seek_start && seek + 500 >= seek_end) + { prompt_past.clear(); } int best_decoder_id = 0; - for (int it = 0; it < (int) temperatures.size(); ++it) { + for (int it = 0; it < (int)temperatures.size(); ++it) + { const float t_cur = temperatures[it]; int n_decoders_cur = 1; - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - if (t_cur > 0.0f) { - n_decoders_cur = params.greedy.best_of; - } - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: - { - if (t_cur > 0.0f) { - n_decoders_cur = params.greedy.best_of; - } else { - n_decoders_cur = params.beam_search.beam_size; - } - } break; + switch (params.strategy) + { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) + { + n_decoders_cur = params.greedy.best_of; + } + } + break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) + { + n_decoders_cur = params.greedy.best_of; + } + else + { + n_decoders_cur = params.beam_search.beam_size; + } + } + break; }; n_decoders_cur = std::max(1, n_decoders_cur); @@ -7014,26 +8227,30 @@ int whisper_full_with_state( WHISPER_LOG_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); // TAGS: WHISPER_DECODER_INIT - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; decoder.sequence.tokens.clear(); - decoder.sequence.result_len = 0; + decoder.sequence.result_len = 0; decoder.sequence.sum_logprobs_all = 0.0; - decoder.sequence.sum_logprobs = -INFINITY; - decoder.sequence.avg_logprobs = -INFINITY; - decoder.sequence.entropy = 0.0; - decoder.sequence.score = -INFINITY; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; - decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; + decoder.seek_delta = 100 * WHISPER_CHUNK_SIZE; - decoder.failed = false; + decoder.failed = false; decoder.completed = false; - decoder.has_ts = false; + decoder.has_ts = false; - if (params.grammar_rules != nullptr) { + if (params.grammar_rules != nullptr) + { decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule); - } else { + } + else + { decoder.grammar = {}; } } @@ -7044,10 +8261,11 @@ int whisper_full_with_state( prompt.clear(); // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { - int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) + { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx) / 2), int(prompt_past.size())); - prompt = { whisper_token_prev(ctx) }; + prompt = {whisper_token_prev(ctx)}; prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); } @@ -7056,13 +8274,15 @@ int whisper_full_with_state( // print the prompt WHISPER_LOG_DEBUG("\n\n"); - for (int i = 0; i < (int) prompt.size(); i++) { + for (int i = 0; i < (int)prompt.size(); i++) + { WHISPER_LOG_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); } WHISPER_LOG_DEBUG("\n\n"); // recreate the KV cache if the number of decoders has changed - if (state->kv_self_n_dec < n_decoders_cur) { + if (state->kv_self_n_dec < n_decoders_cur) + { WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur); whisper_kv_cache_free(state->kv_self); @@ -7071,9 +8291,10 @@ int whisper_full_with_state( const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1; if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype, - ctx->model.hparams.n_text_state, - ctx->model.hparams.n_text_layer, - GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_text_ctx, 256) * factor)) + { WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__); whisper_free_state(state); return -7; @@ -7086,7 +8307,8 @@ int whisper_full_with_state( whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0); - if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) + { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -8; } @@ -7110,25 +8332,29 @@ int whisper_full_with_state( whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur); - for (int j = 1; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + for (int j = 1; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1); - memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); - memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); - memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size() * sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size() * sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size() * sizeof(decoder.logprobs[0])); } state->t_sample_us += ggml_time_us() - t_start_sample_us; } } - for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + for (int i = 0, n_max = whisper_n_text_ctx(ctx) / 2 - 4; i < n_max; ++i) + { const int64_t t_start_sample_us = ggml_time_us(); - if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { - for (auto & bc : bc_per_dec) { + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) + { + for (auto &bc : bc_per_dec) + { bc.clear(); } } @@ -7138,125 +8364,158 @@ int whisper_full_with_state( { std::atomic j_cur(0); - auto process = [&]() { - while (true) { + auto process = [&]() + { + while (true) + { const int j = j_cur.fetch_add(1); - if (j >= n_decoders_cur) { + if (j >= n_decoders_cur) + { break; } - auto & decoder = state->decoders[j]; + auto &decoder = state->decoders[j]; - if (decoder.completed || decoder.failed) { + if (decoder.completed || decoder.failed) + { continue; } - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); - } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); - } - - decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: - { - const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); - - for (const auto & token : tokens_new) { - bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); - bc_per_dec[j].back().sequence.tokens.push_back(token); - bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; - } - } break; + switch (params.strategy) + { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) + { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } + else + { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } + break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto &token : tokens_new) + { + bc_per_dec[j].push_back({ + j, + decoder.seek_delta, + decoder.has_ts, + decoder.sequence, + decoder.grammar, + }); + bc_per_dec[j].back().sequence.tokens.push_back(token); + bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; + } + } + break; }; } }; const int n_threads = std::min(params.n_threads, n_decoders_cur); - if (n_threads == 1) { + if (n_threads == 1) + { process(); - } else { + } + else + { std::vector threads(n_threads - 1); - for (int t = 0; t < n_threads - 1; ++t) { + for (int t = 0; t < n_threads - 1; ++t) + { threads[t] = std::thread(process); } process(); - for (int t = 0; t < n_threads - 1; ++t) { + for (int t = 0; t < n_threads - 1; ++t) + { threads[t].join(); } } } beam_candidates.clear(); - for (const auto & bc : bc_per_dec) { + for (const auto &bc : bc_per_dec) + { beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end()); - if (!bc.empty()) { + if (!bc.empty()) + { state->n_sample += 1; } } // for beam-search, choose the top candidates and update the KV caches - if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) + { std::sort( - beam_candidates.begin(), - beam_candidates.end(), - [](const beam_candidate & a, const beam_candidate & b) { - if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) { - return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; - } - return a.decoder_idx < b.decoder_idx; - }); + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate &a, const beam_candidate &b) + { + if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) + { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + } + return a.decoder_idx < b.decoder_idx; + }); uint32_t cur_c = 0; - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; - if (decoder.completed || decoder.failed) { + if (decoder.completed || decoder.failed) + { continue; } - if (cur_c >= beam_candidates.size()) { + if (cur_c >= beam_candidates.size()) + { cur_c = 0; } - auto & cur = beam_candidates[cur_c++]; + auto &cur = beam_candidates[cur_c++]; - while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) { + while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) + { ++cur_c; } decoder.seek_delta = cur.seek_delta; - decoder.has_ts = cur.has_ts; - decoder.sequence = cur.sequence; - decoder.grammar = cur.grammar; + decoder.has_ts = cur.has_ts; + decoder.sequence = cur.sequence; + decoder.grammar = cur.grammar; whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); WHISPER_LOG_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", - __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); } - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; - if (decoder.completed || decoder.failed) { + if (decoder.completed || decoder.failed) + { continue; } - whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1); + whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1); whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1); - whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1); + whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1); } } @@ -7264,28 +8523,32 @@ int whisper_full_with_state( // - check if the sequence is completed // - check if the sequence is failed // - update sliding window based on timestamp tokens - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; - if (decoder.completed || decoder.failed) { + if (decoder.completed || decoder.failed) + { continue; } - auto & has_ts = decoder.has_ts; - auto & failed = decoder.failed; - auto & completed = decoder.completed; - auto & seek_delta = decoder.seek_delta; - auto & result_len = decoder.sequence.result_len; + auto &has_ts = decoder.has_ts; + auto &failed = decoder.failed; + auto &completed = decoder.completed; + auto &seek_delta = decoder.seek_delta; + auto &result_len = decoder.sequence.result_len; { - const auto & token = decoder.sequence.tokens.back(); + const auto &token = decoder.sequence.tokens.back(); // timestamp token - update sliding window - if (token.id > whisper_token_beg(ctx)) { - const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + if (token.id > whisper_token_beg(ctx)) + { + const int seek_delta_new = 2 * (token.id - whisper_token_beg(ctx)); // do not allow to go back in time - if (has_ts && seek_delta > seek_delta_new && result_len < i) { + if (has_ts && seek_delta > seek_delta_new && result_len < i) + { WHISPER_LOG_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new); failed = true; // TODO: maybe this is not a failure ? continue; @@ -7302,28 +8565,34 @@ int whisper_full_with_state( { const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; WHISPER_LOG_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", - __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); } #endif // end of segment - if (token.id == whisper_token_eot(ctx) || // end of text token - (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached - (has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms) - ) { - if (result_len == 0 && !params.no_timestamps) { - if (seek + seek_delta + delta_min >= seek_end) { + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + delta_min >= seek_end) // end of audio reached (100ms) + ) + { + if (result_len == 0 && !params.no_timestamps) + { + if (seek + seek_delta + delta_min >= seek_end) + { result_len = i + 1; - } else { + } + else + { WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j); failed = true; continue; } } - if (params.single_segment || params.no_timestamps) { + if (params.single_segment || params.no_timestamps) + { result_len = i + 1; - seek_delta = 100*WHISPER_CHUNK_SIZE; + seek_delta = 100 * WHISPER_CHUNK_SIZE; } WHISPER_LOG_DEBUG("%s: decoder %d completed\n", __func__, j); @@ -7332,8 +8601,9 @@ int whisper_full_with_state( } // TESTS: if no tensors are loaded, it means we are running tests - if (ctx->model.n_loaded == 0) { - seek_delta = 100*WHISPER_CHUNK_SIZE; + if (ctx->model.n_loaded == 0) + { + seek_delta = 100 * WHISPER_CHUNK_SIZE; completed = true; continue; } @@ -7341,7 +8611,8 @@ int whisper_full_with_state( // sometimes, the decoding can get stuck in a repetition loop // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy - if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2)) + { WHISPER_LOG_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j); failed = true; continue; @@ -7352,17 +8623,20 @@ int whisper_full_with_state( { bool completed_all = true; - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; - if (decoder.completed || decoder.failed) { + if (decoder.completed || decoder.failed) + { continue; } completed_all = false; } - if (completed_all) { + if (completed_all) + { break; } } @@ -7371,34 +8645,37 @@ int whisper_full_with_state( // obtain logits for the next token { - auto & batch = state->batch; + auto &batch = state->batch; batch.n_tokens = 0; const int n_past = prompt.size() + i; - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; - if (decoder.failed || decoder.completed) { + if (decoder.failed || decoder.completed) + { continue; } - //WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); + // WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); decoder.i_batch = batch.n_tokens; - batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id; - batch.pos [batch.n_tokens] = n_past; - batch.n_seq_id[batch.n_tokens] = 1; - batch.seq_id [batch.n_tokens][0] = j; - batch.logits [batch.n_tokens] = 1; + batch.token[batch.n_tokens] = decoder.sequence.tokens.back().id; + batch.pos[batch.n_tokens] = n_past; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id[batch.n_tokens][0] = j; + batch.logits[batch.n_tokens] = 1; batch.n_tokens++; } assert(batch.n_tokens > 0); - if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) + { WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); return -9; } @@ -7409,17 +8686,21 @@ int whisper_full_with_state( { std::atomic j_cur(0); - auto process = [&]() { - while (true) { + auto process = [&]() + { + while (true) + { const int j = j_cur.fetch_add(1); - if (j >= n_decoders_cur) { + if (j >= n_decoders_cur) + { break; } - auto & decoder = state->decoders[j]; + auto &decoder = state->decoders[j]; - if (decoder.failed || decoder.completed) { + if (decoder.failed || decoder.completed) + { continue; } @@ -7429,18 +8710,23 @@ int whisper_full_with_state( const int n_threads = std::min(params.n_threads, n_decoders_cur); - if (n_threads == 1) { + if (n_threads == 1) + { process(); - } else { + } + else + { std::vector threads(n_threads - 1); - for (int t = 0; t < n_threads - 1; ++t) { + for (int t = 0; t < n_threads - 1; ++t) + { threads[t] = std::thread(process); } process(); - for (int t = 0; t < n_threads - 1; ++t) { + for (int t = 0; t < n_threads - 1; ++t) + { threads[t].join(); } } @@ -7454,10 +8740,12 @@ int whisper_full_with_state( { double best_score = -INFINITY; - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + for (int j = 0; j < n_decoders_cur; ++j) + { + auto &decoder = state->decoders[j]; - if (decoder.failed) { + if (decoder.failed) + { continue; } @@ -7465,11 +8753,12 @@ int whisper_full_with_state( whisper_sequence_score(params, decoder.sequence); WHISPER_LOG_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", - __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); - if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) { + if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) + { WHISPER_LOG_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", - __func__, j, decoder.sequence.entropy, params.entropy_thold); + __func__, j, decoder.sequence.entropy, params.entropy_thold); decoder.failed = true; state->n_fail_h++; @@ -7477,7 +8766,8 @@ int whisper_full_with_state( continue; } - if (best_score < decoder.sequence.score) { + if (best_score < decoder.sequence.score) + { best_score = decoder.sequence.score; best_decoder_id = j; } @@ -7491,21 +8781,24 @@ int whisper_full_with_state( // was the decoding successful for the current temperature? // do fallback only if: // - we are not at the last temperature - if (it != (int) temperatures.size() - 1) { - const auto & decoder = state->decoders[best_decoder_id]; + if (it != (int)temperatures.size() - 1) + { + const auto &decoder = state->decoders[best_decoder_id]; if (decoder.failed || - (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) { + (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) + { WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold); success = false; state->n_fail_p++; } } - if (success) { - //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { - // WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); - //} + if (success) + { + // for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + // } break; } @@ -7515,91 +8808,108 @@ int whisper_full_with_state( // output results through a user-provided callback { - const auto & best_decoder = state->decoders[best_decoder_id]; + const auto &best_decoder = state->decoders[best_decoder_id]; auto seek_delta = best_decoder.seek_delta; const auto result_len = best_decoder.sequence.result_len; - const auto & tokens_cur = best_decoder.sequence.tokens; + const auto &tokens_cur = best_decoder.sequence.tokens; // [EXPERIMENTAL] Token-level timestamps with DTW const auto n_segments_before = state->result_all.size(); const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold && - best_decoder.sequence.avg_logprobs < params.logprob_thold); + best_decoder.sequence.avg_logprobs < params.logprob_thold); - //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + // WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); // update prompt_past prompt_past.clear(); - if (prompt.front() == whisper_token_prev(ctx)) { + if (prompt.front() == whisper_token_prev(ctx)) + { prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); } - for (int i = 0; i < result_len && !is_no_speech; ++i) { + for (int i = 0; i < result_len && !is_no_speech; ++i) + { prompt_past.push_back(tokens_cur[i].id); } - if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) { - int i0 = 0; - auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) + { + int i0 = 0; + auto t0 = seek + 2 * (tokens_cur.front().tid - whisper_token_beg(ctx)); std::string text; bool speaker_turn_next = false; - for (int i = 0; i < (int) tokens_cur.size(); i++) { - //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, - // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, - // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + for (int i = 0; i < (int)tokens_cur.size(); i++) + { + // printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); - if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) { + if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) + { text += whisper_token_to_str(ctx, tokens_cur[i].id); } // [TDRZ] record if speaker turn was predicted after current segment - if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) { + if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) + { speaker_turn_next = true; } - if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { - const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) + { + const auto t1 = seek + 2 * (tokens_cur[i].tid - whisper_token_beg(ctx)); - if (!text.empty()) { + if (!text.empty()) + { const auto tt0 = t0; const auto tt1 = t1; - if (params.print_realtime) { - if (params.print_timestamps) { + if (params.print_realtime) + { + if (params.print_timestamps) + { printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { + } + else + { printf("%s", text.c_str()); fflush(stdout); } } - //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + // printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); - result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next }); - for (int j = i0; j <= i; j++) { + result_all.push_back({tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next}); + for (int j = i0; j <= i; j++) + { result_all.back().tokens.push_back(tokens_cur[j]); } int n_new = 1; - if (params.token_timestamps) { + if (params.token_timestamps) + { whisper_exp_compute_token_level_timestamps( - *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); - if (params.max_len > 0) { + if (params.max_len > 0) + { n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); } } - if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) { + if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) + { params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } } text = ""; - while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + while (i < (int)tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) + { i++; } i--; @@ -7609,37 +8919,46 @@ int whisper_full_with_state( } } - if (!text.empty()) { + if (!text.empty()) + { const auto t1 = seek + seek_delta; const auto tt0 = t0; const auto tt1 = t1; - if (params.print_realtime) { - if (params.print_timestamps) { + if (params.print_realtime) + { + if (params.print_timestamps) + { printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { + } + else + { printf("%s", text.c_str()); fflush(stdout); } } - result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next }); - for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.push_back({tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next}); + for (int j = i0; j < (int)tokens_cur.size(); j++) + { result_all.back().tokens.push_back(tokens_cur[j]); } int n_new = 1; - if (params.token_timestamps) { + if (params.token_timestamps) + { whisper_exp_compute_token_level_timestamps( - *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); - if (params.max_len > 0) { + if (params.max_len > 0) + { n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); } } - if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) { + if (params.new_segment_callback && !ctx->params.dtw_token_timestamps) + { params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); } } @@ -7649,12 +8968,15 @@ int whisper_full_with_state( // [EXPERIMENTAL] Token-level timestamps with DTW { const int n_segments = state->result_all.size() - n_segments_before; - if (ctx->params.dtw_token_timestamps && n_segments) { + if (ctx->params.dtw_token_timestamps && n_segments) + { const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek); whisper_exp_compute_token_level_timestamps_dtw( - ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads); - if (params.new_segment_callback) { - for (int seg = (int) result_all.size() - n_segments; seg < n_segments; seg++) { + ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads); + if (params.new_segment_callback) + { + for (int seg = (int)result_all.size() - n_segments; seg < n_segments; seg++) + { params.new_segment_callback(ctx, state, seg, params.new_segment_callback_user_data); } } @@ -7663,9 +8985,10 @@ int whisper_full_with_state( // ref: https://github.com/ggml-org/whisper.cpp/pull/2629 const bool single_timestamp_ending = tokens_cur.size() > 1 && - tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) && - tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx); - if (single_timestamp_ending) { + tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) && + tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx); + if (single_timestamp_ending) + { WHISPER_LOG_DEBUG("single timestamp ending - skip entire chunk\n"); seek_delta = std::min(seek_end - seek, WHISPER_CHUNK_SIZE * 100); } @@ -7681,19 +9004,23 @@ int whisper_full_with_state( } int whisper_full( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples) { + struct whisper_context *ctx, + struct whisper_full_params params, + const float *samples, + int n_samples) +{ std::vector vad_samples; - if (params.vad) { + if (params.vad) + { WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); - if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) { + if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) + { WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); return -1; } - if (vad_samples.empty()) { + if (vad_samples.empty()) + { ctx->state->result_all.clear(); return 0; } @@ -7704,24 +9031,29 @@ int whisper_full( } int whisper_full_parallel( - struct whisper_context * ctx, - struct whisper_full_params params, - const float * samples, - int n_samples, - int n_processors) { + struct whisper_context *ctx, + struct whisper_full_params params, + const float *samples, + int n_samples, + int n_processors) +{ - if (n_processors == 1) { + if (n_processors == 1) + { return whisper_full(ctx, params, samples, n_samples); } std::vector vad_samples; - if (params.vad) { + if (params.vad) + { WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__); - if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) { + if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) + { WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__); return -1; } - if (vad_samples.empty()) { + if (vad_samples.empty()) + { return 0; } samples = vad_samples.data(); @@ -7730,20 +9062,21 @@ int whisper_full_parallel( int ret = 0; // prepare separate states for each thread - std::vector states; + std::vector states; - const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; - const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + const int offset_samples = (WHISPER_SAMPLE_RATE * params.offset_ms) / 1000; + const int n_samples_per_processor = (n_samples - offset_samples) / n_processors; // the calling thread will process the first chunk // while the other threads will process the remaining chunks std::vector workers(n_processors - 1); - for (int i = 0; i < n_processors - 1; ++i) { + for (int i = 0; i < n_processors - 1; ++i) + { // create a new state for each thread states.push_back(whisper_init_state(ctx)); - const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int start_samples = offset_samples + (i + 1) * n_samples_per_processor; const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; auto params_cur = params; @@ -7771,30 +9104,35 @@ int whisper_full_parallel( ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); } - for (int i = 0; i < n_processors - 1; ++i) { + for (int i = 0; i < n_processors - 1; ++i) + { workers[i].join(); } - const int64_t offset_t = (int64_t) params.offset_ms/10.0; + const int64_t offset_t = (int64_t)params.offset_ms / 10.0; // combine results into result_state->result_all from all other states - for (int i = 0; i < n_processors - 1; ++i) { - auto& results_i = states[i]->result_all; + for (int i = 0; i < n_processors - 1; ++i) + { + auto &results_i = states[i]->result_all; - for (auto& result : results_i) { + for (auto &result : results_i) + { // correct the segment timestamp taking into account the offset result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; // make sure that segments are not overlapping - if (!ctx->state->result_all.empty()) { + if (!ctx->state->result_all.empty()) + { result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); } ctx->state->result_all.push_back(std::move(result)); // call the new_segment_callback for each segment - if (params.new_segment_callback) { + if (params.new_segment_callback) + { params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); } } @@ -7817,7 +9155,7 @@ int whisper_full_parallel( } // average the timings - ctx->state->t_mel_us /= n_processors; + ctx->state->t_mel_us /= n_processors; ctx->state->t_sample_us /= n_processors; ctx->state->t_encode_us /= n_processors; ctx->state->t_decode_us /= n_processors; @@ -7825,53 +9163,184 @@ int whisper_full_parallel( // print information about the audio boundaries WHISPER_LOG_WARN("\n"); WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); - for (int i = 0; i < n_processors - 1; ++i) { - WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + for (int i = 0; i < n_processors - 1; ++i) + { + WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t).c_str()); } WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); return ret; } -int whisper_full_n_segments_from_state(struct whisper_state * state) { +int whisper_full_batch_parallel(struct whisper_context *ctx, + struct whisper_full_params params, + const float *const *batches, + const int *size_per_batch, + int n_batches, + int n_processors) +{ + int ret = 0; + n_processors = std::min(n_processors, n_batches); + if (n_batches > n_processors) + { + throw std::runtime_error("batch size must be equal to number of processors"); + } + // prepare separate states for each thread + std::vector states; + std::vector> batches_vector; + batches_vector.reserve(n_batches); + for (int i = 0; i < n_batches; ++i) + { + int batch_size = size_per_batch[i]; + batches_vector.emplace_back(batches[i], batches[i] + batch_size); + } + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + const int n_parallel_processes = n_processors - 1; + std::vector workers(n_parallel_processes); + for (int i = 0; i < n_parallel_processes; ++i) + { + if (i + 1 > n_batches - 1) + { + // break when batch not exist for parallel process + break; + } + const float *samples = batches_vector[i + 1].data(); + const int n_samples = batches_vector[i + 1].size(); + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + params_cur.progress_callback = nullptr; + params_cur.progress_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples, n_samples); + } + + { + auto params_cur = params; + + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + const float *samples = batches_vector[0].data(); + const int n_samples = batches_vector[0].size(); + + // Run the first transformation using default state but only for the first chunk. + ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, n_samples); + } + + for (int i = 0; i < n_parallel_processes; ++i) + { + workers[i].join(); + } + + // combine results into result_state->result_all from all other states + for (int i = 0; i < n_processors - 1; ++i) + { + auto &results_i = states[i]->result_all; + + for (auto &result : results_i) + { + + // make sure that segments are not overlapping + if (!ctx->state->result_all.empty()) + { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); + } + + ctx->state->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) + { + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); + } + } + + ctx->state->t_mel_us += states[i]->t_mel_us; + + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; + ctx->state->t_prompt_us += states[i]->t_prompt_us; + + ctx->state->n_sample += states[i]->n_sample; + ctx->state->n_encode += states[i]->n_encode; + ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; + ctx->state->n_prompt += states[i]->n_prompt; + + whisper_free_state(states[i]); + } + + // average the timings + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; + + return ret; +} + +int whisper_full_n_segments_from_state(struct whisper_state *state) +{ return state->result_all.size(); } -int whisper_full_n_segments(struct whisper_context * ctx) { +int whisper_full_n_segments(struct whisper_context *ctx) +{ return ctx->state->result_all.size(); } -int whisper_full_lang_id_from_state(struct whisper_state * state) { +int whisper_full_lang_id_from_state(struct whisper_state *state) +{ return state->lang_id; } -int whisper_full_lang_id(struct whisper_context * ctx) { +int whisper_full_lang_id(struct whisper_context *ctx) +{ return ctx->state->lang_id; } -static int64_t map_processed_to_original_time(int64_t processed_time, const std::vector & mapping_table) { - if (mapping_table.empty()) { +static int64_t map_processed_to_original_time(int64_t processed_time, const std::vector &mapping_table) +{ + if (mapping_table.empty()) + { return processed_time; } - if (processed_time <= mapping_table.front().processed_time) { + if (processed_time <= mapping_table.front().processed_time) + { return mapping_table.front().original_time; // Before first mapping point } - if (processed_time >= mapping_table.back().processed_time) { + if (processed_time >= mapping_table.back().processed_time) + { return mapping_table.back().original_time; // After last mapping point } // Binary search over the time map that finds the first entry that has a // processed time greater than or equal to the current processed time. auto upper = std::lower_bound(mapping_table.begin(), mapping_table.end(), processed_time, - [](const vad_time_mapping & entry, int64_t time) { - return entry.processed_time < time; - } - ); + [](const vad_time_mapping &entry, int64_t time) + { + return entry.processed_time < time; + }); // If exact match found - if (upper->processed_time == processed_time) { + if (upper->processed_time == processed_time) + { return upper->original_time; } @@ -7882,7 +9351,8 @@ static int64_t map_processed_to_original_time(int64_t processed_time, const std: int64_t original_diff = upper->original_time - lower->original_time; int64_t offset = processed_time - lower->processed_time; - if (processed_diff == 0) { + if (processed_diff == 0) + { return lower->original_time; } @@ -7891,9 +9361,11 @@ static int64_t map_processed_to_original_time(int64_t processed_time, const std: } // Function to get the starting timestamp of a segment -int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { +int64_t whisper_full_get_segment_t0_from_state(struct whisper_state *state, int i_segment) +{ // If VAD wasn't used, return the original timestamp - if (!state->has_vad_segments || state->vad_mapping_table.empty()) { + if (!state->has_vad_segments || state->vad_mapping_table.empty()) + { return state->result_all[i_segment].t0; } @@ -7905,9 +9377,11 @@ int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int } // Function to get the ending timestamp of a segment -int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { +int64_t whisper_full_get_segment_t1_from_state(struct whisper_state *state, int i_segment) +{ // If VAD wasn't used, return the original timestamp - if (!state->has_vad_segments || state->vad_mapping_table.empty()) { + if (!state->has_vad_segments || state->vad_mapping_table.empty()) + { return state->result_all[i_segment].t1; } @@ -7922,83 +9396,101 @@ int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int // Ensure minimum duration to prevent zero-length segments const int64_t min_duration = 10; // 10ms minimum - if (orig_t1 - orig_t0 < min_duration) { + if (orig_t1 - orig_t0 < min_duration) + { orig_t1 = orig_t0 + min_duration; } return orig_t1; } - -int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { +int64_t whisper_full_get_segment_t0(struct whisper_context *ctx, int i_segment) +{ return whisper_full_get_segment_t0_from_state(ctx->state, i_segment); } -int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { +int64_t whisper_full_get_segment_t1(struct whisper_context *ctx, int i_segment) +{ return whisper_full_get_segment_t1_from_state(ctx->state, i_segment); } -bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) { +bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state *state, int i_segment) +{ return state->result_all[i_segment].speaker_turn_next; } -bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) { +bool whisper_full_get_segment_speaker_turn_next(struct whisper_context *ctx, int i_segment) +{ return ctx->state->result_all[i_segment].speaker_turn_next; } -const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { +const char *whisper_full_get_segment_text_from_state(struct whisper_state *state, int i_segment) +{ return state->result_all[i_segment].text.c_str(); } -const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { +const char *whisper_full_get_segment_text(struct whisper_context *ctx, int i_segment) +{ return ctx->state->result_all[i_segment].text.c_str(); } -int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) { +int whisper_full_n_tokens_from_state(struct whisper_state *state, int i_segment) +{ return state->result_all[i_segment].tokens.size(); } -int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { +int whisper_full_n_tokens(struct whisper_context *ctx, int i_segment) +{ return ctx->state->result_all[i_segment].tokens.size(); } -const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) { +const char *whisper_full_get_token_text_from_state(struct whisper_context *ctx, struct whisper_state *state, int i_segment, int i_token) +{ return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); } -const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { +const char *whisper_full_get_token_text(struct whisper_context *ctx, int i_segment, int i_token) +{ return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); } -whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) { +whisper_token whisper_full_get_token_id_from_state(struct whisper_state *state, int i_segment, int i_token) +{ return state->result_all[i_segment].tokens[i_token].id; } -whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { +whisper_token whisper_full_get_token_id(struct whisper_context *ctx, int i_segment, int i_token) +{ return ctx->state->result_all[i_segment].tokens[i_token].id; } -struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) { +struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state *state, int i_segment, int i_token) +{ return state->result_all[i_segment].tokens[i_token]; } -struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { +struct whisper_token_data whisper_full_get_token_data(struct whisper_context *ctx, int i_segment, int i_token) +{ return ctx->state->result_all[i_segment].tokens[i_token]; } -float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) { +float whisper_full_get_token_p_from_state(struct whisper_state *state, int i_segment, int i_token) +{ return state->result_all[i_segment].tokens[i_token].p; } -float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { +float whisper_full_get_token_p(struct whisper_context *ctx, int i_segment, int i_token) +{ return ctx->state->result_all[i_segment].tokens[i_token].p; } -float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) { +float whisper_full_get_segment_no_speech_prob(struct whisper_context *ctx, int i_segment) +{ return ctx->state->result_all[i_segment].no_speech_prob; } -float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * state, int i_segment) { +float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state *state, int i_segment) +{ return state->result_all[i_segment].no_speech_prob; } @@ -8009,55 +9501,60 @@ float whisper_full_get_segment_no_speech_prob_from_state(struct whisper_state * // Will be removed in the future when ggml becomes a separate library // -WHISPER_API int whisper_bench_memcpy(int n_threads) { +WHISPER_API int whisper_bench_memcpy(int n_threads) +{ fputs(whisper_bench_memcpy_str(n_threads), stderr); return 0; } -WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { +WHISPER_API const char *whisper_bench_memcpy_str(int n_threads) +{ static std::string s; s = ""; char strbuf[256]; ggml_time_init(); - size_t n = 20; - size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations + size_t n = 20; + size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations // 1GB array - const size_t size = arr*1e6; + const size_t size = arr * 1e6; - double sum = 0.0; + double sum = 0.0; // heat-up { - char * src = (char *) malloc(size); - char * dst = (char *) malloc(size); + char *src = (char *)malloc(size); + char *dst = (char *)malloc(size); - for (size_t i = 0; i < size; i++) src[i] = i; + for (size_t i = 0; i < size; i++) + src[i] = i; memcpy(dst, src, size); // heat-up double tsum = 0.0; - for (size_t i = 0; i < n; i++) { + for (size_t i = 0; i < n; i++) + { const int64_t t0 = ggml_time_us(); memcpy(dst, src, size); const int64_t t1 = ggml_time_us(); - tsum += (t1 - t0)*1e-6; + tsum += (t1 - t0) * 1e-6; src[rand() % size] = rand() % 256; } - snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9)); + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double)(n * size) / (tsum * 1e9)); s += strbuf; // needed to prevent the compiler from optimizing the memcpy away { - for (size_t i = 0; i < size; i++) sum += dst[i]; + for (size_t i = 0; i < size; i++) + sum += dst[i]; } free(src); @@ -8066,33 +9563,36 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { // single-thread { - char * src = (char *) malloc(size); - char * dst = (char *) malloc(size); + char *src = (char *)malloc(size); + char *dst = (char *)malloc(size); - for (size_t i = 0; i < size; i++) src[i] = i; + for (size_t i = 0; i < size; i++) + src[i] = i; memcpy(dst, src, size); // heat-up double tsum = 0.0; - for (size_t i = 0; i < n; i++) { + for (size_t i = 0; i < n; i++) + { const int64_t t0 = ggml_time_us(); memcpy(dst, src, size); const int64_t t1 = ggml_time_us(); - tsum += (t1 - t0)*1e-6; + tsum += (t1 - t0) * 1e-6; src[rand() % size] = rand() % 256; } - snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9)); + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double)(n * size) / (tsum * 1e9)); s += strbuf; // needed to prevent the compiler from optimizing the memcpy away { - for (size_t i = 0; i < size; i++) sum += dst[i]; + for (size_t i = 0; i < size; i++) + sum += dst[i]; } free(src); @@ -8101,21 +9601,25 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { // multi-thread - for (int32_t k = 1; k <= n_threads; k++) { - char * src = (char *) malloc(size); - char * dst = (char *) malloc(size); + for (int32_t k = 1; k <= n_threads; k++) + { + char *src = (char *)malloc(size); + char *dst = (char *)malloc(size); - for (size_t i = 0; i < size; i++) src[i] = i; + for (size_t i = 0; i < size; i++) + src[i] = i; memcpy(dst, src, size); // heat-up double tsum = 0.0; - auto helper = [&](int th) { - const int64_t i0 = (th + 0)*size/k; - const int64_t i1 = (th + 1)*size/k; + auto helper = [&](int th) + { + const int64_t i0 = (th + 0) * size / k; + const int64_t i1 = (th + 1) * size / k; - for (size_t i = 0; i < n; i++) { + for (size_t i = 0; i < n; i++) + { memcpy(dst + i0, src + i0, i1 - i0); src[i0 + rand() % (i1 - i0)] = rand() % 256; @@ -8125,26 +9629,29 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { const int64_t t0 = ggml_time_us(); std::vector threads(k - 1); - for (int32_t th = 0; th < k - 1; ++th) { + for (int32_t th = 0; th < k - 1; ++th) + { threads[th] = std::thread(helper, th); } helper(k - 1); - for (int32_t th = 0; th < k - 1; ++th) { + for (int32_t th = 0; th < k - 1; ++th) + { threads[th].join(); } const int64_t t1 = ggml_time_us(); - tsum += (t1 - t0)*1e-6; + tsum += (t1 - t0) * 1e-6; - snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k); + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double)(n * size) / (tsum * 1e9), k); s += strbuf; // needed to prevent the compiler from optimizing the memcpy away { - for (size_t i = 0; i < size; i++) sum += dst[i]; + for (size_t i = 0; i < size; i++) + sum += dst[i]; } free(src); @@ -8157,12 +9664,14 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { return s.c_str(); } -WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { +WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) +{ fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr); return 0; } -WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { +WHISPER_API const char *whisper_bench_ggml_mul_mat_str(int n_threads) +{ static std::string s; s = ""; char strbuf[256]; @@ -8172,7 +9681,13 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { const int n_max = 128; const std::vector sizes = { - 64, 128, 256, 512, 1024, 2048, 4096, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, }; const size_t N_max = sizes.back(); @@ -8181,12 +9696,14 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { // b: N*N*sizeof(float) // c: N*N*sizeof(float) // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) - std::vector buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead()); + std::vector buf(3llu * N_max * N_max * sizeof(float) + 3 * ggml_tensor_overhead() + ggml_graph_overhead()); // put a bunch of random data in the buffer - for (size_t i = 0; i < buf.size(); i++) buf[i] = i; + for (size_t i = 0; i < buf.size(); i++) + buf[i] = i; - for (int j = 0; j < (int) sizes.size(); j++) { + for (int j = 0; j < (int)sizes.size(); j++) + { int n_q4_0 = 0; int n_q4_1 = 0; int n_q5_0 = 0; @@ -8206,32 +9723,43 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { const size_t N = sizes[j]; - for (int k = 0; k < 7; ++k) { + for (int k = 0; k < 7; ++k) + { const ggml_type wtype = - k == 0 ? GGML_TYPE_Q4_0 : - k == 1 ? GGML_TYPE_Q4_1 : - k == 2 ? GGML_TYPE_Q5_0 : - k == 3 ? GGML_TYPE_Q5_1 : - k == 4 ? GGML_TYPE_Q8_0 : - k == 5 ? GGML_TYPE_F16 : GGML_TYPE_F32; - - double & s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 : k == 2 ? s_q5_0 : k == 3 ? s_q5_1 : k == 4 ? s_q8_0 : k == 5 ? s_fp16 : /*k == 6*/ s_fp32; - int & n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 : k == 2 ? n_q5_0 : k == 3 ? n_q5_1 : k == 4 ? n_q8_0 : k == 5 ? n_fp16 : /*k == 6*/ n_fp32; + k == 0 ? GGML_TYPE_Q4_0 : k == 1 ? GGML_TYPE_Q4_1 + : k == 2 ? GGML_TYPE_Q5_0 + : k == 3 ? GGML_TYPE_Q5_1 + : k == 4 ? GGML_TYPE_Q8_0 + : k == 5 ? GGML_TYPE_F16 + : GGML_TYPE_F32; + + double &s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 + : k == 2 ? s_q5_0 + : k == 3 ? s_q5_1 + : k == 4 ? s_q8_0 + : k == 5 ? s_fp16 + : /*k == 6*/ s_fp32; + int &n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 + : k == 2 ? n_q5_0 + : k == 3 ? n_q5_1 + : k == 4 ? n_q8_0 + : k == 5 ? n_fp16 + : /*k == 6*/ n_fp32; struct ggml_init_params gparams = { - /*.mem_size =*/ buf.size(), - /*.mem_buffer =*/ buf.data(), - /*.no_alloc =*/ false, + /*.mem_size =*/buf.size(), + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/false, }; - struct ggml_context * ctx0 = ggml_init(gparams); + struct ggml_context *ctx0 = ggml_init(gparams); - struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); - struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + struct ggml_tensor *a = ggml_new_tensor_2d(ctx0, wtype, N, N); + struct ggml_tensor *b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); - struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); + struct ggml_tensor *c = ggml_mul_mat(ctx0, a, b); - struct ggml_cgraph * gf = ggml_new_graph(ctx0); + struct ggml_cgraph *gf = ggml_new_graph(ctx0); ggml_build_forward_expand(gf, c); @@ -8240,39 +9768,41 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { // heat-up ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr); - for (int i = 0; i < n_max; ++i) { + for (int i = 0; i < n_max; ++i) + { const int64_t t0 = ggml_time_us(); ggml_graph_compute_helper(gf, n_threads, nullptr, nullptr); const int64_t t1 = ggml_time_us(); - tsum += (t1 - t0)*1e-6; + tsum += (t1 - t0) * 1e-6; n++; - if (tsum > 1.0 && n >= 3) { + if (tsum > 1.0 && n >= 3) + { break; } } ggml_free(ctx0); - s = ((2.0*N*N*N*n)/tsum)*1e-9; + s = ((2.0 * N * N * N * n) / tsum) * 1e-9; } // Q4_0 | Q4_1 snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) | Q4_1 %7.1f GFLOPS (%3d runs)\n", - N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1); + N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1); s += strbuf; // Q5_0 | Q5_1 | Q8_0 snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q5_0 %7.1f GFLOPS (%3d runs) | Q5_1 %7.1f GFLOPS (%3d runs) | Q8_0 %7.1f GFLOPS (%3d runs)\n", - N, N, s_q5_0, n_q5_0, s_q5_1, n_q5_1, s_q8_0, n_q8_0); + N, N, s_q5_0, n_q5_0, s_q5_1, n_q5_1, s_q8_0, n_q8_0); s += strbuf; // F16 | F32 snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: F16 %7.1f GFLOPS (%3d runs) | F32 %7.1f GFLOPS (%3d runs)\n", - N, N, s_fp16, n_fp16, s_fp32, n_fp32); + N, N, s_fp16, n_fp16, s_fp32, n_fp32); s += strbuf; } @@ -8296,29 +9826,45 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { // token-level timestamps // -static int64_t sample_to_timestamp(int i_sample) { - return (100ll*i_sample)/WHISPER_SAMPLE_RATE; +static int64_t sample_to_timestamp(int i_sample) +{ + return (100ll * i_sample) / WHISPER_SAMPLE_RATE; } // a cost-function / heuristic that is high for text that takes longer to pronounce // obviously, can be improved -static float voice_length(const std::string & text) { +static float voice_length(const std::string &text) +{ float res = 0.0f; - for (char c : text) { - if (c == ' ') { + for (char c : text) + { + if (c == ' ') + { res += 0.01f; - } else if (c == ',') { + } + else if (c == ',') + { res += 2.00f; - } else if (c == '.') { + } + else if (c == '.') + { res += 3.00f; - } else if (c == '!') { + } + else if (c == '!') + { res += 3.00f; - } else if (c == '?') { + } + else if (c == '?') + { res += 3.00f; - } else if (c >= '0' && c <= '9') { + } + else if (c >= '0' && c <= '9') + { res += 3.00f; - } else { + } + else + { res += 1.00f; } } @@ -8327,48 +9873,56 @@ static float voice_length(const std::string & text) { } // average the fabs of the signal -static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { +static std::vector get_signal_energy(const float *signal, int n_samples, int n_samples_per_half_window) +{ const int hw = n_samples_per_half_window; std::vector result(n_samples); - for (int i = 0; i < n_samples; i++) { + for (int i = 0; i < n_samples; i++) + { float sum = 0; - for (int j = -hw; j <= hw; j++) { - if (i + j >= 0 && i + j < n_samples) { + for (int j = -hw; j <= hw; j++) + { + if (i + j >= 0 && i + j < n_samples) + { sum += fabs(signal[i + j]); } } - result[i] = sum/(2*hw + 1); + result[i] = sum / (2 * hw + 1); } return result; } -static int timestamp_to_sample(int64_t t, int64_t segment_t0, int n_samples) { +static int timestamp_to_sample(int64_t t, int64_t segment_t0, int n_samples) +{ // Convert absolute timestamp to segment-relative timestamp int64_t relative_t = t - segment_t0; int sample = (int)((relative_t * WHISPER_SAMPLE_RATE) / 100); return std::max(0, std::min(n_samples - 1, sample)); } -static int64_t sample_to_timestamp(int i_sample, int64_t segment_t0) { +static int64_t sample_to_timestamp(int i_sample, int64_t segment_t0) +{ int64_t relative_timestamp = (100ll * i_sample) / WHISPER_SAMPLE_RATE; return relative_timestamp + segment_t0; } static void whisper_exp_compute_token_level_timestamps( - struct whisper_context & ctx, - struct whisper_state & state, - int i_segment, - float thold_pt, - float thold_ptsum) { - auto & segment = state.result_all[i_segment]; - auto & tokens = segment.tokens; + struct whisper_context &ctx, + struct whisper_state &state, + int i_segment, + float thold_pt, + float thold_ptsum) +{ + auto &segment = state.result_all[i_segment]; + auto &tokens = segment.tokens; const int n_samples = state.energy.size(); - if (n_samples == 0) { + if (n_samples == 0) + { WHISPER_LOG_ERROR("%s: no signal data available\n", __func__); return; } @@ -8378,44 +9932,53 @@ static void whisper_exp_compute_token_level_timestamps( const int n = tokens.size(); - if (n == 0) { + if (n == 0) + { return; } - if (n == 1) { + if (n == 1) + { tokens[0].t0 = t0; tokens[0].t1 = t1; return; } - auto & t_beg = state.t_beg; - auto & t_last = state.t_last; - auto & tid_last = state.tid_last; + auto &t_beg = state.t_beg; + auto &t_last = state.t_last; + auto &tid_last = state.tid_last; - for (int j = 0; j < n; ++j) { - auto & token = tokens[j]; + for (int j = 0; j < n; ++j) + { + auto &token = tokens[j]; - if (j == 0) { - if (token.id == whisper_token_beg(&ctx)) { - tokens[j ].t0 = t0; - tokens[j ].t1 = t0; + if (j == 0) + { + if (token.id == whisper_token_beg(&ctx)) + { + tokens[j].t0 = t0; + tokens[j].t1 = t0; tokens[j + 1].t0 = t0; - t_beg = t0; - t_last = t0; + t_beg = t0; + t_last = t0; tid_last = whisper_token_beg(&ctx); - } else { - tokens[j ].t0 = t_last; + } + else + { + tokens[j].t0 = t_last; } } - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); + const int64_t tt = t_beg + 2 * (token.tid - whisper_token_beg(&ctx)); tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); - if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { - if (j > 0) { + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) + { + if (j > 0) + { tokens[j - 1].t1 = tt; } tokens[j].t0 = tt; @@ -8435,52 +9998,63 @@ static void whisper_exp_compute_token_level_timestamps( int p0 = 0; int p1 = 0; - while (true) { - while (p1 < n && tokens[p1].t1 < 0) { + while (true) + { + while (p1 < n && tokens[p1].t1 < 0) + { p1++; } - if (p1 >= n) { + if (p1 >= n) + { p1--; } - //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + // printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); - if (p1 > p0) { + if (p1 > p0) + { double psum = 0.0; - for (int j = p0; j <= p1; j++) { + for (int j = p0; j <= p1; j++) + { psum += tokens[j].vlen; } - //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + // printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); const double dt = tokens[p1].t1 - tokens[p0].t0; // split the time proportionally to the voice length - for (int j = p0 + 1; j <= p1; j++) { - const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + for (int j = p0 + 1; j <= p1; j++) + { + const double ct = tokens[j - 1].t0 + dt * tokens[j - 1].vlen / psum; tokens[j - 1].t1 = ct; - tokens[j ].t0 = ct; + tokens[j].t0 = ct; } } p1++; p0 = p1; - if (p1 >= n) { + if (p1 >= n) + { break; } } } // fix up (just in case) - for (int j = 0; j < n - 1; j++) { - if (tokens[j].t1 < 0) { + for (int j = 0; j < n - 1; j++) + { + if (tokens[j].t1 < 0) + { tokens[j + 1].t0 = tokens[j].t1; } - if (j > 0) { - if (tokens[j - 1].t1 > tokens[j].t0) { + if (j > 0) + { + if (tokens[j - 1].t1 > tokens[j].t0) + { tokens[j].t0 = tokens[j - 1].t1; tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); } @@ -8490,10 +10064,12 @@ static void whisper_exp_compute_token_level_timestamps( // VAD // expand or contract tokens based on voice activity { - const int hw = WHISPER_SAMPLE_RATE/8; + const int hw = WHISPER_SAMPLE_RATE / 8; - for (int j = 0; j < n; j++) { - if (tokens[j].id >= whisper_token_eot(&ctx)) { + for (int j = 0; j < n; j++) + { + if (tokens[j].id >= whisper_token_eot(&ctx)) + { continue; } @@ -8507,26 +10083,35 @@ static void whisper_exp_compute_token_level_timestamps( float sum = 0.0f; - for (int k = ss0; k < ss1; k++) { + for (int k = ss0; k < ss1; k++) + { sum += state.energy[k]; } - const float thold = 0.5*sum/ns; + const float thold = 0.5 * sum / ns; { int k = s0; - if (state.energy[k] > thold && j > 0) { - while (k > 0 && state.energy[k] > thold) { + if (state.energy[k] > thold && j > 0) + { + while (k > 0 && state.energy[k] > thold) + { k--; } tokens[j].t0 = sample_to_timestamp(k, segment.t0); - if (tokens[j].t0 < tokens[j - 1].t1) { + if (tokens[j].t0 < tokens[j - 1].t1) + { tokens[j].t0 = tokens[j - 1].t1; - } else { + } + else + { s0 = k; } - } else { - while (state.energy[k] < thold && k < s1) { + } + else + { + while (state.energy[k] < thold && k < s1) + { k++; } s0 = k; @@ -8536,18 +10121,26 @@ static void whisper_exp_compute_token_level_timestamps( { int k = s1; - if (state.energy[k] > thold) { - while (k < n_samples - 1 && state.energy[k] > thold) { + if (state.energy[k] > thold) + { + while (k < n_samples - 1 && state.energy[k] > thold) + { k++; } tokens[j].t1 = sample_to_timestamp(k, segment.t0); - if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) { + if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) + { tokens[j].t1 = tokens[j + 1].t0; - } else { + } + else + { s1 = k; } - } else { - while (state.energy[k] < thold && k > s0) { + } + else + { + while (state.energy[k] < thold && k > s0) + { k--; } s1 = k; @@ -8572,7 +10165,7 @@ static void whisper_exp_compute_token_level_timestamps( //} // debug info - //for (int j = 0; j < n; ++j) { + // for (int j = 0; j < n; ++j) { // const auto & token = tokens[j]; // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, @@ -8590,20 +10183,30 @@ static void whisper_exp_compute_token_level_timestamps( // n_text_layer -> total text layers on model // n_head -> total heads per text layer on model -static std::vector get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) { +static std::vector get_alignment_heads_by_layer(const whisper_context_params &cparams, int il, int n_text_layer, int n_head) +{ std::vector ret; - if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) + { return ret; - } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) { - if (il >= n_text_layer - cparams.dtw_n_top) { - for (int32_t i = 0; i < n_head; ++i) { + } + else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) + { + if (il >= n_text_layer - cparams.dtw_n_top) + { + for (int32_t i = 0; i < n_head; ++i) + { ret.push_back(i); } } - } else { + } + else + { const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset); - for (size_t i = 0; i < aheads.n_heads; ++i) { - if (aheads.heads[i].n_text_layer == il) { + for (size_t i = 0; i < aheads.n_heads; ++i) + { + if (aheads.heads[i].n_text_layer == il) + { ret.push_back(aheads.heads[i].n_head); } } @@ -8614,13 +10217,14 @@ static std::vector get_alignment_heads_by_layer(const whisper_context_ // dtw + backtrace to return found path // based on // https://github.com/openai/whisper/blob/main/whisper/timing.py#L83 -static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) { +static ggml_tensor *dtw_and_backtrace(ggml_context *ctx, ggml_tensor *x) +{ WHISPER_ASSERT(ggml_n_dims(x) == 2); int64_t N = x->ne[0]; int64_t M = x->ne[1]; - struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1); - struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1); + struct ggml_tensor *cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1); + struct ggml_tensor *trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1); cost = whisper_set_f32(cost, INFINITY); trace = whisper_set_i32(trace, -1); @@ -8629,21 +10233,28 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) { // dtw // supposedly can be optmized by computing diagonals in parallel ? // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most. - for (int64_t j = 1; j < M + 1; ++j) { - for (int64_t i = 1; i < N + 1; ++i) { + for (int64_t j = 1; j < M + 1; ++j) + { + for (int64_t i = 1; i < N + 1; ++i) + { float c0 = whisper_get_f32_nd(cost, i - 1, j - 1, 0, 0); float c1 = whisper_get_f32_nd(cost, i - 1, j, 0, 0); float c2 = whisper_get_f32_nd(cost, i, j - 1, 0, 0); float c; int32_t t; - if (c0 < c1 && c0 < c2) { + if (c0 < c1 && c0 < c2) + { c = c0; t = 0; - } else if (c1 < c0 && c1 < c2) { + } + else if (c1 < c0 && c1 < c2) + { c = c1; t = 1; - } else { + } + else + { c = c2; t = 2; } @@ -8656,30 +10267,38 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) { // Backtrace const int64_t BT_MAX_ROWS = N + M - 1; - struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2); + struct ggml_tensor *bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2); // trace[0, :] = 2; for (int64_t i = 0; i < M + 1; ++i) whisper_set_i32_nd(trace, 0, i, 0, 0, 2); - //trace[:, 0] = 1; + // trace[:, 0] = 1; for (int64_t i = 0; i < N + 1; ++i) whisper_set_i32_nd(trace, i, 0, 0, 0, 1); int bt_row_idx = BT_MAX_ROWS - 1; int64_t i = N; int64_t j = M; - while (i > 0 || j > 0) { + while (i > 0 || j > 0) + { whisper_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1); whisper_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1); --bt_row_idx; int32_t t = whisper_get_i32_nd(trace, i, j, 0, 0); - if (t == 0) { + if (t == 0) + { --i; --j; - } else if (t == 1) { + } + else if (t == 1) + { --i; - } else if (t == 2) { + } + else if (t == 2) + { --j; - } else { + } + else + { WHISPER_ASSERT(0); } } @@ -8688,11 +10307,13 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) { // Clip + transpose // This might not be entirely necessary for our case, but leaving it for now so output matrix // is identical to dtw on openAI timing.py - const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1; - ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols); - for (int64_t i = 0; i < 2; ++i) { - for (int64_t j = 0; j < result_n_cols; ++j) { - int32_t v = whisper_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0); + const int64_t result_n_cols = BT_MAX_ROWS - bt_row_idx - 1; + ggml_tensor *r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols); + for (int64_t i = 0; i < 2; ++i) + { + for (int64_t j = 0; j < result_n_cols; ++j) + { + int32_t v = whisper_get_i32_nd(bt, j + bt_row_idx + 1, i, 0, 0); whisper_set_i32_nd(r, i, j, 0, 0, v); } } @@ -8700,15 +10321,18 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) { return r; } -struct median_filter_user_data { +struct median_filter_user_data +{ int filter_width; }; -static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int /*nth*/, void * userdata) { - if (ith != 0) { +static void median_filter(struct ggml_tensor *dst, const struct ggml_tensor *a, int ith, int /*nth*/, void *userdata) +{ + if (ith != 0) + { return; } - int filter_width = ((median_filter_user_data *) userdata)->filter_width; + int filter_width = ((median_filter_user_data *)userdata)->filter_width; WHISPER_ASSERT(filter_width < a->ne[2]); WHISPER_ASSERT(filter_width % 2); WHISPER_ASSERT(ggml_n_dims(a) == 3); @@ -8716,22 +10340,29 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * std::vector filter; filter.reserve(filter_width); - for (int64_t i = 0; i < a->ne[0]; ++i) { - for (int64_t j = 0; j < a->ne[1]; ++j) { - for (int64_t k = 0; k < a->ne[2]; ++k) { - for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) { + for (int64_t i = 0; i < a->ne[0]; ++i) + { + for (int64_t j = 0; j < a->ne[1]; ++j) + { + for (int64_t k = 0; k < a->ne[2]; ++k) + { + for (int64_t off = -filter_width / 2; off <= filter_width / 2; ++off) + { // "reflect" padding int64_t idx = k + off; - if (idx < 0) { + if (idx < 0) + { idx = -idx; - } else if (idx >= a->ne[2]) { - idx = 2*(a->ne[2] - 1) - idx; + } + else if (idx >= a->ne[2]) + { + idx = 2 * (a->ne[2] - 1) - idx; } filter.push_back(whisper_get_f32_nd(a, i, j, idx, 0)); } std::sort(filter.begin(), filter.end()); - const float v = filter[filter.size()/2]; + const float v = filter[filter.size() / 2]; whisper_set_f32_nd(dst, i, j, k, 0, v); filter.clear(); } @@ -8740,15 +10371,15 @@ static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * } static void whisper_exp_compute_token_level_timestamps_dtw( - struct whisper_context * ctx, - struct whisper_state * state, - struct whisper_full_params params, - int i_segment, - size_t n_segments, - int seek, - int n_frames, - int medfilt_width, - int n_threads) + struct whisper_context *ctx, + struct whisper_state *state, + struct whisper_full_params params, + int i_segment, + size_t n_segments, + int seek, + int n_frames, + int medfilt_width, + int n_threads) { const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx; WHISPER_ASSERT(medfilt_width % 2); @@ -8759,27 +10390,33 @@ static void whisper_exp_compute_token_level_timestamps_dtw( // Our ggml buffer should be pre-allocated somewhere during init and reused // when we call this function struct ggml_init_params gparams = { - /*.mem_size =*/ ctx->params.dtw_mem_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, + /*.mem_size =*/ctx->params.dtw_mem_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/false, }; - struct ggml_context * gctx = ggml_init(gparams); + struct ggml_context *gctx = ggml_init(gparams); // Build token sequence that will be passed to decoder // sot + [lang] + text result + eot - std::vector tokens = { whisper_token_sot(ctx), }; - if (whisper_is_multilingual(ctx)) { + std::vector tokens = { + whisper_token_sot(ctx), + }; + if (whisper_is_multilingual(ctx)) + { const int lang_id = whisper_lang_id(params.language); state->lang_id = lang_id; tokens.push_back(whisper_token_lang(ctx, lang_id)); } const size_t sot_sequence_length = tokens.size(); tokens.push_back(whisper_token_not(ctx)); - for (size_t i = i_segment; i < i_segment + n_segments; ++i) { - auto & segment = state->result_all[i]; - for (auto &t: segment.tokens) { + for (size_t i = i_segment; i < i_segment + n_segments; ++i) + { + auto &segment = state->result_all[i]; + for (auto &t : segment.tokens) + { // Only text tokens - if (t.id < whisper_token_eot(ctx)) { + if (t.id < whisper_token_eot(ctx)) + { tokens.push_back(t.id); } } @@ -8793,13 +10430,14 @@ static void whisper_exp_compute_token_level_timestamps_dtw( whisper_kv_cache_clear(state->kv_self); whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0); whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1); - if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) { + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) + { WHISPER_LOG_INFO("DECODER FAILED\n"); WHISPER_ASSERT(0); } WHISPER_ASSERT(state->aheads_cross_QKs != nullptr); - const auto n_audio_tokens = n_frames/2; + const auto n_audio_tokens = n_frames / 2; WHISPER_ASSERT(state->aheads_cross_QKs != NULL); WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]); const auto n_tokens = state->aheads_cross_QKs->ne[0]; @@ -8811,17 +10449,18 @@ static void whisper_exp_compute_token_level_timestamps_dtw( // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims WHISPER_ASSERT(state->aheads_cross_QKs->type == GGML_TYPE_F32); WHISPER_ASSERT(ggml_is_contiguous(state->aheads_cross_QKs)); - ggml_tensor * w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads); - auto & data = state->aheads_cross_QKs_data; + ggml_tensor *w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads); + auto &data = state->aheads_cross_QKs_data; data.resize(n_tokens * n_audio_ctx * n_heads); ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads); - for (int k = 0; k < n_heads; ++k) { - for (int j = 0; j < n_audio_tokens; ++j) { + for (int k = 0; k < n_heads; ++k) + { + for (int j = 0; j < n_audio_tokens; ++j) + { memcpy( - (char *) w->data + j * w->nb[1] + k * w->nb[2], + (char *)w->data + j * w->nb[1] + k * w->nb[2], data.data() + j * n_tokens + k * n_tokens * n_audio_ctx, - n_tokens * sizeof(float) - ); + n_tokens * sizeof(float)); } } @@ -8832,7 +10471,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw( // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims w = ggml_norm(gctx, w, 1e-9f); - w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3); + w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0, 3), 0, 2, 1, 3); // Pass median filter - this is done over AUDIO_TOKENS dimension. // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims @@ -8852,29 +10491,33 @@ static void whisper_exp_compute_token_level_timestamps_dtw( w = ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]); // Compute - struct ggml_cgraph * gf = ggml_new_graph(gctx); + struct ggml_cgraph *gf = ggml_new_graph(gctx); ggml_build_forward_expand(gf, w); - ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + ggml_backend_ptr backend{ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr)}; ggml_backend_graph_compute(backend.get(), gf); - ggml_tensor * alignment = dtw_and_backtrace(gctx, w); + ggml_tensor *alignment = dtw_and_backtrace(gctx, w); // Place timestamps on segments int32_t last_v = 0; auto seg_i = state->result_all.begin() + i_segment; auto tok_i = seg_i->tokens.begin(); - for (int i = 0; i < alignment->ne[1]; ++i) { + for (int i = 0; i < alignment->ne[1]; ++i) + { int32_t v = whisper_get_i32_nd(alignment, 0, i, 0, 0); - if (v != last_v) { + if (v != last_v) + { int32_t time_index = whisper_get_i32_nd(alignment, 1, i, 0, 0); int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio last_v = v; // Skip non-text tokens - while (!(tok_i->id < whisper_token_eot(ctx))) { + while (!(tok_i->id < whisper_token_eot(ctx))) + { ++tok_i; - if (tok_i == seg_i->tokens.end()) { + if (tok_i == seg_i->tokens.end()) + { ++seg_i; tok_i = seg_i->tokens.begin(); } @@ -8882,7 +10525,8 @@ static void whisper_exp_compute_token_level_timestamps_dtw( tok_i->t_dtw = timestamp; ++tok_i; - if (tok_i == seg_i->tokens.end()) { + if (tok_i == seg_i->tokens.end()) + { ++seg_i; tok_i = seg_i->tokens.begin(); } @@ -8902,27 +10546,33 @@ static void whisper_exp_compute_token_level_timestamps_dtw( ggml_free(gctx); } -void whisper_log_set(ggml_log_callback log_callback, void * user_data) { +void whisper_log_set(ggml_log_callback log_callback, void *user_data) +{ g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default; g_state.log_callback_user_data = user_data; ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); } -const char * whisper_version(void) { +const char *whisper_version(void) +{ return WHISPER_VERSION; } GGML_ATTRIBUTE_FORMAT(2, 3) -static void whisper_log_internal(ggml_log_level level, const char * format, ...) { +static void whisper_log_internal(ggml_log_level level, const char *format, ...) +{ va_list args; va_start(args, format); char buffer[1024]; int len = vsnprintf(buffer, 1024, format, args); - if (len < 1024) { + if (len < 1024) + { g_state.log_callback(level, buffer, g_state.log_callback_user_data); - } else { - char* buffer2 = new char[len+1]; - vsnprintf(buffer2, len+1, format, args); + } + else + { + char *buffer2 = new char[len + 1]; + vsnprintf(buffer2, len + 1, format, args); buffer2[len] = 0; g_state.log_callback(level, buffer2, g_state.log_callback_user_data); delete[] buffer2; @@ -8930,11 +10580,13 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...) va_end(args); } -static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) { - (void) level; - (void) user_data; +static void whisper_log_callback_default(ggml_log_level level, const char *text, void *user_data) +{ + (void)level; + (void)user_data; #ifndef WHISPER_DEBUG - if (level == GGML_LOG_LEVEL_DEBUG) { + if (level == GGML_LOG_LEVEL_DEBUG) + { return; } #endif