Skip to content

Commit

Permalink
add flask endpoint in transcriptions folder
Browse files Browse the repository at this point in the history
  • Loading branch information
tdanielles committed Nov 23, 2024
1 parent a783c5f commit 6593098
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 1 deletion.
Binary file added .DS_Store
Binary file not shown.
Empty file added transcription/.gitignore
Empty file.
40 changes: 40 additions & 0 deletions transcription/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from flask import Flask, request, jsonify
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import torch

app = Flask(__name__)

# load model and processor once during init
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

@app.route("/transcribe", methods=["POST"])
def transcribe():
if "image" not in request.files:
return jsonify({"error": "No image file provided"}), 400

image_file = request.files["image"]
try:
# open and preprocess image
image = Image.open(image_file).convert("RGB")
prompt = "<OCR>"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=False
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]

return jsonify({"transcription": generated_text})
except Exception as e:
return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
app.run(debug=True)
2 changes: 1 addition & 1 deletion transcription/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def florence():

prompt = "<OCR>"

url = "../assets/kkl.jpg"
url = "../assets/Filled_Logbook_page-0001.jpg"
image = Image.open(url).convert("RGB")

inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
Expand Down

0 comments on commit 6593098

Please sign in to comment.