From 0f6072782809155e8c467a1f08c4da4d24118213 Mon Sep 17 00:00:00 2001 From: Alexander Pearson-Goulart Date: Tue, 14 Nov 2023 20:32:22 -0800 Subject: [PATCH] frontend + backend SAM integration --- backend/deepcell_label/blueprints.py | 4 +- frontend/src/Project/Canvas/ComposeCanvas.js | 1 - .../Project/Canvas/ToolCanvas/SamCanvas.js | 10 +- .../Canvas/ToolCanvas/SamMaskCanvas.js | 133 ------------------ .../Project/Canvas/ToolCanvas/ToolCanvas.js | 4 +- frontend/src/Project/ProjectContext.js | 2 - .../service/edit/segment/samMachine.js | 40 +++--- .../service/labels/segmentApiMachine.js | 43 ++++-- 8 files changed, 54 insertions(+), 183 deletions(-) delete mode 100644 frontend/src/Project/Canvas/ToolCanvas/SamMaskCanvas.js diff --git a/backend/deepcell_label/blueprints.py b/backend/deepcell_label/blueprints.py index 99c43d81d..a3d2e1572 100644 --- a/backend/deepcell_label/blueprints.py +++ b/backend/deepcell_label/blueprints.py @@ -19,6 +19,8 @@ from deepcell_label.label import Edit from deepcell_label.loaders import Loader from deepcell_label.models import Project +import json + import cv2 import numpy as np @@ -289,7 +291,7 @@ def test_sam_prediction(): The output of this endpoint is an ndarray of 0's and 1's indicating where to draw the mask on the frontend. """ - json_data = request.get_json() + json_data = json.loads(request.data, strict=False) bbox = BBox(**json_data) image_embedding, ort_session = retrieve_sam_model_data() diff --git a/frontend/src/Project/Canvas/ComposeCanvas.js b/frontend/src/Project/Canvas/ComposeCanvas.js index 2b3e1331d..ee4bcab6b 100644 --- a/frontend/src/Project/Canvas/ComposeCanvas.js +++ b/frontend/src/Project/Canvas/ComposeCanvas.js @@ -7,7 +7,6 @@ import { useCanvas } from '../ProjectContext'; const Canvas = styled('canvas')``; export const ComposeCanvas = ({ bitmaps }) => { - console.log(bitmaps) const canvas = useCanvas(); const { sx, sy, zoom, sw, sh, scale } = useSelector( canvas, diff --git a/frontend/src/Project/Canvas/ToolCanvas/SamCanvas.js b/frontend/src/Project/Canvas/ToolCanvas/SamCanvas.js index 62b058e1a..832f4782d 100644 --- a/frontend/src/Project/Canvas/ToolCanvas/SamCanvas.js +++ b/frontend/src/Project/Canvas/ToolCanvas/SamCanvas.js @@ -3,7 +3,7 @@ import { useEffect, useRef, useState } from 'react'; import { useSam, useCanvas } from '../../ProjectContext'; import "./styles/sam-canvas.css" -const SamCanvas = ({setRunONNXCommand}) => { +const SamCanvas = () => { const canvas = useCanvas(); const width = useSelector(canvas, (state) => state.context.width); const height = useSelector(canvas, (state) => state.context.height); @@ -33,13 +33,7 @@ const SamCanvas = ({setRunONNXCommand}) => { useEffect(() => { // User has selected a region for segmentation if (!isMouseDown && startX && startY && endX && endY) { - if (window.confirm("Send selected region for segmentation?")) { - // sam.send({ type: 'SEND_TO_API' }) - setRunONNXCommand(true) - // setRunONNXCommand(false) - } else { - sam.send({ type: 'CLEAR_SELECTION' }) - } + sam.send({ type: 'SEND_TO_API' }) } }, [isMouseDown]) diff --git a/frontend/src/Project/Canvas/ToolCanvas/SamMaskCanvas.js b/frontend/src/Project/Canvas/ToolCanvas/SamMaskCanvas.js deleted file mode 100644 index 9c5bbba74..000000000 --- a/frontend/src/Project/Canvas/ToolCanvas/SamMaskCanvas.js +++ /dev/null @@ -1,133 +0,0 @@ -import { InferenceSession, Tensor } from "onnxruntime-web"; -import { useSelector } from '@xstate/react'; -import { useEffect, useRef, useState } from 'react'; -import { useSam, useCanvas } from '../../ProjectContext'; -import "./styles/sam-canvas.css" -import { onnxMaskToImage } from "./utils/util"; -import { modelData } from "./utils/onnxModelAPI"; -import npyjs from "npyjs"; -const ort = require("onnxruntime-web"); - -const SamMaskCanvas = ({runONNXCommand}) => { - const canvas = useCanvas(); - const width = useSelector(canvas, (state) => state.context.width); - const height = useSelector(canvas, (state) => state.context.height); - const zoom = useSelector(canvas, (state) => state.context.zoom); - const sx = useSelector(canvas, (state) => state.context.sx); - const sy = useSelector(canvas, (state) => state.context.sy); - const ref = useRef(null); - - const sam = useSam(); - - const x = useSelector(sam, (state) => state.context.x); - const y = useSelector(sam, (state) => state.context.y); - const isMouseDown = useSelector(sam, (state) => state.context.isMouseDown); - const startX = useSelector(sam, (state) => state.context.startX); - const startY = useSelector(sam, (state) => state.context.startY); - const endX = useSelector(sam, (state) => state.context.endX); - const endY = useSelector(sam, (state) => state.context.endY); - - const clicks = [ - {x: 40, y: 60, clickType: 1} - ] - - - // SAM Code from github - const [maskImg, setMaskImg] = useState(null) - const [model, setModel] = useState(null); // ONNX model - const [tensor, setTensor] = useState(null); // Image embedding tensor - - // The ONNX model expects the input to be rescaled to 1024. - // The modelScale state variable keeps track of the scale values. - // TODO: Change this to be dynamic based on the image size - const [modelScale, setModelScale] = useState({samScale: 0.5, height: 512, width: 512}); - - // Initialize the ONNX model. load the image, and load the SAM - // pre-computed image embedding - useEffect(() => { - console.log("INITTING") - // Initialize the ONNX model - const initModel = async () => { - console.log("CALLING MODEL") - try { - const URL = "sam_onnx_quantized_example.onnx"; - console.log("THIS IS URL", URL) - const model = await InferenceSession.create(URL, { executionProviders: ['wasm'] }); - console.log("MADE MODEL", model) - setModel(model); - } catch (e) { - console.log("THIS IS ERR") - console.log(e); - } - }; - initModel(); - - // Load the Segment Anything pre-computed embedding - Promise.resolve(loadNpyTensor("./tif_embedding.npy", "float32")).then( - (embedding) => setTensor(embedding) - ); - }, []); - - useEffect(() => { - console.log("RUN ONNX COMMAND", runONNXCommand) - if (runONNXCommand) { - runONNX(); - } - }, [runONNXCommand]) - - // Decode a Numpy file into a tensor. - const loadNpyTensor = async (tensorFile, dType) => { - let npLoader = new npyjs(); - const npArray = await npLoader.load(tensorFile); - const tensor = new ort.Tensor(dType, npArray.data, npArray.shape); - console.log("TENSOR", tensor) - return tensor; - }; - - const runONNX = async () => { - console.log("GOT HERE") - console.log("MODEL", model) - console.log("CLICKS", clicks) - console.log("TENSOR", tensor) - console.log("MODEL SCALE", modelScale) - try { - if ( - model === null || - clicks === null || - tensor === null || - modelScale === null - ) - return; - else { - console.log("INSIDE MODEL") - // Preapre the model input in the correct format for SAM. - // The modelData function is from onnxModelAPI.tsx. - const feeds = modelData({ - clicks, - tensor, - modelScale, - }); - if (feeds === undefined) return; - // Run the SAM ONNX model with the feeds returned from modelData() - const results = await model.run(feeds); - const output = results[model.outputNames[0]]; - // The predicted mask returned from the ONNX model is an array which is - // rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx. - setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3])); - } - } catch (e) { - console.log(e); - } - }; - - useEffect(() => { - console.log(maskImg) - console.log(tensor) - }, [maskImg, tensor]) - - - return
-
-}; - -export default SamMaskCanvas; diff --git a/frontend/src/Project/Canvas/ToolCanvas/ToolCanvas.js b/frontend/src/Project/Canvas/ToolCanvas/ToolCanvas.js index 3e4162a69..c25b7a048 100644 --- a/frontend/src/Project/Canvas/ToolCanvas/ToolCanvas.js +++ b/frontend/src/Project/Canvas/ToolCanvas/ToolCanvas.js @@ -19,7 +19,6 @@ import SwapCanvas from './SwapCanvas'; import ThresholdCanvas from './ThresholdCanvas'; import WatershedCanvas from './WatershedCanvas'; import SamCanvas from './SamCanvas'; -import SamMaskCanvas from './SamMaskCanvas'; import { useState } from 'react'; function ToolCanvas({ setBitmaps }) { @@ -65,8 +64,7 @@ function ToolCanvas({ setBitmaps }) { return ; case 'sam': return <> - - + ; default: return null; diff --git a/frontend/src/Project/ProjectContext.js b/frontend/src/Project/ProjectContext.js index bdf4ae2f3..98d213ae7 100644 --- a/frontend/src/Project/ProjectContext.js +++ b/frontend/src/Project/ProjectContext.js @@ -300,7 +300,6 @@ export function useEditSegment() { } export function useEditSegmentCopy() { - console.log("HERE") const project = useProject(); const segment = useSelector(project, (state) => { const tool = state.context.toolRef; @@ -308,7 +307,6 @@ export function useEditSegmentCopy() { return segment; }); - console.log("SEGMENT") return segment; } diff --git a/frontend/src/Project/service/edit/segment/samMachine.js b/frontend/src/Project/service/edit/segment/samMachine.js index ba03b8753..b70b11bb0 100644 --- a/frontend/src/Project/service/edit/segment/samMachine.js +++ b/frontend/src/Project/service/edit/segment/samMachine.js @@ -2,27 +2,12 @@ import { assign, Machine, send } from 'xstate'; import { fromEventBus } from '../../eventBus'; -async function sendToSamAPI(ctx) { - const id = new URLSearchParams(window.location.search).get('projectId') - console.log("SHOULD SEND TO API", ctx) - const options = { - method: 'POST', - body: JSON.stringify(ctx), - 'Content-Type': 'application/json', - }; - const response = await fetch(`${document.location.origin}/api/sendToSam/${id}`, options) - const data = await response.json() - - await new Promise(r => setTimeout(r, 4000)) - - return data -} - const createSAMMachine = (context) => Machine( { invoke: [ { src: fromEventBus('watershed', () => context.eventBuses.canvas, 'COORDINATES') }, + { id: 'arrays', src: fromEventBus('sam', () => context.eventBuses.arrays, ['EDITED_SEGMENT']) }, ], context: { isMouseDown: false, @@ -50,16 +35,12 @@ Machine( waiting: { on: { CLEAR_SELECTION: { target: 'done', actions: ['clearSelection']}, - SEND_TO_API: { target: "fetching" } + SEND_TO_API: {target: "fetching", actions: ["sendToAPI"]} }, }, fetching: { - invoke: { - src: sendToSamAPI, - onDone: { - target: 'done', - actions: ['clearSelection'] - }, + on: { + EDITED_SEGMENT: { target: 'done', actions: ['clearSelection'] }, } }, done: { @@ -78,6 +59,19 @@ Machine( setMouseIsDown: assign({ isMouseDown: true }), setMouseIsNotDown: assign({ isMouseDown: false }), clearSelection: assign({ startX: null, startY: null, endX: null, endY: null, isMouseDown: false, x: 0, y: 0 }), + sendToAPI: send( + (ctx) => ({ + type: 'EDIT', + action: 'sam', + args: { + x_start: ctx.startX, + y_start: ctx.startY, + x_end: ctx.endX, + y_end: ctx.endY, + }, + }), + { to: 'arrays' } + ), }, } ); diff --git a/frontend/src/Project/service/labels/segmentApiMachine.js b/frontend/src/Project/service/labels/segmentApiMachine.js index f57100e68..9636db676 100644 --- a/frontend/src/Project/service/labels/segmentApiMachine.js +++ b/frontend/src/Project/service/labels/segmentApiMachine.js @@ -40,22 +40,41 @@ async function makeEditZip(context, event) { return zipBlob; } -/** Sends a label zip to the DeepCell Label API to edit. */ -async function edit(context, event) { - const form = new FormData(); - const zipBlob = await makeEditZip(context, event); - form.append('labels', zipBlob, 'labels.zip'); - const width = context.labeled[0].length; - const height = context.labeled.length; - +async function testSamPrediction(context, event) { const options = { method: 'POST', - body: form, - 'Content-Type': 'multipart/form-data', + body: JSON.stringify(event.args), + 'Content-Type': 'application/json', }; - return fetch(`${document.location.origin}/api/edit`, options) + return fetch(`${document.location.origin}/api/testSamPrediction`, options) .then(checkResponseCode) - .then((res) => parseResponseZip(res, width, height)); + .then((res) => res.json()) + .then((data) => { + return { labeled: data.data[0], cells: context.cells} + }) +} + +/** Sends a label zip to the DeepCell Label API to edit. */ +async function edit(context, event) { + + if (event.action === "sam") { + return testSamPrediction(context, event) + } else { + const form = new FormData(); + const zipBlob = await makeEditZip(context, event); + form.append('labels', zipBlob, 'labels.zip'); + const width = context.labeled[0].length; + const height = context.labeled.length; + + const options = { + method: 'POST', + body: form, + 'Content-Type': 'multipart/form-data', + }; + return fetch(`${document.location.origin}/api/edit`, options) + .then(checkResponseCode) + .then((res) => parseResponseZip(res, width, height)); + } } function checkResponseCode(response) {