|
| 1 | +import warnings |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +import argparse |
| 5 | +import soundfile as sf |
| 6 | +import torch.nn.functional as F |
| 7 | +import itertools as it |
| 8 | +from fairseq import utils |
| 9 | +from fairseq.models import BaseFairseqModel |
| 10 | +from fairseq.data import Dictionary |
| 11 | +from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecEncoder, Wav2Vec2CtcConfig |
| 12 | +from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
| 13 | + |
| 14 | +try: |
| 15 | + from flashlight.lib.text.dictionary import create_word_dict, load_words |
| 16 | + from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes |
| 17 | + from flashlight.lib.text.decoder import ( |
| 18 | + CriterionType, |
| 19 | + LexiconDecoderOptions, |
| 20 | + KenLM, |
| 21 | + LM, |
| 22 | + LMState, |
| 23 | + SmearingMode, |
| 24 | + Trie, |
| 25 | + LexiconDecoder, |
| 26 | + ) |
| 27 | +except: |
| 28 | + warnings.warn( |
| 29 | + "flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python" |
| 30 | + ) |
| 31 | + LM = object |
| 32 | + LMState = object |
| 33 | + |
| 34 | + |
| 35 | +class Wav2VecCtc(BaseFairseqModel): |
| 36 | + def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel): |
| 37 | + super().__init__() |
| 38 | + self.cfg = cfg |
| 39 | + self.w2v_encoder = w2v_encoder |
| 40 | + |
| 41 | + def upgrade_state_dict_named(self, state_dict, name): |
| 42 | + super().upgrade_state_dict_named(state_dict, name) |
| 43 | + return state_dict |
| 44 | + |
| 45 | + @classmethod |
| 46 | + def build_model(cls, cfg: Wav2Vec2CtcConfig, target_dictionary): ##change here |
| 47 | + """Build a new model instance.""" |
| 48 | + w2v_encoder = Wav2VecEncoder(cfg, target_dictionary) |
| 49 | + return cls(cfg, w2v_encoder) |
| 50 | + |
| 51 | + def get_normalized_probs(self, net_output, log_probs): |
| 52 | + """Get normalized probabilities (or log probs) from a net's output.""" |
| 53 | + |
| 54 | + logits = net_output["encoder_out"] |
| 55 | + if log_probs: |
| 56 | + return utils.log_softmax(logits.float(), dim=-1) |
| 57 | + else: |
| 58 | + return utils.softmax(logits.float(), dim=-1) |
| 59 | + |
| 60 | + def get_logits(self, net_output): |
| 61 | + logits = net_output["encoder_out"] |
| 62 | + padding = net_output["padding_mask"] |
| 63 | + if padding is not None and padding.any(): |
| 64 | + padding = padding.T |
| 65 | + logits[padding][...,0] = 0 |
| 66 | + logits[padding][...,1:] = float('-inf') |
| 67 | + |
| 68 | + return logits |
| 69 | + |
| 70 | + def forward(self, **kwargs): |
| 71 | + x = self.w2v_encoder(**kwargs) |
| 72 | + return x |
| 73 | + |
| 74 | + |
| 75 | +class W2lDecoder(object): |
| 76 | + def __init__(self, args, tgt_dict): |
| 77 | + self.tgt_dict = tgt_dict |
| 78 | + self.vocab_size = len(tgt_dict) |
| 79 | + #print(args) |
| 80 | + self.nbest = args['nbest'] |
| 81 | + |
| 82 | + # criterion-specific init |
| 83 | + if args['criterion'] == "ctc": |
| 84 | + self.criterion_type = CriterionType.CTC |
| 85 | + self.blank = ( |
| 86 | + tgt_dict.index("<ctc_blank>") |
| 87 | + if "<ctc_blank>" in tgt_dict.indices |
| 88 | + else tgt_dict.bos() |
| 89 | + ) |
| 90 | + if "<sep>" in tgt_dict.indices: |
| 91 | + self.silence = tgt_dict.index("<sep>") |
| 92 | + elif "|" in tgt_dict.indices: |
| 93 | + self.silence = tgt_dict.index("|") |
| 94 | + else: |
| 95 | + self.silence = tgt_dict.eos() |
| 96 | + self.asg_transitions = None |
| 97 | + elif args.criterion == "asg_loss": |
| 98 | + self.criterion_type = CriterionType.ASG |
| 99 | + self.blank = -1 |
| 100 | + self.silence = -1 |
| 101 | + self.asg_transitions = args.asg_transitions |
| 102 | + self.max_replabel = args.max_replabel |
| 103 | + assert len(self.asg_transitions) == self.vocab_size ** 2 |
| 104 | + else: |
| 105 | + raise RuntimeError(f"unknown criterion: {args.criterion}") |
| 106 | + |
| 107 | + def generate(self, models, sample, **unused): |
| 108 | + """Generate a batch of inferences.""" |
| 109 | + # model.forward normally channels prev_output_tokens into the decoder |
| 110 | + # separately, but SequenceGenerator directly calls model.encoder |
| 111 | + encoder_input = { |
| 112 | + k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" |
| 113 | + } |
| 114 | + emissions = self.get_emissions(models, encoder_input) |
| 115 | + return self.decode(emissions) |
| 116 | + |
| 117 | + def get_emissions(self, models, encoder_input): |
| 118 | + """Run encoder and normalize emissions""" |
| 119 | + model = models ## change here |
| 120 | + encoder_out = model(**encoder_input) |
| 121 | + if self.criterion_type == CriterionType.CTC: |
| 122 | + if hasattr(model, "get_logits"): |
| 123 | + emissions = model.get_logits(encoder_out) # no need to normalize emissions |
| 124 | + else: |
| 125 | + emissions = model.get_normalized_probs(encoder_out, log_probs=True) |
| 126 | + elif self.criterion_type == CriterionType.ASG: |
| 127 | + emissions = encoder_out["encoder_out"] |
| 128 | + return emissions.transpose(0, 1).float().cpu().contiguous() |
| 129 | + |
| 130 | + def get_tokens(self, idxs): |
| 131 | + """Normalize tokens by handling CTC blank, ASG replabels, etc.""" |
| 132 | + idxs = (g[0] for g in it.groupby(idxs)) |
| 133 | + if self.criterion_type == CriterionType.CTC: |
| 134 | + idxs = filter(lambda x: x != self.blank, idxs) |
| 135 | + elif self.criterion_type == CriterionType.ASG: |
| 136 | + idxs = filter(lambda x: x >= 0, idxs) |
| 137 | + idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel) |
| 138 | + return torch.LongTensor(list(idxs)) |
| 139 | + |
| 140 | + |
| 141 | +class W2lViterbiDecoder(W2lDecoder): |
| 142 | + def __init__(self, args, tgt_dict): |
| 143 | + super().__init__(args, tgt_dict) |
| 144 | + |
| 145 | + def decode(self, emissions): |
| 146 | + B, T, N = emissions.size() |
| 147 | + hypos = [] |
| 148 | + if self.asg_transitions is None: |
| 149 | + transitions = torch.FloatTensor(N, N).zero_() |
| 150 | + else: |
| 151 | + transitions = torch.FloatTensor(self.asg_transitions).view(N, N) |
| 152 | + viterbi_path = torch.IntTensor(B, T) |
| 153 | + workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N)) |
| 154 | + CpuViterbiPath.compute( |
| 155 | + B, |
| 156 | + T, |
| 157 | + N, |
| 158 | + get_data_ptr_as_bytes(emissions), |
| 159 | + get_data_ptr_as_bytes(transitions), |
| 160 | + get_data_ptr_as_bytes(viterbi_path), |
| 161 | + get_data_ptr_as_bytes(workspace), |
| 162 | + ) |
| 163 | + return [ |
| 164 | + [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] |
| 165 | + for b in range(B) |
| 166 | + ] |
| 167 | + |
| 168 | + |
| 169 | +class W2lKenLMDecoder(W2lDecoder): |
| 170 | + def __init__(self, args, tgt_dict): |
| 171 | + super().__init__(args, tgt_dict) |
| 172 | + |
| 173 | + self.unit_lm = getattr(args, "unit_lm", False) |
| 174 | + |
| 175 | + if args['lexicon']: |
| 176 | + self.lexicon = load_words(args['lexicon']) |
| 177 | + self.word_dict = create_word_dict(self.lexicon) |
| 178 | + self.unk_word = self.word_dict.get_index("<unk>") |
| 179 | + |
| 180 | + self.lm = KenLM(args['kenlm_model'], self.word_dict) |
| 181 | + self.trie = Trie(self.vocab_size, self.silence) |
| 182 | + |
| 183 | + start_state = self.lm.start(False) |
| 184 | + for i, (word, spellings) in enumerate(self.lexicon.items()): |
| 185 | + word_idx = self.word_dict.get_index(word) |
| 186 | + _, score = self.lm.score(start_state, word_idx) |
| 187 | + for spelling in spellings: |
| 188 | + spelling_idxs = [tgt_dict.index(token) for token in spelling] |
| 189 | + assert ( |
| 190 | + tgt_dict.unk() not in spelling_idxs |
| 191 | + ), f"{spelling} {spelling_idxs}" |
| 192 | + self.trie.insert(spelling_idxs, word_idx, score) |
| 193 | + self.trie.smear(SmearingMode.MAX) |
| 194 | + |
| 195 | + self.decoder_opts = LexiconDecoderOptions( |
| 196 | + beam_size=args['beam'], |
| 197 | + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), |
| 198 | + beam_threshold=args['beam_threshold'], |
| 199 | + lm_weight=args['lm_weight'], |
| 200 | + word_score=args['word_score'], |
| 201 | + unk_score=args['unk_weight'], |
| 202 | + sil_score=args['sil_weight'], |
| 203 | + log_add=False, |
| 204 | + criterion_type=self.criterion_type, |
| 205 | + ) |
| 206 | + |
| 207 | + if self.asg_transitions is None: |
| 208 | + N = 768 |
| 209 | + # self.asg_transitions = torch.FloatTensor(N, N).zero_() |
| 210 | + self.asg_transitions = [] |
| 211 | + |
| 212 | + self.decoder = LexiconDecoder( |
| 213 | + self.decoder_opts, |
| 214 | + self.trie, |
| 215 | + self.lm, |
| 216 | + self.silence, |
| 217 | + self.blank, |
| 218 | + self.unk_word, |
| 219 | + self.asg_transitions, |
| 220 | + self.unit_lm, |
| 221 | + ) |
| 222 | + else: |
| 223 | + assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" |
| 224 | + from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions |
| 225 | + |
| 226 | + d = {w: [[w]] for w in tgt_dict.symbols} |
| 227 | + self.word_dict = create_word_dict(d) |
| 228 | + self.lm = KenLM(args.kenlm_model, self.word_dict) |
| 229 | + self.decoder_opts = LexiconFreeDecoderOptions( |
| 230 | + beam_size=args.beam, |
| 231 | + beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), |
| 232 | + beam_threshold=args.beam_threshold, |
| 233 | + lm_weight=args.lm_weight, |
| 234 | + sil_score=args.sil_weight, |
| 235 | + log_add=False, |
| 236 | + criterion_type=self.criterion_type, |
| 237 | + ) |
| 238 | + self.decoder = LexiconFreeDecoder( |
| 239 | + self.decoder_opts, self.lm, self.silence, self.blank, [] |
| 240 | + ) |
| 241 | + |
| 242 | + |
| 243 | + def decode(self, emissions): |
| 244 | + B, T, N = emissions.size() |
| 245 | + hypos = [] |
| 246 | + for b in range(B): |
| 247 | + emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0) |
| 248 | + results = self.decoder.decode(emissions_ptr, T, N) |
| 249 | + |
| 250 | + nbest_results = results[: self.nbest] |
| 251 | + hypos.append( |
| 252 | + [ |
| 253 | + { |
| 254 | + "tokens": self.get_tokens(result.tokens), |
| 255 | + "score": result.score, |
| 256 | + "words": [ |
| 257 | + self.word_dict.get_entry(x) for x in result.words if x >= 0 |
| 258 | + ], |
| 259 | + } |
| 260 | + for result in nbest_results |
| 261 | + ] |
| 262 | + ) |
| 263 | + return hypos |
| 264 | + |
| 265 | +def get_feature(filepath): |
| 266 | + def postprocess(feats, sample_rate): |
| 267 | + if feats.dim == 2: |
| 268 | + feats = feats.mean(-1) |
| 269 | + |
| 270 | + assert feats.dim() == 1, feats.dim() |
| 271 | + |
| 272 | + with torch.no_grad(): |
| 273 | + feats = F.layer_norm(feats, feats.shape) |
| 274 | + return feats |
| 275 | + |
| 276 | + wav, sample_rate = sf.read(filepath) |
| 277 | + feats = torch.from_numpy(wav).float() |
| 278 | + feats = postprocess(feats, sample_rate) |
| 279 | + return feats |
| 280 | + |
| 281 | +def post_process(sentence: str, symbol: str): |
| 282 | + if symbol == "sentencepiece": |
| 283 | + sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() |
| 284 | + elif symbol == 'wordpiece': |
| 285 | + sentence = sentence.replace(" ", "").replace("_", " ").strip() |
| 286 | + elif symbol == 'letter': |
| 287 | + sentence = sentence.replace(" ", "").replace("|", " ").strip() |
| 288 | + elif symbol == "_EOW": |
| 289 | + sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() |
| 290 | + elif symbol is not None and symbol != 'none': |
| 291 | + sentence = (sentence + " ").replace(symbol, "").rstrip() |
| 292 | + return sentence |
| 293 | + |
| 294 | + |
| 295 | + |
| 296 | +def get_results(wav_path,dict_path,generator,use_cuda=False,w2v_path=None,model=None, half=None): |
| 297 | + sample = dict() |
| 298 | + net_input = dict() |
| 299 | + feature = get_feature(wav_path) |
| 300 | + target_dict = Dictionary.load(dict_path) |
| 301 | + |
| 302 | + model.eval() |
| 303 | + |
| 304 | + if half: |
| 305 | + net_input["source"] = feature.unsqueeze(0).half() |
| 306 | + else: |
| 307 | + net_input["source"] = feature.unsqueeze(0) |
| 308 | + |
| 309 | + padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0) |
| 310 | + |
| 311 | + net_input["padding_mask"] = padding_mask |
| 312 | + sample["net_input"] = net_input |
| 313 | + sample = utils.move_to_cuda(sample) if use_cuda else sample |
| 314 | + |
| 315 | + with torch.no_grad(): |
| 316 | + hypo = generator.generate(model, sample, prefix_tokens=None) |
| 317 | + hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu()) |
| 318 | + text=post_process(hyp_pieces, 'letter') |
| 319 | + |
| 320 | + return text |
| 321 | + |
| 322 | + |
| 323 | +def load_model(model_path): |
| 324 | + return torch.load(model_path)#,map_location=torch.device("cuda")) |
| 325 | + |
| 326 | + |
| 327 | + |
| 328 | +def get_args(lexicon_path, lm_path, BEAM=128, LM_WEIGHT=2, WORD_SCORE=-1): |
| 329 | + args = {} |
| 330 | + args['lexicon'] = lexicon_path |
| 331 | + args['kenlm_model'] = lm_path |
| 332 | + args['beam'] = BEAM |
| 333 | + args['beam_threshold'] = 25 |
| 334 | + args['lm_weight'] = LM_WEIGHT |
| 335 | + args['word_score'] = WORD_SCORE |
| 336 | + args['unk_weight'] = -np.inf |
| 337 | + args['sil_weight'] = 0 |
| 338 | + args['nbest'] = 1 |
| 339 | + args['criterion'] ='ctc' |
| 340 | + args['labels']='ltr' |
| 341 | + return args |
| 342 | + |
| 343 | +def parse_transcription(model_path, dict_path, wav_path, cuda, decoder="viterbi", lexicon_path=None, lm_path=None, half=None): |
| 344 | + target_dict = Dictionary.load(dict_path) |
| 345 | + args = get_args(lexicon_path, lm_path) |
| 346 | + |
| 347 | + if decoder=="viterbi": |
| 348 | + generator = W2lViterbiDecoder(args, target_dict) |
| 349 | + else: |
| 350 | + generator = W2lKenLMDecoder(args, target_dict) |
| 351 | + |
| 352 | + result = '' |
| 353 | + |
| 354 | + if cuda: |
| 355 | + model = load_model(model_path) |
| 356 | + model.cuda() |
| 357 | + else: |
| 358 | + model = load_model(model_path) |
| 359 | + |
| 360 | + |
| 361 | + if half: |
| 362 | + model.half() |
| 363 | + |
| 364 | + result = get_results(wav_path=wav_path, dict_path=dict_path, generator=generator, use_cuda=cuda, model=model, half=half) |
| 365 | + |
| 366 | + return result |
| 367 | + |
| 368 | +if __name__ == "__main__": |
| 369 | + parser = argparse.ArgumentParser(description='Run') |
| 370 | + parser.add_argument('-m', '--model', type=str, help="Custom model path") |
| 371 | + parser.add_argument('-d', '--dict', type=str, help="Dict path") |
| 372 | + parser.add_argument('-w', '--wav', type=str, help= "Wav file path") |
| 373 | + parser.add_argument('-c', '--cuda', default=False, type=bool, help="CUDA True or False") |
| 374 | + parser.add_argument('-D', '--decoder', type=str, help= "Which decoder to use kenlm or viterbi") |
| 375 | + parser.add_argument('-l', '--lexicon', default=None, type=str, help= "Lexicon path if decoder is kenlm") |
| 376 | + parser.add_argument('-L', '--lm-path', default=None, type=str, help= "Language mode path if decoder is kenlm") |
| 377 | + parser.add_argument('-H', '--half', default=False, type=bool, help="Half True or False") |
| 378 | + |
| 379 | + args_local = parser.parse_args() |
| 380 | + |
| 381 | + result = parse_transcription(args_local.model, args_local.dict, args_local.wav, args_local.cuda, args_local.decoder, args_local.lexicon, args_local.lm_path, args_local.half) |
| 382 | + print(result) |
0 commit comments