Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for segmenting a single image #5

Open
AlexanderZeilmann opened this issue Aug 29, 2024 · 2 comments
Open

Code for segmenting a single image #5

AlexanderZeilmann opened this issue Aug 29, 2024 · 2 comments

Comments

@AlexanderZeilmann
Copy link

With the original segment anything all I have to do to segment a single image is downloading the checkpoint and running

from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry["<model_type>"](checkpoint="<path/to/checkpoint>")
predictor = SamPredictor(sam)
predictor.set_image(<your_image>)
masks, _, _ = predictor.predict(<input_prompts>)

How can I do something similar in FastSAM3D? I downloaded the FastSAM3D checkpoint and have a 3D image with prompts ready. How do use FastSAM3D to segment my image using my prompts?

@skill-diver
Copy link
Collaborator

You could use infer.sh, and modify the parameter as you want.
vp means visualization path you want to output, tdp mean the image you want to used for segmenting.

python validation.py --seed 2023
-vp ./results/vis_sam_med3d
-tdp data/initial_test_dataset/total_segment -nc 1

@williamtrayn0r
Copy link

Regarding this issue, I have my data set in the specified format and have saved the FastSam3D checkpoint locally. When I run infer.sh I experience errors when loading the model. They are regarding missing encoder blocks. Additionally, when looking through the validation.py file I notice many functions refer to tuning the model. In my case, I do not want to tune the model. I want to test the model's segmentation on my data. Any help with this would be greatly appreciated.

Here is my error message below.

