Skip to content

Commit f1982f6

Browse files
akashmjnggerganov
andauthored
whisper : support speaker segmentation (local diarization) of mono audio via tinydiarize (ggml-org#1058)
* add HuggingFace mirror to download ggml model * support tdrz via simple hack overriding solm tokens * fix incorrect translate/transcribe token_ids that are not static const * add apollo 13 sample for tdrz demo * render [SPEAKER TURN] consistently in all terminal output using vocab.id_to_token * extend whisper_segment with speaker_turn_next field and save in json output * fix failing go build * slipped in some python syntax whoops * whisper : finalize tinydiarize support (add flag + fixes) * whisper : tdrz support for word-level timestamps (respect max_len) * java : try to fix tests after adding tdrz_enable flag * main : remove TODO leftover * java : fix params order list after adding "tdrz_enable" * whisper : fix solm and add nosp token * main : print tinydiarize help --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent db39b70 commit f1982f6

File tree

8 files changed

+223
-138
lines changed

8 files changed

+223
-138
lines changed

Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,16 @@ samples:
308308
@wget --quiet --show-progress -O samples/gb1.ogg https://upload.wikimedia.org/wikipedia/commons/1/1f/George_W_Bush_Columbia_FINAL.ogg
309309
@wget --quiet --show-progress -O samples/hp0.ogg https://upload.wikimedia.org/wikipedia/en/d/d4/En.henryfphillips.ogg
310310
@wget --quiet --show-progress -O samples/mm1.wav https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav
311+
@wget --quiet --show-progress -O samples/a13.mp3 https://upload.wikimedia.org/wikipedia/commons/transcoded/6/6f/Apollo13-wehaveaproblem.ogg/Apollo13-wehaveaproblem.ogg.mp3
311312
@echo "Converting to 16-bit WAV ..."
312313
@ffmpeg -loglevel -0 -y -i samples/gb0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb0.wav
313314
@ffmpeg -loglevel -0 -y -i samples/gb1.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/gb1.wav
314315
@ffmpeg -loglevel -0 -y -i samples/hp0.ogg -ar 16000 -ac 1 -c:a pcm_s16le samples/hp0.wav
316+
@rm samples/*.ogg
315317
@ffmpeg -loglevel -0 -y -i samples/mm1.wav -ar 16000 -ac 1 -c:a pcm_s16le samples/mm0.wav
316318
@rm samples/mm1.wav
319+
@ffmpeg -loglevel -0 -y -i samples/a13.mp3 -ar 16000 -ac 1 -c:a pcm_s16le -ss 00:00:00 -to 00:00:30 samples/a13.wav
320+
@rm samples/a13.mp3
317321

318322
#
319323
# Models

bindings/go/whisper.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,13 @@ func (ctx *Context) Whisper_token_lang(lang_id int) Token {
270270
}
271271

272272
// Task tokens
273-
func Whisper_token_translate() Token {
274-
return Token(C.whisper_token_translate())
273+
func (ctx *Context) Whisper_token_translate() Token {
274+
return Token(C.whisper_token_translate((*C.struct_whisper_context)(ctx)))
275275
}
276276

277277
// Task tokens
278-
func Whisper_token_transcribe() Token {
279-
return Token(C.whisper_token_transcribe())
278+
func (ctx *Context) Whisper_token_transcribe() Token {
279+
return Token(C.whisper_token_transcribe((*C.struct_whisper_context)(ctx)))
280280
}
281281

282282
// Performance information

bindings/java/src/main/java/io/github/ggerganov/whispercpp/WhisperCppJnaLibrary.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ public interface WhisperCppJnaLibrary extends Library {
224224
int whisper_token_lang(Pointer ctx, int lang_id);
225225

226226
// Task tokens
227-
int whisper_token_translate();
228-
int whisper_token_transcribe();
227+
int whisper_token_translate (Pointer ctx);
228+
int whisper_token_transcribe(Pointer ctx);
229229

230230
// Performance information from the default state.
231231
void whisper_print_timings(Pointer ctx);

bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java

+9-1
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ public void speedUp(boolean enable) {
137137
/** Overwrite the audio context size (0 = use default). */
138138
public int audio_ctx;
139139

140+
/** Enable tinydiarize (default = false) */
141+
public CBool tdrz_enable;
142+
143+
/** Enable tinydiarize (default = false) */
144+
public void tdrzEnable(boolean enable) {
145+
tdrz_enable = enable ? CBool.TRUE : CBool.FALSE;
146+
}
147+
140148
/** Tokens to provide to the whisper decoder as an initial prompt.
141149
* These are prepended to any existing text context from a previous call. */
142150
public String initial_prompt;
@@ -302,7 +310,7 @@ protected List<String> getFieldOrder() {
302310
"no_context", "single_segment",
303311
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
304312
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "speed_up", "audio_ctx",
305-
"initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
313+
"tdrz_enable", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
306314
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty",
307315
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
308316
"new_segment_callback", "new_segment_callback_user_data",

examples/main/main.cpp

+83-57
Original file line numberDiff line numberDiff line change
@@ -68,28 +68,32 @@ struct whisper_params {
6868
float entropy_thold = 2.40f;
6969
float logprob_thold = -1.00f;
7070

71-
bool speed_up = false;
72-
bool translate = false;
73-
bool detect_language= false;
74-
bool diarize = false;
75-
bool split_on_word = false;
76-
bool no_fallback = false;
77-
bool output_txt = false;
78-
bool output_vtt = false;
79-
bool output_srt = false;
80-
bool output_wts = false;
81-
bool output_csv = false;
82-
bool output_jsn = false;
83-
bool output_lrc = false;
84-
bool print_special = false;
85-
bool print_colors = false;
86-
bool print_progress = false;
87-
bool no_timestamps = false;
88-
89-
std::string language = "en";
71+
bool speed_up = false;
72+
bool translate = false;
73+
bool detect_language = false;
74+
bool diarize = false;
75+
bool tinydiarize = false;
76+
bool split_on_word = false;
77+
bool no_fallback = false;
78+
bool output_txt = false;
79+
bool output_vtt = false;
80+
bool output_srt = false;
81+
bool output_wts = false;
82+
bool output_csv = false;
83+
bool output_jsn = false;
84+
bool output_lrc = false;
85+
bool print_special = false;
86+
bool print_colors = false;
87+
bool print_progress = false;
88+
bool no_timestamps = false;
89+
90+
std::string language = "en";
9091
std::string prompt;
9192
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
92-
std::string model = "models/ggml-base.en.bin";
93+
std::string model = "models/ggml-base.en.bin";
94+
95+
// [TDRZ] speaker turn string
96+
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
9397

9498
std::vector<std::string> fname_inp = {};
9599
std::vector<std::string> fname_out = {};
@@ -115,41 +119,42 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
115119
whisper_print_usage(argc, argv, params);
116120
exit(0);
117121
}
118-
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
119-
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
120-
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
121-
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
122-
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
123-
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
124-
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
125-
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
126-
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
127-
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
128-
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
129-
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
130-
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
131-
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
132-
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
133-
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
134-
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
135-
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
136-
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
137-
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
138-
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
139-
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
140-
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
141-
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
142-
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
143-
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
144-
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
145-
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
146-
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
147-
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
148-
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
149-
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
150-
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
151-
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
152-
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
122+
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
123+
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
124+
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
125+
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
126+
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
127+
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
128+
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
129+
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
130+
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
131+
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
132+
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
133+
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
134+
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
135+
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
136+
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
137+
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
138+
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
139+
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
140+
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
141+
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
142+
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
143+
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
144+
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
145+
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
146+
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
147+
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
148+
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
149+
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
150+
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
151+
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
152+
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
153+
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
154+
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
155+
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
156+
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
157+
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
153158
else {
154159
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
155160
whisper_print_usage(argc, argv, params);
@@ -182,6 +187,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
182187
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
183188
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
184189
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
190+
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
185191
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
186192
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
187193
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
@@ -297,6 +303,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
297303
printf("%s%s", speaker.c_str(), text);
298304
}
299305

306+
if (params.tinydiarize) {
307+
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
308+
printf("%s", params.tdrz_speaker_turn.c_str());
309+
}
310+
}
311+
300312
// with timestamps or speakers: each segment on new line
301313
if (!params.no_timestamps || params.diarize) {
302314
printf("\n");
@@ -564,6 +576,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
564576
const int n_segments = whisper_full_n_segments(ctx);
565577
for (int i = 0; i < n_segments; ++i) {
566578
const char * text = whisper_full_get_segment_text(ctx, i);
579+
567580
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
568581
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
569582

@@ -576,11 +589,15 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
576589
value_i("from", t0 * 10, false);
577590
value_i("to", t1 * 10, true);
578591
end_obj(false);
579-
value_s("text", text, !params.diarize);
592+
value_s("text", text, !params.diarize && !params.tinydiarize);
580593

581594
if (params.diarize && pcmf32s.size() == 2) {
582595
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
583596
}
597+
598+
if (params.tinydiarize) {
599+
value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
600+
}
584601
end_obj(i == (n_segments - 1));
585602
}
586603

@@ -777,6 +794,12 @@ int main(int argc, char ** argv) {
777794
exit(0);
778795
}
779796

797+
if (params.diarize && params.tinydiarize) {
798+
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
799+
whisper_print_usage(argc, argv, params);
800+
exit(0);
801+
}
802+
780803
// whisper init
781804

782805
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
@@ -818,11 +841,12 @@ int main(int argc, char ** argv) {
818841
if (params.detect_language) {
819842
params.language = "auto";
820843
}
821-
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
844+
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
822845
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
823846
params.n_threads, params.n_processors,
824847
params.language.c_str(),
825848
params.translate ? "translate" : "transcribe",
849+
params.tinydiarize ? "tdrz = 1, " : "",
826850
params.no_timestamps ? 0 : 1);
827851

828852
fprintf(stderr, "\n");
@@ -853,6 +877,8 @@ int main(int argc, char ** argv) {
853877

854878
wparams.speed_up = params.speed_up;
855879

880+
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
881+
856882
wparams.initial_prompt = params.prompt.c_str();
857883

858884
wparams.greedy.best_of = params.best_of;

models/download-ggml-model.sh

+7-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function get_script_path() {
2222
models_path="$(get_script_path)"
2323

2424
# Whisper models
25-
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large" )
25+
models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small.en-tdrz" "small" "medium.en" "medium" "large-v1" "large" )
2626

2727
# list available models
2828
function list_models {
@@ -50,6 +50,12 @@ if [[ ! " ${models[@]} " =~ " ${model} " ]]; then
5050
exit 1
5151
fi
5252

53+
# check if model contains `tdrz` and update the src and pfx accordingly
54+
if [[ $model == *"tdrz"* ]]; then
55+
src="https://huggingface.co/akashmjn/tinydiarize-whisper.cpp"
56+
pfx="resolve/main/ggml"
57+
fi
58+
5359
# download ggml model
5460

5561
printf "Downloading ggml model $model from '$src' ...\n"

0 commit comments

Comments
 (0)