Skip to content

Commit

Permalink
Merge pull request #348 from shadowcz007/whisper-sensevoice
Browse files Browse the repository at this point in the history
Whisper sensevoice
  • Loading branch information
shadowcz007 authored Oct 12, 2024
2 parents 1dc3192 + edd7af9 commit 10c9eff
Show file tree
Hide file tree
Showing 9 changed files with 336 additions and 25 deletions.
15 changes: 13 additions & 2 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def mix_status(request):
# from .nodes.Vae import VAELoader,VAEDecode
from .nodes.ScreenShareNode import ScreenShareNode,FloatingVideo

from .nodes.Audio import AudioPlayNode,SpeechRecognition,SpeechSynthesis
from .nodes.Audio import AudioPlayNode,SpeechRecognition,SpeechSynthesis,AnalyzeAudioNone
from .nodes.Utils import CreateJsonNode,KeyInput,IncrementingListNode,ListSplit,CreateLoraNames,CreateSampler_names,CreateCkptNames,CreateSeedNode,TESTNODE_,TESTNODE_TOKEN,AppInfo,IntNumber,FloatSlider,TextInput,ColorInput,FontInput,TextToNumber,DynamicDelayProcessor,LimitNumber,SwitchByIndex,MultiplicationNode
from .nodes.Mask import PreviewMask_,MaskListReplace,MaskListMerge,OutlineMask,FeatheredMask

Expand Down Expand Up @@ -1103,6 +1103,7 @@ def mix_status(request):
"SpeechRecognition":SpeechRecognition,
"SpeechSynthesis":SpeechSynthesis,
"AudioPlay":AudioPlayNode,
"AnalyzeAudio":AnalyzeAudioNone,

# Text
"TextToNumber":TextToNumber,
Expand Down Expand Up @@ -1220,6 +1221,7 @@ def mix_status(request):
"SpeechSynthesis":"SpeechSynthesis ♾️Mixlab",
"SpeechRecognition":"SpeechRecognition ♾️Mixlab",
"AudioPlay":"Preview Audio ♾️Mixlab",
"AnalyzeAudio":"Analyze Audio ♾️Mixlab",

# Utils
"DynamicDelayProcessor":"DynamicDelayByText ♾️Mixlab",
Expand Down Expand Up @@ -1433,11 +1435,20 @@ def mix_status(request):
from .nodes.SenseVoice import SenseVoiceNode
logging.info('SenseVoice.available')
NODE_CLASS_MAPPINGS['SenseVoiceNode']=SenseVoiceNode
NODE_DISPLAY_NAME_MAPPINGS["SenseVoiceNode"]= "Sense Voice"
NODE_DISPLAY_NAME_MAPPINGS["SenseVoiceNode"]= "Sense Voice ♾️Mixlab"

except Exception as e:
logging.info('SenseVoice.available False' )

try:
from .nodes.Whisper import LoadWhisperModel,WhisperTranscribe
logging.info('Whisper.available')
NODE_CLASS_MAPPINGS['LoadWhisperModel_']=LoadWhisperModel
NODE_CLASS_MAPPINGS['WhisperTranscribe_']=WhisperTranscribe
NODE_DISPLAY_NAME_MAPPINGS["LoadWhisperModel_"]= "Load Whisper Model ♾️Mixlab"
NODE_DISPLAY_NAME_MAPPINGS["WhisperTranscribe_"]= "Whisper Transcribe ♾️Mixlab"

except Exception as e:
logging.info('Whisper.available False' )

logging.info('\033[93m -------------- \033[0m')
95 changes: 95 additions & 0 deletions nodes/Audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,101 @@
import folder_paths
import torchaudio

class AnyType(str):
"""A special class that is always equal in not equal comparisons. Credit to pythongosssss"""

def __ne__(self, __value: object) -> bool:
return False

any_type = AnyType("*")


def analyze_audio_data(audio_data):
total_duration = 0
total_gap_duration = 0
emotion_counts = {}
audio_types = set()
languages = set()

for i, entry in enumerate(audio_data):
# Calculate the duration of each audio segment
start_time = entry['start_time']
end_time = entry['end_time']
duration = end_time - start_time
total_duration += duration

# Count the emotions
if "emotion" in entry:
emotion = entry['emotion']
if emotion in emotion_counts:
emotion_counts[emotion] += 1
else:
emotion_counts[emotion] = 1

# Collect the audio types
if "audio_type" in entry:
audio_types.add(entry['audio_type'])

if "language" in entry:
languages.add(entry['language'])

# Calculate gap duration if not the last entry
if i < len(audio_data) - 1:
next_start_time = audio_data[i + 1]['start_time']
gap_duration = next_start_time - end_time
if gap_duration > 0:
total_gap_duration += gap_duration

# Get the most frequent emotion
if len(emotion_counts.keys())>0:
most_frequent_emotion = max(emotion_counts, key=emotion_counts.get)
else:
most_frequent_emotion=None

# Convert audio_types set to list for better readability
audio_types = list(audio_types)

languages=list(languages)

# Print the results
print(f"Total Effective Duration: {total_duration:.2f} seconds")
print(f"Total Gap Duration: {total_gap_duration:.2f} seconds")
print(f"Emotion Changes: {emotion_counts}")
print(f"Most Frequent Emotion: {most_frequent_emotion}")
print(f"Audio Types: {audio_types}")