RuntimeError: Error(s) in loading state_dict for Sam3D:
Missing key(s) in state_dict: "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_d", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_d", "image_encoder.blocks.1.attn.rel_pos_h", "image_encoder.blocks.1.attn.rel_pos_w", "image_encoder.blocks.1.attn.qkv.weight", "image_encoder.blocks.1.attn.qkv.bias", "image_encoder.blocks.1.attn.proj.weight", "image_encoder.blocks.1.attn.proj.bias", "image_encoder.blocks.6.norm1.weight", "image_encoder.blocks.6.norm1.bias", "image_encoder.blocks.6.attn.rel_pos_d", "image_encoder.blocks.6.attn.rel_pos_h", "image_encoder.blocks.6.attn.rel_pos_w", "image_encoder.blocks.6.attn.qkv.weight", "image_encoder.blocks.6.attn.qkv.bias", "image_encoder.blocks.6.attn.proj.weight", "image_encoder.blocks.6.attn.proj.bias", "image_encoder.blocks.6.norm2.weight", "image_encoder.blocks.6.norm2.bias", "image_encoder.blocks.6.mlp.lin1.weight", "image_encoder.blocks.6.mlp.lin1.bias", "image_encoder.blocks.6.mlp.lin2.weight", "image_encoder.blocks.6.mlp.lin2.bias", "image_encoder.blocks.7.norm1.weight", "image_encoder.blocks.7.norm1.bias", "image_encoder.blocks.7.attn.rel_pos_d", "image_encoder.blocks.7.attn.rel_pos_h", "image_encoder.blocks.7.attn.rel_pos_w", "image_encoder.blocks.7.attn.qkv.weight", "image_encoder.blocks.7.attn.qkv.bias", "image_encoder.blocks.7.attn.proj.weight", "image_encoder.blocks.7.attn.proj.bias", "image_encoder.blocks.7.norm2.weight", "image_encoder.blocks.7.norm2.bias", "image_encoder.blocks.7.mlp.lin1.weight", "image_encoder.blocks.7.mlp.lin1.bias", "image_encoder.blocks.7.mlp.lin2.weight", "image_encoder.blocks.7.mlp.lin2.bias", "image_encoder.blocks.8.norm1.weight", "image_encoder.blocks.8.norm1.bias", "image_encoder.blocks.8.attn.rel_pos_d", "image_encoder.blocks.8.attn.rel_pos_h", "image_encoder.blocks.8.attn.rel_pos_w", "image_encoder.blocks.8.attn.qkv.weight", "image_encoder.blocks.8.attn.qkv.bias", "image_encoder.blocks.8.attn.proj.weight", "image_encoder.blocks.8.attn.proj.bias", "image_encoder.blocks.8.norm2.weight", "image_encoder.blocks.8.norm2.bias", "image_encoder.blocks.8.mlp.lin1.weight", "image_encoder.blocks.8.mlp.lin1.bias", "image_encoder.blocks.8.mlp.lin2.weight", "image_encoder.blocks.8.mlp.lin2.bias", "image_encoder.blocks.9.norm1.weight", "image_encoder.blocks.9.norm1.bias", "image_encoder.blocks.9.attn.rel_pos_d", "image_encoder.blocks.9.attn.rel_pos_h", "image_encoder.blocks.9.attn.rel_pos_w", "image_encoder.blocks.9.attn.qkv.weight", "image_encoder.blocks.9.attn.qkv.bias", "image_encoder.blocks.9.attn.proj.weight", "image_encoder.blocks.9.attn.proj.bias", "image_encoder.blocks.9.norm2.weight", "image_encoder.blocks.9.norm2.bias", "image_encoder.blocks.9.mlp.lin1.weight", "image_encoder.blocks.9.mlp.lin1.bias", "image_encoder.blocks.9.mlp.lin2.weight", "image_encoder.blocks.9.mlp.lin2.bias", "image_encoder.blocks.10.norm1.weight", "image_encoder.blocks.10.norm1.bias", "image_encoder.blocks.10.attn.rel_pos_d", "image_encoder.blocks.10.attn.rel_pos_h", "image_encoder.blocks.10.attn.rel_pos_w", "image_encoder.blocks.10.attn.qkv.weight", "image_encoder.blocks.10.attn.qkv.bias", "image_encoder.blocks.10.attn.proj.weight", "image_encoder.blocks.10.attn.proj.bias", "image_encoder.blocks.10.norm2.weight", "image_encoder.blocks.10.norm2.bias", "image_encoder.blocks.10.mlp.lin1.weight", "image_encoder.blocks.10.mlp.lin1.bias", "image_encoder.blocks.10.mlp.lin2.weight", "image_encoder.blocks.10.mlp.lin2.bias", "image_encoder.blocks.11.norm1.weight", "image_encoder.blocks.11.norm1.bias", "image_encoder.blocks.11.attn.rel_pos_d", "image_encoder.blocks.11.attn.rel_pos_h", "image_encoder.blocks.11.attn.rel_pos_w", "image_encoder.blocks.11.attn.qkv.weight", "image_encoder.blocks.11.attn.qkv.bias", "image_encoder.blocks.11.attn.proj.weight", "image_encoder.blocks.11.attn.proj.bias", "image_encoder.blocks.11.norm2.weight", "image_encoder.blocks.11.norm2.bias", "image_encoder.blocks.11.mlp.lin1.weight", "image_encoder.blocks.11.mlp.lin1.bias", "image_encoder.blocks.11.mlp.lin2.weight", "image_encoder.blocks.11.mlp.lin2.bias".
size mismatch for image_encoder.blocks.2.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).
size mismatch for image_encoder.blocks.2.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).
size mismatch for image_encoder.blocks.2.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).
size mismatch for image_encoder.blocks.3.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]).
size mismatch for image_encoder.blocks.3.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]).
size mismatch for image_encoder.blocks.3.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]).
size mismatch for image_encoder.blocks.4.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]).
size mismatch for image_encoder.blocks.4.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]).
size mismatch for image_encoder.blocks.4.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([27, 64]).
size mismatch for image_encoder.blocks.5.attn.rel_pos_d: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).
size mismatch for image_encoder.blocks.5.attn.rel_pos_h: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).
size mismatch for image_encoder.blocks.5.attn.rel_pos_w: copying a param with shape torch.Size([15, 128]) from checkpoint, the shape in current model is torch.Size([15, 64]).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants