pandora_transcribe/transcribe.py

280 lines
8.1 KiB
Python

import json
import logging
import os
import shutil
import signal
import subprocess
import tempfile
import time
import ox
import ox.iso
from django.conf import settings
from annotation import tasks
from item import utils
from itemlist.models import List
from item.models import Item
from user.models import User
logger = logging.getLogger(__name__)
TARGET_LENGTH = getattr(settings, 'TRANSCRIBE_TARGET_LENGTH', 200)
def prepare_annotations(result, join_sentences=False):
if join_sentences:
return prepare_joint_annotations(result)
annotations = []
for segment in result["segments"]:
annotations.append(
{
"in": segment["start"],
"out": segment["end"] + 0.3,
"value": segment["text"].strip(),
}
)
return annotations
def prepare_joint_annotations(result, target_length=TARGET_LENGTH):
abbrevs = ["Mr.", "Mrs.", "Dr."]
ignore = []
phrase_sounds = []
segments = result["segments"]
all_words = []
for s in segments:
all_words.extend(s["words"])
new_segs = []
sentence = ""
for w in all_words:
if not w == all_words[-1]:
next_w = all_words[all_words.index(w) + 1]
else:
# w is last word
if sentence == "" and w["text"] in ignore:
continue
if sentence == "":
in_ = w["start"]
# 0th word of a sentence
if w["text"] in ignore and next_w["text"][0].isupper():
continue
if sentence == "The music " and next_w["text"][0] == "The":
sentence = ""
continue
sentence += w["text"] + " "
# if this is a short sentence and next word starts less than 1 sec away
# and not last word of entire text
if (
w["text"].endswith(".")
and w != all_words[-1]
and (next_w["start"] - w["end"]) < 0.8
and len(sentence) < target_length
and next_w["text"] not in ignore
):
# then do not end this sentence yet
continue
if (
w["text"].endswith(".") and w["text"] not in abbrevs and len(w["text"]) > 2
) or (
w["text"] in ignore
and sentence.strip() == w["text"]
and (w == all_words[-1] or next_w["text"][0].isupper())
):
# end the sentence, delay end a bit
out_ = w["end"] + 0.3
sentence_dict = {"in": in_, "out": out_, "value": sentence.strip()}
new_segs.append(sentence_dict)
sentence = ""
annotations = list(filter(lambda i: i["value"].strip() not in ignore, new_segs))
return annotations
def run_demucs(src, output):
cmd = [
"/opt/whisper-timestamped/bin/demucs"
"--two-stems", "vocals",
"-o", output,
src
]
subprocess.check_call(cmd)
wav = glob("%s/htdemucs/*/vocals.wav" % output)[0]
return wav
def run_whisper(src, language=None, translate=False, gpu=False, model="small", demucs=False):
tmp = tempfile.mkdtemp()
if demucs:
try:
src = run_demucs(src, tmp)
except:
logger.error("failed to run demucs for %s", src)
shutil.rmtree(tmp)
return None
output = os.path.join(tmp, "output.json")
run_py = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run_whisper.py")
cmd = ["/opt/whisper-timestamped/bin/python", run_py]
cmd += ["--model", model]
if language:
cmd += ["--language", language]
if translate and language in translate:
cmd += ["--translate"]
language = "en"
cmd += [src, output]
try:
subprocess.check_call(cmd)
except:
logger.error("failed to run: %s", cmd)
shutil.rmtree(tmp)
return None
with open(output) as fd:
response = json.load(fd)
#shutil.rmtree(tmp)
return response
def extract_subtitles(item, user, layer, translate, gpu=False, join_sentences=False, model="small"):
language = None
if "language" not in item.data:
language = None
else:
language = ox.iso.langTo2Code(item.data["language"][0])
if not item.streams():
logger.error("skip item without media %s: %s", item.public_id)
return False
src = item.streams()[0].media.path
response = run_whisper(src, language, translate, gpu, model)
if not response:
logger.error("extract failed for %s", item.public_id)
return False
annotations = prepare_annotations(response, join_sentences=join_sentences)
if not annotations:
return False
if language and language != "en":
for annotation in annotations:
annotation["value"] = '<span lang="%s">%s</span>' % (
language,
annotation["value"],
)
tasks.add_annotations.delay(
{
"item": item.public_id,
"layer": layer,
"user": user.username,
"annotations": annotations,
}
)
return True
def extract_subtitles_cmd(item, user, layer, translate, gpu=False, model="small"):
if "language" not in item.data:
language = None
else:
language = ox.iso.langTo2Code(item.data["language"][0])
if not item.streams():
logger.error("skip item without media %s: %s", item.public_id)
return False
src = item.streams()[0].media.path
tmp = tempfile.mkdtemp()
cmd = ["/opt/whisper-timestamped/bin/whisper_timestamped", "--model", model]
if language:
cmd += ["--language", language]
if translate and language in translate:
cmd += ["--task", "translate"]
language = "en"
if not gpu:
cmd += [
"--fp16",
"False",
]
cmd += [
"-f",
"srt",
"--accurate",
"--output_dir",
tmp,
src,
]
try:
subprocess.check_output(cmd)
except:
logger.error(
"failed to extract subtitles from item %s\n%s", item.public_id, cmd
)
shutil.rmtree(tmp)
return False
annotations = []
for f in os.listdir(tmp):
if f.endswith(".srt") and "words.srt" not in f:
srt = os.path.join(tmp, f)
annotations = ox.srt.load(srt)
if not annotations:
logger.error("no subtitles detected %s", item.public_id)
return True
if language and language != "en":
for annotation in annotations:
annotation["value"] = '<span lang="%s">%s</span>' % (
language,
annotation["value"],
)
tasks.add_annotations.delay(
{
"item": item.public_id,
"layer": layer,
"user": user.username,
"annotations": annotations,
}
)
shutil.rmtree(tmp)
return True
def main(**kwargs):
user = User.objects.get(username=kwargs["user"])
queue = List.objects.get(user=user, name=kwargs["queue"])
done = List.objects.get(user=user, name=kwargs["done"])
layer = kwargs.get("layer")
translate = kwargs.get("translate")
if translate:
translate = dict([tt.split(":") for tt in translate.split(",")])
if not layer:
layer = utils.get_by_key(settings.CONFIG["layers"], "isSubtitles", True)
if layer:
layer = layer["id"]
else:
logger.error("no layer defined and config has no subtitle layer")
return
try:
while True:
wait = True
for item in queue.get_items(queue.user).all():
if extract_subtitles(
item,
user,
layer,
translate,
kwargs.get("gpu"),
join_sentences=kwargs.get("join_sentences"),
model=kwargs.get("model", "small")
):
done.items.add(item)
queue.items.remove(item)
wait = False
if wait:
time.sleep(5 * 60)
except KeyboardInterrupt:
pass