-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate SAM 2.1 #8610
base: develop
Are you sure you want to change the base?
Integrate SAM 2.1 #8610
Conversation
Segment Anything 2.0 require to compile a .cu file with nvcc at build time. Hence, a cuda devel baseImage is required to build the nuclio container.
…e not embedded into installation package)
…l as required in Dockerfile.ui for extra plugins, adjust function.yaml and function-gpu.yaml to accommodate SAM2.1 and frontend plugin, add index.tsx and inference.worker.ts, update main.py and model_handler.py accordingly
…since not included in self.predictor.get_image_embedding() and needed in decoder and a new encoder class is added in image_encoder.py
…changes in index.tsx, and accommodate those inputs in inference.worker.ts and add sam2.1_hiera_large.decoder where postprocessing steps are added to undergo minimal changes in sam2 plugin
WalkthroughThe changes in this pull request involve updates to the documentation and the introduction of new functionality for the Computer Vision Annotation Tool (CVAT). The Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CVAT
participant SAM2Plugin
participant InferenceWorker
User->>CVAT: Upload Image
CVAT->>SAM2Plugin: Process Image
SAM2Plugin->>InferenceWorker: Send Image for Inference
InferenceWorker->>InferenceWorker: Run Model
InferenceWorker-->>SAM2Plugin: Return Inference Results
SAM2Plugin-->>CVAT: Update with Results
CVAT-->>User: Display Segmentation Results
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Quality Gate failedFailed conditions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 19
🧹 Outside diff range and nitpick comments (11)
serverless/pytorch/facebookresearch/sam2/nuclio/image_encoder.py (2)
1-13
: LGTM! Consider adding class docstring.
The class structure and initialization look good. Type hints are properly used, and the inheritance from torch.nn.Module
is appropriate.
Add a docstring to describe the purpose and usage of the SAM2Encoder
class:
class SAM2Encoder(torch.nn.Module):
+ """SAM 2.1 image encoder that processes input images through the backbone and prepares features for the mask decoder.
+
+ Args:
+ sam2_model (SAM2Base): Base SAM 2.1 model containing the image encoder and mask decoder components.
+ """
def __init__(self, sam2_model: SAM2Base) -> None:
14-34
: Improve type hints and documentation.
The method could benefit from more specific type hints and comprehensive documentation.
- def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Process input image through SAM encoder and prepare multi-scale features.
+
+ Args:
+ x (torch.Tensor): Input image tensor of shape (1, 3, H, W)
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Multi-scale features at three
+ different resolutions, ordered from highest to lowest resolution.
+ Each tensor has shape (1, C, H_i, W_i) where H_i, W_i are the
+ spatial dimensions at scale i.
+ """
serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (2)
1-10
: Remove unnecessary empty line
There's an extra empty line at line 11 that can be removed to maintain consistent spacing.
12-38
: Consider architectural improvements for serverless environment
As this is running in a serverless environment, consider these architectural improvements:
- Implement async processing using
async/await
for better scalability - Add a health check method to verify model status
- Consider implementing model warmup to improve cold start performance
Example health check implementation:
async def health_check(self):
try:
# Verify model and device status
if self.device.type == 'cuda':
assert torch.cuda.is_available()
# Run a small inference to ensure model is responsive
dummy_input = torch.zeros((1, 3, 64, 64), device=self.device)
with torch.inference_mode():
_ = self.sam2_encoder(dummy_input)
return True
except Exception as e:
logger.error(f"Health check failed: {e}")
return False
serverless/pytorch/facebookresearch/sam2/nuclio/function.yaml (2)
12-17
: Update help message and example GIF for SAM 2.1
The help message and example GIF should be specific to SAM 2.1:
- The help message could better describe SAM 2.1's unique capabilities and improvements over SAM 2.0
- The animated GIF currently shows an HRNet example (
hrnet_example.gif
) which should be replaced with a SAM 2.1-specific demonstration
20-23
: Consider updating Python runtime and event timeout
- Python 3.8 is approaching end-of-life. Consider upgrading to Python 3.10+ for better performance and security updates.
- The 30-second event timeout might be insufficient for processing large images or batch requests. Consider increasing based on your performance testing results.
serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (1)
16-17
: Update demo GIF and help message for SAM 2.1
The animated GIF currently points to an HRNet example. Consider updating it with a SAM 2.1-specific demonstration. Additionally, the help message could be more descriptive about SAM 2.1's specific capabilities and limitations.
- animated_gif: https://raw.githubusercontent.com/cvat-ai/cvat/develop/site/content/en/images/hrnet_example.gif
- help_message: The interactor allows to get a mask of an object using at least one positive, and any negative points inside it
+ animated_gif: https://raw.githubusercontent.com/cvat-ai/cvat/develop/site/content/en/images/sam2_example.gif
+ help_message: SAM 2.1 interactor generates high-quality object masks from user-provided prompts. Supports positive/negative points and optional bounding box input for improved accuracy.
docker-compose.dev.yml (1)
101-101
: Consider adding a default value for CLIENT_PLUGINS.
The addition of CLIENT_PLUGINS build argument is correct, but consider providing a default empty value to ensure backward compatibility and prevent build failures when the variable is not set.
- CLIENT_PLUGINS: ${CLIENT_PLUGINS}
+ CLIENT_PLUGINS: ${CLIENT_PLUGINS:-}
README.md (1)
194-195
: Consider adding ONNX conversion documentation.
Since the PR objectives mention that the decoder part was converted to ONNX format using a specific export script, it would be helpful to add documentation about this process, either in the serverless function's README or by adding a note in the table.
Example addition:
| [Segment Anything 2.1](/serverless/pytorch/facebookresearch/sam2/nuclio/) | interactor | PyTorch | ✔️ | ✔️ |
+| [Segment Anything 2.1 (ONNX)](/serverless/pytorch/facebookresearch/sam2/nuclio/onnx) | interactor | ONNX | ✔️ | ✔️ |
serverless/pytorch/facebookresearch/sam2/nuclio/main.py (1)
27-27
: Remove Unused Variable image_
The variable image_
is assigned but not used afterward. Consider removing it to clean up the code.
Apply this diff to remove the unused variable:
-high_res_feats_0, high_res_feats_1, image_embed, image_ = context.user_data.model.handle(image)
+high_res_feats_0, high_res_feats_1, image_embed = context.user_data.model.handle(image)
cvat-ui/plugins/sam2/src/ts/inference.worker.ts (1)
54-55
: Avoid disabling ESLint rules without justification
The line // eslint-disable-next-line no-restricted-globals
disables the ESLint rule no-restricted-globals
. If possible, refactor the code to avoid using restricted globals, or provide a comment explaining the necessity of disabling this rule to improve code readability and maintainability.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (9)
- README.md (1 hunks)
- cvat-ui/plugins/sam2/src/ts/index.tsx (1 hunks)
- cvat-ui/plugins/sam2/src/ts/inference.worker.ts (1 hunks)
- docker-compose.dev.yml (1 hunks)
- serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (1 hunks)
- serverless/pytorch/facebookresearch/sam2/nuclio/function.yaml (1 hunks)
- serverless/pytorch/facebookresearch/sam2/nuclio/image_encoder.py (1 hunks)
- serverless/pytorch/facebookresearch/sam2/nuclio/main.py (1 hunks)
- serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (1 hunks)
🔇 Additional comments (12)
serverless/pytorch/facebookresearch/sam2/nuclio/model_handler.py (1)
29-38
: 🛠️ Refactor suggestion
Consider adding input size limits and memory management
The handle method could benefit from additional safeguards:
- Add maximum input size validation to prevent memory issues
- Consider adding memory cleanup after processing
- Add more comprehensive input validation
Let's check if there are any memory-related issues reported in similar implementations:
def handle(self, image):
+ max_image_size = 1024 * 1024 * 4 # 4MP limit
with torch.inference_mode():
assert isinstance(image, Image)
+ # Validate image size
+ w, h = image.size
+ if w * h > max_image_size:
+ raise ValueError(f"Image too large: {w}x{h}. Maximum size allowed: {max_image_size} pixels")
+
input_image = self._transforms(image)
input_image = input_image[None, ...].to(self.device)
assert (
len(input_image.shape) == 4 and input_image.shape[1] == 3
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
high_res_feats_0, high_res_feats_1, image_embed = self.sam2_encoder(input_image)
- return high_res_feats_0, high_res_feats_1, image_embed, input_image
+ # Clean up CUDA memory if needed
+ if self.device.type == 'cuda':
+ torch.cuda.empty_cache()
+ return high_res_feats_0, high_res_feats_1, image_embed, input_image
serverless/pytorch/facebookresearch/sam2/nuclio/function.yaml (2)
44-45
: Verify CUDA configuration for SAM2 installation
The SAM2 installation has CUDA disabled (SAM2_BUILD_CUDA=0
). If this is the CPU-only version, consider creating a separate GPU configuration file (function-gpu.yaml) with CUDA enabled for better performance on GPU instances.
53-58
: Review resource limits and timeouts
Please verify these configuration values based on your performance testing:
maxWorkers: 2
might need adjustment based on expected load and resource availabilitymaxRequestBodySize: 33554432
(32MB) might be insufficient for high-resolution imagesworkerAvailabilityTimeoutMilliseconds: 10000
(10s) seems short, consider increasing if worker initialization takes longer
Run the following script to check for similar configurations in other functions:
serverless/pytorch/facebookresearch/sam2/nuclio/function-gpu.yaml (3)
69-74
: LGTM: Platform configuration is well-defined
The platform settings with automatic restart policy and proper volume mounting are appropriate for a GPU-enabled service.
59-67
: Review resource constraints for production workloads
The current configuration might be restrictive for production use:
- 32MB request body limit might not accommodate high-resolution images
- 10-second worker availability timeout is quite aggressive
- Single worker configuration could become a bottleneck
Let's check typical request sizes in the codebase:
21-23
: Consider runtime and timeout adjustments
- Python 3.8 is aging - consider upgrading to Python 3.10+ for better performance and features.
- The 30-second timeout might be insufficient for processing large images or batch requests.
Let's check if Python 3.10+ is used elsewhere in the project:
docker-compose.dev.yml (1)
101-101
: Verify documentation for the new build argument.
The new CLIENT_PLUGINS build argument needs to be documented to help users understand its purpose and usage, especially in the context of the SAM 2.1 plugin integration.
README.md (1)
194-195
: LGTM! SAM 2.1 entry added correctly.
The addition of Segment Anything 2.1 to the serverless functions table is well-formatted and accurately reflects its capabilities as an interactive segmentation model supporting both CPU and GPU environments.
cvat-ui/plugins/sam2/src/ts/inference.worker.ts (2)
29-29
: Clarify the use of the index signature in 'DecodeBody' interface
The index signature readonly [name: string]: Tensor;
in the DecodeBody
interface allows any string property to be a Tensor
. Ensure this is intentional to accept dynamic keys, as it might mask unintended extra properties or typos in property names.
55-55
: Confirm validity of worker context check
The condition if ((self as any).importScripts)
is used to determine if the code is running in a web worker context. Verify that this check is reliable across all target browsers and environments, or consider using a more standard method such as typeof importScripts === 'function'
.
cvat-ui/plugins/sam2/src/ts/index.tsx (2)
316-343
: Avoid overwriting 'worker.onmessage'; centralize message handling
Overwriting the worker's onmessage
handler here can interfere with other message handling logic. As previously mentioned, it's better to have a single onmessage
handler for the worker.
Refer to the previous suggestion on centralizing the onmessage
handler to handle different actions appropriately.
355-355
: Verify the accessibility of the model file path
The model URL is set to /assets/sam2.1_hiera_large.decoder.onnx
. Ensure that this file exists at the specified path and is accessible in the production environment.
Run the following script to check if the model file exists:
✅ Verification successful
Let me search for more context about how this model file is handled in the codebase.
Let me check the webpack configuration to understand how assets are handled during the build process.
Based on the webpack configuration, I can see that ONNX model files from plugin assets are handled by the CopyPlugin configuration:
{
from: 'plugins/**/assets/*.(onnx|js)',
to : 'assets/[name][ext]',
}
Let me verify if the model file exists in the plugin's assets directory:
Model file is present and correctly configured
The model file sam2.1_hiera_large.decoder.onnx
exists in the correct location at cvat-ui/plugins/sam2/assets/
. The webpack configuration correctly handles copying ONNX files from plugin assets to the public assets directory during build. The model URL path /assets/sam2.1_hiera_large.decoder.onnx
will be valid in production.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that the model file exists in the assets directory
# Test: Check for the existence of the model file
if [ -f ./assets/sam2.1_hiera_large.decoder.onnx ]; then
echo "Model file exists."
else
echo "Model file does not exist."
fi
Length of output: 116
Script:
#!/bin/bash
# Search for references to the model file name or related configuration
rg -g '!*.{lock,min.js}' "sam2.1_hiera_large.decoder.onnx" -C 2
# Search for any webpack or build configuration files that might handle assets
fd -e config.js -e webpack.config.js
# Look for documentation about model files
fd README -t f
Length of output: 1136
Script:
#!/bin/bash
# Check webpack configuration for asset handling
cat cvat-ui/webpack.config.js
# Look for any documentation about model setup in the plugin's directory
fd -t f -p "**/sam2/**/README*"
# Check if there are any build or asset related instructions in the main README
head -n 100 README.md
Length of output: 13083
Script:
#!/bin/bash
# Check if the model exists in the plugin's assets directory
fd -t f -e onnx . cvat-ui/plugins/sam2/
# Also check if there are any build instructions or documentation specific to the SAM2 plugin
rg -g '!*.{lock,min.js}' -C 3 "model.*\.onnx" cvat-ui/plugins/sam2/
Length of output: 770
self.image_encoder = sam2_model.image_encoder | ||
self.no_mem_embed = sam2_model.no_mem_embed | ||
|
||
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add input validation and shape assertions.
The forward method should validate input tensor dimensions and device placement to fail fast with clear error messages.
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]:
+ if x.dim() != 4:
+ raise ValueError(f"Expected 4D input tensor (B,C,H,W), got {x.dim()}D")
+ if x.size(0) != 1:
+ raise ValueError(f"Expected batch size 1, got {x.size(0)}")
+ if x.size(1) != 3:
+ raise ValueError(f"Expected 3 channels (RGB), got {x.size(1)}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: | |
def forward(self, x: torch.Tensor) -> tuple[Any, Any, Any]: | |
if x.dim() != 4: | |
raise ValueError(f"Expected 4D input tensor (B,C,H,W), got {x.dim()}D") | |
if x.size(0) != 1: | |
raise ValueError(f"Expected batch size 1, got {x.size(0)}") | |
if x.size(1) != 3: | |
raise ValueError(f"Expected 3 channels (RGB), got {x.size(1)}") |
self.sam2_checkpoint = "/opt/nuclio/sam2/sam2.1_hiera_large.pt" | ||
self.sam2_model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" | ||
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider making model paths configurable
The model checkpoint and configuration paths are hardcoded. Consider making these configurable through environment variables or configuration files for better flexibility and maintainability.
- self.sam2_checkpoint = "/opt/nuclio/sam2/sam2.1_hiera_large.pt"
- self.sam2_model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
+ self.sam2_checkpoint = os.getenv('SAM2_CHECKPOINT', "/opt/nuclio/sam2/sam2.1_hiera_large.pt")
+ self.sam2_model_cfg = os.getenv('SAM2_CONFIG', "configs/sam2.1/sam2.1_hiera_l.yaml")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
self.sam2_checkpoint = "/opt/nuclio/sam2/sam2.1_hiera_large.pt" | |
self.sam2_model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" | |
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device) | |
self.sam2_checkpoint = os.getenv('SAM2_CHECKPOINT', "/opt/nuclio/sam2/sam2.1_hiera_large.pt") | |
self.sam2_model_cfg = os.getenv('SAM2_CONFIG', "configs/sam2.1/sam2.1_hiera_l.yaml") | |
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device) |
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device) | ||
self.sam2_encoder = SAM2Encoder(sam2_model) if torch.cuda.is_available() else SAM2Encoder(sam2_model).cpu() | ||
self._transforms = SAM2Transforms(resolution=sam2_model.image_size, mask_threshold=0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling and logging for model initialization
The model loading process lacks error handling and logging. Consider adding try-catch blocks and logging statements to handle potential failures gracefully.
+ import logging
+ logger = logging.getLogger(__name__)
+
+ try:
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device)
self.sam2_encoder = SAM2Encoder(sam2_model) if torch.cuda.is_available() else SAM2Encoder(sam2_model).cpu()
+ logger.info(f"SAM2 model loaded successfully on {self.device}")
+ except FileNotFoundError as e:
+ logger.error(f"Failed to load SAM2 model: {e}")
+ raise
+ except Exception as e:
+ logger.error(f"Unexpected error during model initialization: {e}")
+ raise
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device) | |
self.sam2_encoder = SAM2Encoder(sam2_model) if torch.cuda.is_available() else SAM2Encoder(sam2_model).cpu() | |
self._transforms = SAM2Transforms(resolution=sam2_model.image_size, mask_threshold=0.0) | |
import logging | |
logger = logging.getLogger(__name__) | |
try: | |
sam2_model = build_sam2(self.sam2_model_cfg, self.sam2_checkpoint, device=self.device) | |
self.sam2_encoder = SAM2Encoder(sam2_model) if torch.cuda.is_available() else SAM2Encoder(sam2_model).cpu() | |
logger.info(f"SAM2 model loaded successfully on {self.device}") | |
except FileNotFoundError as e: | |
logger.error(f"Failed to load SAM2 model: {e}") | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error during model initialization: {e}") | |
raise | |
self._transforms = SAM2Transforms(resolution=sam2_model.image_size, mask_threshold=0.0) |
- kind: RUN | ||
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Optimize Docker image size
Consider cleaning up apt cache after installation to reduce the image size:
- value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6
+ value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6 && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
- kind: RUN | |
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6 | |
- kind: RUN | |
value: apt-get update && apt-get -y install curl git python3 python3-pip ffmpeg libsm6 libxext6 && \ | |
apt-get clean && \ | |
rm -rf /var/lib/apt/lists/* |
- kind: RUN | ||
value: curl -O https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add checksum verification for model weights
The model weights are downloaded without verifying their integrity. Consider adding SHA256 checksum verification:
- value: curl -O https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt
+ value: |
+ curl -O https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt && \
+ echo "expected-sha256-checksum sam2.1_hiera_large.pt" | sha256sum -c
Committable suggestion was skipped due to low confidence.
if (typeof query.jobID === 'number') { | ||
[plugin.data.jobs[query.jobID]] = results; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect assignment to 'plugin.data.jobs[query.jobID]'
The assignment [plugin.data.jobs[query.jobID]] = results;
is incorrect because square brackets are used for array destructuring, not for object property assignment. This can lead to unexpected behavior.
Apply this fix to correctly assign the result:
- [plugin.data.jobs[query.jobID]] = results;
+ plugin.data.jobs[query.jobID] = results[0];
Assuming that results
is an array and you want to assign the first element to plugin.data.jobs[query.jobID]
.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if (typeof query.jobID === 'number') { | |
[plugin.data.jobs[query.jobID]] = results; | |
} | |
if (typeof query.jobID === 'number') { | |
plugin.data.jobs[query.jobID] = results[0]; | |
} |
plugin.data.jobs = { | ||
// we do not need to store old job instances | ||
[job.id]: job, | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid overwriting 'plugin.data.jobs' object with a single job
Overwriting the entire plugin.data.jobs
object with a single job can lead to loss of previously stored jobs. This can cause issues if other parts of the code rely on accessing other job instances.
Modify the code to update the job entry without overwriting the entire object:
- plugin.data.jobs = {
- // we do not need to store old job instances
- [job.id]: job,
- };
+ plugin.data.jobs[job.id] = job;
If it's intentional to discard old job instances, consider resetting the object explicitly and documenting this behavior.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
plugin.data.jobs = { | |
// we do not need to store old job instances | |
[job.id]: job, | |
}; | |
plugin.data.jobs[job.id] = job; |
const isLowResMaskSuitable = JSON | ||
.stringify(clicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle edge cases when comparing 'clicks' arrays
Using clicks.slice(0, -1)
when clicks
has 0 or 1 elements may not behave as expected, since slicing with -1
can return an empty array.
Ensure that the comparison accounts for cases when clicks
has fewer than two elements:
const isLowResMaskSuitable = clicks.length > 1 &&
JSON.stringify(clicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks);
This addition checks that there are at least two clicks before performing the comparison.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
const isLowResMaskSuitable = JSON | |
.stringify(clicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks); | |
const isLowResMaskSuitable = clicks.length > 1 && | |
JSON.stringify(clicks.slice(0, -1)) === JSON.stringify(plugin.data.lastClicks); |
if (obj_bbox.length) { | ||
clicks.push({ clickType: 2, x: obj_bbox[0][0], y: obj_bbox[0][1] }); | ||
clicks.push({ clickType: 3, x: obj_bbox[1][0], y: obj_bbox[1][1] }); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add check for 'obj_bbox' length before accessing its elements
Accessing obj_bbox[1]
without checking the length of obj_bbox
can result in an error if it contains fewer than two elements.
Add a condition to ensure that obj_bbox
has at least two elements:
if (obj_bbox.length >= 2) {
clicks.push({ clickType: 2, x: obj_bbox[0][0], y: obj_bbox[0][1] });
clicks.push({ clickType: 3, x: obj_bbox[1][0], y: obj_bbox[1][1] });
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if (obj_bbox.length) { | |
clicks.push({ clickType: 2, x: obj_bbox[0][0], y: obj_bbox[0][1] }); | |
clicks.push({ clickType: 3, x: obj_bbox[1][0], y: obj_bbox[1][1] }); | |
} | |
if (obj_bbox.length >= 2) { | |
clicks.push({ clickType: 2, x: obj_bbox[0][0], y: obj_bbox[0][1] }); | |
clicks.push({ clickType: 3, x: obj_bbox[1][0], y: obj_bbox[1][1] }); | |
} |
function toMatImage(input: number[], width: number, height: number): number[][] { | ||
const image = Array(height).fill(0); | ||
for (let i = 0; i < image.length; i++) { | ||
image[i] = Array(width).fill(0); | ||
} | ||
|
||
for (let i = 0; i < input.length; i++) { | ||
const row = Math.floor(i / width); | ||
const col = i % width; | ||
image[row][col] = input[i] > 0 ? 255 : 0; | ||
} | ||
|
||
return image; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Optimize 'toMatImage' function for better performance
The current implementation initializes a 2D array and fills it using nested loops, which can be inefficient for large images.
Consider flattening the data processing or using more efficient data structures:
function toMatImage(input: number[], width: number, height: number): number[][] {
- const image = Array(height).fill(0);
- for (let i = 0; i < image.length; i++) {
- image[i] = Array(width).fill(0);
- }
+ const image = new Array(height);
+ for (let i = 0; i < height; i++) {
+ image[i] = input.slice(i * width, (i + 1) * width).map(value => (value > 0 ? 255 : 0));
+ }
return image;
}
This approach avoids initializing the array with zeros and directly maps the input values to the image array.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
function toMatImage(input: number[], width: number, height: number): number[][] { | |
const image = Array(height).fill(0); | |
for (let i = 0; i < image.length; i++) { | |
image[i] = Array(width).fill(0); | |
} | |
for (let i = 0; i < input.length; i++) { | |
const row = Math.floor(i / width); | |
const col = i % width; | |
image[row][col] = input[i] > 0 ? 255 : 0; | |
} | |
return image; | |
} | |
function toMatImage(input: number[], width: number, height: number): number[][] { | |
const image = new Array(height); | |
for (let i = 0; i < height; i++) { | |
image[i] = input.slice(i * width, (i + 1) * width).map(value => (value > 0 ? 255 : 0)); | |
} | |
return image; | |
} | |
Very nice work @hashJoe ! I think the next big milestone would be to integrate encoder in frontend to unlock decentralized tracking capabilities for video annotation (this is mostly all SAM2 is about)! |
@jeanchristopheruel I need to look into the video prediction and check how to implement it in CVAT. How would integrating the encoder in frontend help? There is a tracking model example which is done in backend (both encoder and decoder), here. Any insights would be helpful! |
@hashJoe I beleve Sam2 encoder is lightweight enough to be supported by the frontend (I think it is 1Gb or so). Porting it to the frontend would remove the need of an inference backend server. Also, it would reduce latency associated to the request containing the state and returning the embeddings. Video annotation faster than ever. Each user have its own internal tracking state, which is a memory embedding in Sam2. |
I'm following this PR |
SAM 2 for video tracking would be really transform our annotation workflow (fish monitoring)! I hope this is possible to implement in CVAT. |
Linked Issues:
Linked Pull Requests:
Summary
To integrate SAM 2.1 into CVAT, I
CLIENT_PLUGINS
argument to pass plugincvat-ui/plugins/sam2
,Motivation and context
This pull request builds upon a #8243 that aimed to integrate SAM 2 into CVAT. The progress on that contribution has been stalled, and this request serves as a continuation of integrating SAM 2.
Main enhancements:
This way the structure of integrating SAM is maintained in SAM 2 with minimal changes.
How has this been tested?
Using the following commands:
, and applying the model on several images.
Checklist
develop
branch(cvat-canvas,
cvat-core,
cvat-data and
cvat-ui)
License
Feel free to contact the maintainers if that's a concern.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Chores