-
Notifications
You must be signed in to change notification settings - Fork 9
/
segment_service.py
136 lines (108 loc) · 4.7 KB
/
segment_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import json
import requests
from datetime import datetime
from flask import Flask, jsonify, abort, make_response, request, Response
from flask_cors import CORS
from model import WhisperSegmenter, WhisperSegmenterFast
import librosa
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import time
import threading
import base64
import io
# Make Flask application
app = Flask(__name__)
CORS(app)
# maintain the returned order of keys!
app.json.sort_keys = False
def decimal_to_seconds( decimal_time ):
splits = decimal_time.split(":")
if len(splits) == 2:
hours = 0
minutes, seconds = splits
elif len(splits) == 3:
hours, minutes, seconds = splits
else:
assert False
return int(hours) * 3600 + int(minutes) * 60 + float(seconds)
def seconds_to_decimal( seconds ):
hours = int(seconds // 3600)
minutes = int(seconds // 60)
seconds = seconds % 60
if hours > 0:
return "%d:%02d:%06.3f"%( hours, minutes, seconds )
else:
return "%d:%06.3f"%( minutes, seconds )
def bytes_to_base64_string(f_bytes):
return base64.b64encode(f_bytes).decode('ASCII')
def base64_string_to_bytes(base64_string):
return base64.b64decode(base64_string)
@app.route('/segment', methods=['POST'])
def segment():
global args, segmenter, sem
sem.acquire()
try:
request_info = request.json
### drop all the key-value pairs whose value is None, since we will determine the default value within this function.
request_info = { k:v for k,v in request_info.items() if v is not None}
audio_file_base64_string = request_info["audio_file_base64_string"]
sr = request_info["sr"]
min_frequency = request_info.get("min_frequency", None)
spec_time_step = request_info.get( "spec_time_step", None )
min_segment_length = request_info.get( "min_segment_length", None )
eps = request_info.get( "eps", None )
num_trials = request_info.get( "num_trials", 3 )
channel_id = request_info.get( "channel_id", 0 )
adobe_audition_compatible = request_info.get( "adobe_audition_compatible", False )
audio, _ = librosa.load( io.BytesIO(base64_string_to_bytes(audio_file_base64_string)),
sr = sr, mono=False )
### for multiple channel audio, choose the desired channel
if len(audio.shape) == 2:
audio = audio[channel_id]
prediction = segmenter.segment( audio, sr = sr, min_frequency = min_frequency, spec_time_step = spec_time_step, min_segment_length = min_segment_length, eps = eps,num_trials = num_trials, batch_size = args.batch_size )
except:
print("Segmentation Error! Returning an empty prediction ...")
prediction = {
"onset":[],
"offset":[],
"cluster":[]
}
adobe_audition_compatible = False
if adobe_audition_compatible:
Start_list = [ seconds_to_decimal( seconds ) for seconds in prediction["onset"] ]
Duration_list = [ seconds_to_decimal( end - start ) for start, end in zip( prediction["onset"], prediction["offset"] ) ]
Format_list = [ "decimal" ] * len(Start_list)
Type_list = [ "Cue" ] * len(Start_list)
Description_list = [ "" for _ in range(len(Start_list))]
Name_list = [ "" for _ in range( len(Start_list) ) ]
prediction = {
"\ufeffName":Name_list,
"Start":Start_list,
"Duration":Duration_list,
"Time Format":Format_list,
"Type":Type_list,
"Description":Description_list
}
sem.release()
return jsonify(prediction), 201
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--flask_port", help="The port of the flask app.", default=8050, type=int)
parser.add_argument("--model_path")
parser.add_argument("--device", help="cpu or cuda", default = "cuda")
parser.add_argument("--device_ids", help="a list of GPU ids", type = int, nargs = "+", default = [0,])
parser.add_argument("--batch_size", default=8, type=int)
args = parser.parse_args()
try:
segmenter = WhisperSegmenterFast( args.model_path, device = args.device, device_ids = args.device_ids )
print("The loaded model is the Ctranslated version.")
except:
segmenter = WhisperSegmenter( args.model_path, device = args.device, device_ids = args.device_ids )
print("The loaded model is the original huggingface version.")
sem = threading.Semaphore()
print("Waiting for requests...")
app.run(host='0.0.0.0', port=args.flask_port, threaded = True )