-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
130 lines (103 loc) · 4.07 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from fastai.vision.all import load_learner, PILImage
from PIL import Image
import mediapipe as mp
import streamlit as st
import cv2
st.set_page_config(
page_title="Sign Language Recognition",
page_icon="✌️",
layout="centered",
initial_sidebar_state="expanded",
)
st.markdown("<h1 style='text-align: center;'>Webcam Live Feed for Sign Language Recognition</h1>", unsafe_allow_html=True)
model_options = {
"Model v1": "./models/sign_language_model_v1.pkl",
"Model v2": "./models/sign_language_model_v2.pkl",
}
if 'selected_model' not in st.session_state:
st.session_state.selected_model = list(model_options.keys())[0]
selected_model = st.selectbox(
"Select Model",
list(model_options.keys()),
index=list(model_options.keys()).index(st.session_state.selected_model)
)
if selected_model != st.session_state.selected_model:
st.session_state.selected_model = selected_model
st.session_state.run = True
_, _, col3 = st.columns([1, 1, 2.6])
with col3:
run = st.button("Run")
_, col_frame, _ = st.columns([0.5, 2.5, 0.5])
with col_frame:
FRAME_WINDOW = st.image([])
_, col_cropped, _ = st.columns([1, 1, 1])
with col_cropped:
CROPPED_HAND_WINDOW = st.image([])
_, col_pred, _ = st.columns([1, 1, 1])
with col_pred:
PREDICTION_TEXT = st.empty()
def load_selected_model(model_name):
try:
return load_learner(model_options[model_name])
except Exception as e:
st.error(f"Error loading model: {e}")
st.stop()
learner = load_selected_model(selected_model)
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(min_detection_confidence=0.7, min_tracking_confidence=0.5)
mp_draw = mp.solutions.drawing_utils
cap = cv2.VideoCapture(0, cv2.CAP_V4L2)
def preprocess_image(hand_img):
hand_img_pil = Image.fromarray(cv2.cvtColor(hand_img, cv2.COLOR_BGR2RGB)).resize((224, 224))
hand_img_fastai = PILImage.create(hand_img_pil)
return hand_img_fastai
def predict_hand_sign(image):
try:
pred, _, probs = learner.predict(image)
return str(pred)
except Exception as e:
st.error(f"Prediction error: {e}")
return 'Error'
def process_video_feed():
while run:
ret, frame = cap.read()
if not ret:
st.write("Failed to capture image")
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = hands.process(frame_rgb)
if results.multi_hand_landmarks:
for hand_landmarks in results.multi_hand_landmarks:
x_min, x_max, y_min, y_max = get_hand_bounding_box(hand_landmarks, frame.shape)
cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
hand_img = frame[y_min:y_max, x_min:x_max]
if hand_img.size > 0:
hand_img_fastai = preprocess_image(hand_img)
predicted_label = predict_hand_sign(hand_img_fastai)
update_ui(frame, hand_img_fastai, predicted_label)
else:
update_ui(frame, None, "Hand cropped image is empty")
else:
update_ui(frame, None, "Hand not detected")
cap.release()
hands.close()
def get_hand_bounding_box(hand_landmarks, frame_shape):
x_min = min([lm.x for lm in hand_landmarks.landmark]) * frame_shape[1]
x_max = max([lm.x for lm in hand_landmarks.landmark]) * frame_shape[1]
y_min = min([lm.y for lm in hand_landmarks.landmark]) * frame_shape[0]
y_max = max([lm.y for lm in hand_landmarks.landmark]) * frame_shape[0]
width = x_max - x_min
height = y_max - y_min
margin_x = 0.1 * width
margin_y = 0.1 * height
x_min = int(max(0, x_min - margin_x))
x_max = int(min(frame_shape[1], x_max + margin_x))
y_min = int(max(0, y_min - margin_y))
y_max = int(min(frame_shape[0], y_max + margin_y))
return x_min, x_max, y_min, y_max
def update_ui(frame, cropped_img, prediction_text):
FRAME_WINDOW.image(frame)
if cropped_img:
CROPPED_HAND_WINDOW.image(cropped_img)
PREDICTION_TEXT.text(f"Predicted Label: {prediction_text}")
process_video_feed()