return {
"total_duration": total_duration,
"total_gap_duration": total_gap_duration,
"emotion_changes": emotion_counts,
"most_frequent_emotion": most_frequent_emotion,
"audio_types": audio_types,
"languages":languages
}


# 分析音频数据
class AnalyzeAudioNone:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"json":(any_type,),},
}

RETURN_TYPES = (any_type,)
RETURN_NAMES = ("result",)

FUNCTION = "run"

CATEGORY = "♾️Mixlab/Audio"

def run(self,json):
result=analyze_audio_data(json)
return (result,)



class SpeechRecognition:
@classmethod
def INPUT_TYPES(s):
Expand Down
28 changes: 24 additions & 4 deletions nodes/ChatGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,12 @@ class JsonRepair:
def INPUT_TYPES(s):
return {
"required": {
"json_string":("STRING", {"forceInput": True,}),
"key":("STRING", {"multiline": False,"dynamicPrompts": False,"default": ""}),
}
"json_string":("STRING", {"forceInput": True,}),
"key":("STRING", {"multiline": False,"dynamicPrompts": False,"default": ""}),
},
"optional":{
"json_string2":("STRING", {"forceInput": True,})
},
}

INPUT_IS_LIST = False
Expand All @@ -800,15 +803,32 @@ def INPUT_TYPES(s):

CATEGORY = "♾️Mixlab/GPT"

def run(self, json_string,key=""):
def run(self, json_string,key="",json_string2=None):

if not isinstance(json_string, str):
json_string=json.dumps(json_string)

json_string=extract_json_strings(json_string)
# print(json_string)
good_json_string = repair_json(json_string)

# 将 JSON 字符串解析为 Python 对象
data = json.loads(good_json_string)

if json_string2!=None:
if not isinstance(json_string2, str):
json_string2=json.dumps(json_string2)

json_string2=extract_json_strings(json_string2)
# print(json_string)
good_json_string2 = repair_json(json_string2)

# 将 JSON 字符串解析为 Python 对象
data2 = json.loads(good_json_string2)

data={**data, **data2}


v=""
if key!="" and (key in data):
v=data[key]
Expand Down
43 changes: 27 additions & 16 deletions nodes/SenseVoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def format_time(seconds):

pattern = r"<\|(.+?)\|><\|(.+?)\|><\|(.+?)\|><\|(.+?)\|>(.+)"
match = re.match(pattern,asr_result)
print('#format_to_srt',match,asr_result)
if match==None:
return None, None, None, None,None,start_time,end_time,None
lang, emotion, audio_type, itn, text = match.groups()
# 😊 表示高兴,😡 表示愤怒,😔 表示悲伤。对于音频事件,🎼 表示音乐,😀 表示笑声,👏 表示掌声

Expand Down Expand Up @@ -115,16 +118,17 @@ def process_audio(self, waveform, _sample_rate, language, use_itn):
part[1],
asr_result)

results.append({
"language":lang,
"emotion":emotion,
"audio_type":audio_type,
"itn":itn,
"srt_content":srt_content,
"start_time":start_time,
"end_time":end_time,
"text":text
})
if lang!=None:
results.append({
"language":lang,
"emotion":emotion,
"audio_type":audio_type,
"itn":itn,
"srt_content":srt_content,
"start_time":start_time,
"end_time":end_time,
"text":text
})

self.vad.vad.all_reset_detection()
pbar.update(1) # 更新进度条
Expand Down Expand Up @@ -168,8 +172,9 @@ def INPUT_TYPES(s):

OUTPUT_NODE = True
FUNCTION = "run"
RETURN_TYPES = (any_type,)
RETURN_NAMES = ("result",)

RETURN_TYPES = (any_type,"STRING","STRING","FLOAT",)
RETURN_NAMES = ("result","srt","text","total_seconds",)

def run(self,audio,device,language,num_threads,use_int8,use_itn ):

Expand Down Expand Up @@ -200,16 +205,22 @@ def run(self,audio,device,language,num_threads,use_int8,use_itn ):

if 'waveform' in audio and 'sample_rate' in audio:
waveform = audio['waveform']
sample_rate = audio['sample_rate']
# print("Original shape:", waveform.shape) # 打印原始形状
if waveform.ndim == 3 and waveform.shape[0] == 1: # 检查是否为三维且 batch_size 为 1
waveform = waveform.squeeze(0) # 移除 batch_size 维度
waveform_numpy = waveform.numpy().transpose(1, 0) # 转换为 (num_samples, num_channels)
else:
raise ValueError("Unexpected waveform dimensions")

_sample_rate = audio['sample_rate']
print("waveform.shape:", waveform.shape)
total_length_seconds = waveform.shape[1] / sample_rate

waveform_numpy = waveform.numpy().transpose(1, 0) # 转换为 (num_samples, num_channels)

results=self.processor.process_audio(waveform_numpy, _sample_rate, language, use_itn)
results=self.processor.process_audio(waveform_numpy, sample_rate, language, use_itn)

srt_content="\n".join([s['srt_content'] for s in results])
text="\n".join([s['text'] for s in results])

return (results,)
return (results,srt_content,text,total_length_seconds,)

Loading

0 comments on commit 10c9eff

Please sign in to comment.