Skip to content

Commit 8a5335f

Browse files
Add files via upload
1 parent 97dee25 commit 8a5335f

File tree

1 file changed

+382
-0
lines changed

1 file changed

+382
-0
lines changed
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
import warnings
2+
import torch
3+
import numpy as np
4+
import argparse
5+
import soundfile as sf
6+
import torch.nn.functional as F
7+
import itertools as it
8+
from fairseq import utils
9+
from fairseq.models import BaseFairseqModel
10+
from fairseq.data import Dictionary
11+
from fairseq.models.wav2vec.wav2vec2_asr import Wav2VecEncoder, Wav2Vec2CtcConfig
12+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
13+
14+
try:
15+
from flashlight.lib.text.dictionary import create_word_dict, load_words
16+
from flashlight.lib.sequence.criterion import CpuViterbiPath, get_data_ptr_as_bytes
17+
from flashlight.lib.text.decoder import (
18+
CriterionType,
19+
LexiconDecoderOptions,
20+
KenLM,
21+
LM,
22+
LMState,
23+
SmearingMode,
24+
Trie,
25+
LexiconDecoder,
26+
)
27+
except:
28+
warnings.warn(
29+
"flashlight python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/flashlight/tree/master/bindings/python"
30+
)
31+
LM = object
32+
LMState = object
33+
34+
35+
class Wav2VecCtc(BaseFairseqModel):
36+
def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel):
37+
super().__init__()
38+
self.cfg = cfg
39+
self.w2v_encoder = w2v_encoder
40+
41+
def upgrade_state_dict_named(self, state_dict, name):
42+
super().upgrade_state_dict_named(state_dict, name)
43+
return state_dict
44+
45+
@classmethod
46+
def build_model(cls, cfg: Wav2Vec2CtcConfig, target_dictionary): ##change here
47+
"""Build a new model instance."""
48+
w2v_encoder = Wav2VecEncoder(cfg, target_dictionary)
49+
return cls(cfg, w2v_encoder)
50+
51+
def get_normalized_probs(self, net_output, log_probs):
52+
"""Get normalized probabilities (or log probs) from a net's output."""
53+
54+
logits = net_output["encoder_out"]
55+
if log_probs:
56+
return utils.log_softmax(logits.float(), dim=-1)
57+
else:
58+
return utils.softmax(logits.float(), dim=-1)
59+
60+
def get_logits(self, net_output):
61+
logits = net_output["encoder_out"]
62+
padding = net_output["padding_mask"]
63+
if padding is not None and padding.any():
64+
padding = padding.T
65+
logits[padding][...,0] = 0
66+
logits[padding][...,1:] = float('-inf')
67+
68+
return logits
69+
70+
def forward(self, **kwargs):
71+
x = self.w2v_encoder(**kwargs)
72+
return x
73+
74+
75+
class W2lDecoder(object):
76+
def __init__(self, args, tgt_dict):
77+
self.tgt_dict = tgt_dict
78+
self.vocab_size = len(tgt_dict)
79+
#print(args)
80+
self.nbest = args['nbest']
81+
82+
# criterion-specific init
83+
if args['criterion'] == "ctc":
84+
self.criterion_type = CriterionType.CTC
85+
self.blank = (
86+
tgt_dict.index("<ctc_blank>")
87+
if "<ctc_blank>" in tgt_dict.indices
88+
else tgt_dict.bos()
89+
)
90+
if "<sep>" in tgt_dict.indices:
91+
self.silence = tgt_dict.index("<sep>")
92+
elif "|" in tgt_dict.indices:
93+
self.silence = tgt_dict.index("|")
94+
else:
95+
self.silence = tgt_dict.eos()
96+
self.asg_transitions = None
97+
elif args.criterion == "asg_loss":
98+
self.criterion_type = CriterionType.ASG
99+
self.blank = -1
100+
self.silence = -1
101+
self.asg_transitions = args.asg_transitions
102+
self.max_replabel = args.max_replabel
103+
assert len(self.asg_transitions) == self.vocab_size ** 2
104+
else:
105+
raise RuntimeError(f"unknown criterion: {args.criterion}")
106+
107+
def generate(self, models, sample, **unused):
108+
"""Generate a batch of inferences."""
109+
# model.forward normally channels prev_output_tokens into the decoder
110+
# separately, but SequenceGenerator directly calls model.encoder
111+
encoder_input = {
112+
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
113+
}
114+
emissions = self.get_emissions(models, encoder_input)
115+
return self.decode(emissions)
116+
117+
def get_emissions(self, models, encoder_input):
118+
"""Run encoder and normalize emissions"""
119+
model = models ## change here
120+
encoder_out = model(**encoder_input)
121+
if self.criterion_type == CriterionType.CTC:
122+
if hasattr(model, "get_logits"):
123+
emissions = model.get_logits(encoder_out) # no need to normalize emissions
124+
else:
125+
emissions = model.get_normalized_probs(encoder_out, log_probs=True)
126+
elif self.criterion_type == CriterionType.ASG:
127+
emissions = encoder_out["encoder_out"]
128+
return emissions.transpose(0, 1).float().cpu().contiguous()
129+
130+
def get_tokens(self, idxs):
131+
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
132+
idxs = (g[0] for g in it.groupby(idxs))
133+
if self.criterion_type == CriterionType.CTC:
134+
idxs = filter(lambda x: x != self.blank, idxs)
135+
elif self.criterion_type == CriterionType.ASG:
136+
idxs = filter(lambda x: x >= 0, idxs)
137+
idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
138+
return torch.LongTensor(list(idxs))
139+
140+
141+
class W2lViterbiDecoder(W2lDecoder):
142+
def __init__(self, args, tgt_dict):
143+
super().__init__(args, tgt_dict)
144+
145+
def decode(self, emissions):
146+
B, T, N = emissions.size()
147+
hypos = []
148+
if self.asg_transitions is None:
149+
transitions = torch.FloatTensor(N, N).zero_()
150+
else:
151+
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
152+
viterbi_path = torch.IntTensor(B, T)
153+
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
154+
CpuViterbiPath.compute(
155+
B,
156+
T,
157+
N,
158+
get_data_ptr_as_bytes(emissions),
159+
get_data_ptr_as_bytes(transitions),
160+
get_data_ptr_as_bytes(viterbi_path),
161+
get_data_ptr_as_bytes(workspace),
162+
)
163+
return [
164+
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
165+
for b in range(B)
166+
]
167+
168+
169+
class W2lKenLMDecoder(W2lDecoder):
170+
def __init__(self, args, tgt_dict):
171+
super().__init__(args, tgt_dict)
172+
173+
self.unit_lm = getattr(args, "unit_lm", False)
174+
175+
if args['lexicon']:
176+
self.lexicon = load_words(args['lexicon'])
177+
self.word_dict = create_word_dict(self.lexicon)
178+
self.unk_word = self.word_dict.get_index("<unk>")
179+
180+
self.lm = KenLM(args['kenlm_model'], self.word_dict)
181+
self.trie = Trie(self.vocab_size, self.silence)
182+
183+
start_state = self.lm.start(False)
184+
for i, (word, spellings) in enumerate(self.lexicon.items()):
185+
word_idx = self.word_dict.get_index(word)
186+
_, score = self.lm.score(start_state, word_idx)
187+
for spelling in spellings:
188+
spelling_idxs = [tgt_dict.index(token) for token in spelling]
189+
assert (
190+
tgt_dict.unk() not in spelling_idxs
191+
), f"{spelling} {spelling_idxs}"
192+
self.trie.insert(spelling_idxs, word_idx, score)
193+
self.trie.smear(SmearingMode.MAX)
194+
195+
self.decoder_opts = LexiconDecoderOptions(
196+
beam_size=args['beam'],
197+
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
198+
beam_threshold=args['beam_threshold'],
199+
lm_weight=args['lm_weight'],
200+
word_score=args['word_score'],
201+
unk_score=args['unk_weight'],
202+
sil_score=args['sil_weight'],
203+
log_add=False,
204+
criterion_type=self.criterion_type,
205+
)
206+
207+
if self.asg_transitions is None:
208+
N = 768
209+
# self.asg_transitions = torch.FloatTensor(N, N).zero_()
210+
self.asg_transitions = []
211+
212+
self.decoder = LexiconDecoder(
213+
self.decoder_opts,
214+
self.trie,
215+
self.lm,
216+
self.silence,
217+
self.blank,
218+
self.unk_word,
219+
self.asg_transitions,
220+
self.unit_lm,
221+
)
222+
else:
223+
assert args.unit_lm, "lexicon free decoding can only be done with a unit language model"
224+
from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions
225+
226+
d = {w: [[w]] for w in tgt_dict.symbols}
227+
self.word_dict = create_word_dict(d)
228+
self.lm = KenLM(args.kenlm_model, self.word_dict)
229+
self.decoder_opts = LexiconFreeDecoderOptions(
230+
beam_size=args.beam,
231+
beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))),
232+
beam_threshold=args.beam_threshold,
233+
lm_weight=args.lm_weight,
234+
sil_score=args.sil_weight,
235+
log_add=False,
236+
criterion_type=self.criterion_type,
237+
)
238+
self.decoder = LexiconFreeDecoder(
239+
self.decoder_opts, self.lm, self.silence, self.blank, []
240+
)
241+
242+
243+
def decode(self, emissions):
244+
B, T, N = emissions.size()
245+
hypos = []
246+
for b in range(B):
247+
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
248+
results = self.decoder.decode(emissions_ptr, T, N)
249+
250+
nbest_results = results[: self.nbest]
251+
hypos.append(
252+
[
253+
{
254+
"tokens": self.get_tokens(result.tokens),
255+
"score": result.score,
256+
"words": [
257+
self.word_dict.get_entry(x) for x in result.words if x >= 0
258+
],
259+
}
260+
for result in nbest_results
261+
]
262+
)
263+
return hypos
264+
265+
def get_feature(filepath):
266+
def postprocess(feats, sample_rate):
267+
if feats.dim == 2:
268+
feats = feats.mean(-1)
269+
270+
assert feats.dim() == 1, feats.dim()
271+
272+
with torch.no_grad():
273+
feats = F.layer_norm(feats, feats.shape)
274+
return feats
275+
276+
wav, sample_rate = sf.read(filepath)
277+
feats = torch.from_numpy(wav).float()
278+
feats = postprocess(feats, sample_rate)
279+
return feats
280+
281+
def post_process(sentence: str, symbol: str):
282+
if symbol == "sentencepiece":
283+
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
284+
elif symbol == 'wordpiece':
285+
sentence = sentence.replace(" ", "").replace("_", " ").strip()
286+
elif symbol == 'letter':
287+
sentence = sentence.replace(" ", "").replace("|", " ").strip()
288+
elif symbol == "_EOW":
289+
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
290+
elif symbol is not None and symbol != 'none':
291+
sentence = (sentence + " ").replace(symbol, "").rstrip()
292+
return sentence
293+
294+
295+
296+
def get_results(wav_path,dict_path,generator,use_cuda=False,w2v_path=None,model=None, half=None):
297+
sample = dict()
298+
net_input = dict()
299+
feature = get_feature(wav_path)
300+
target_dict = Dictionary.load(dict_path)
301+
302+
model.eval()
303+
304+
if half:
305+
net_input["source"] = feature.unsqueeze(0).half()
306+
else:
307+
net_input["source"] = feature.unsqueeze(0)
308+
309+
padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)
310+
311+
net_input["padding_mask"] = padding_mask
312+
sample["net_input"] = net_input
313+
sample = utils.move_to_cuda(sample) if use_cuda else sample
314+
315+
with torch.no_grad():
316+
hypo = generator.generate(model, sample, prefix_tokens=None)
317+
hyp_pieces = target_dict.string(hypo[0][0]["tokens"].int().cpu())
318+
text=post_process(hyp_pieces, 'letter')
319+
320+
return text
321+
322+
323+
def load_model(model_path):
324+
return torch.load(model_path)#,map_location=torch.device("cuda"))
325+
326+
327+
328+
def get_args(lexicon_path, lm_path, BEAM=128, LM_WEIGHT=2, WORD_SCORE=-1):
329+
args = {}
330+
args['lexicon'] = lexicon_path
331+
args['kenlm_model'] = lm_path
332+
args['beam'] = BEAM
333+
args['beam_threshold'] = 25
334+
args['lm_weight'] = LM_WEIGHT
335+
args['word_score'] = WORD_SCORE
336+
args['unk_weight'] = -np.inf
337+
args['sil_weight'] = 0
338+
args['nbest'] = 1
339+
args['criterion'] ='ctc'
340+
args['labels']='ltr'
341+
return args
342+
343+
def parse_transcription(model_path, dict_path, wav_path, cuda, decoder="viterbi", lexicon_path=None, lm_path=None, half=None):
344+
target_dict = Dictionary.load(dict_path)
345+
args = get_args(lexicon_path, lm_path)
346+
347+
if decoder=="viterbi":
348+
generator = W2lViterbiDecoder(args, target_dict)
349+
else:
350+
generator = W2lKenLMDecoder(args, target_dict)
351+
352+
result = ''
353+
354+
if cuda:
355+
model = load_model(model_path)
356+
model.cuda()
357+
else:
358+
model = load_model(model_path)
359+
360+
361+
if half:
362+
model.half()
363+
364+
result = get_results(wav_path=wav_path, dict_path=dict_path, generator=generator, use_cuda=cuda, model=model, half=half)
365+
366+
return result
367+
368+
if __name__ == "__main__":
369+
parser = argparse.ArgumentParser(description='Run')
370+
parser.add_argument('-m', '--model', type=str, help="Custom model path")
371+
parser.add_argument('-d', '--dict', type=str, help="Dict path")
372+
parser.add_argument('-w', '--wav', type=str, help= "Wav file path")
373+
parser.add_argument('-c', '--cuda', default=False, type=bool, help="CUDA True or False")
374+
parser.add_argument('-D', '--decoder', type=str, help= "Which decoder to use kenlm or viterbi")
375+
parser.add_argument('-l', '--lexicon', default=None, type=str, help= "Lexicon path if decoder is kenlm")
376+
parser.add_argument('-L', '--lm-path', default=None, type=str, help= "Language mode path if decoder is kenlm")
377+
parser.add_argument('-H', '--half', default=False, type=bool, help="Half True or False")
378+
379+
args_local = parser.parse_args()
380+
381+
result = parse_transcription(args_local.model, args_local.dict, args_local.wav, args_local.cuda, args_local.decoder, args_local.lexicon, args_local.lm_path, args_local.half)
382+
print(result)

0 commit comments

Comments
 (0)