Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sudden model size increase for high pruning ratios #417

Open
Hrayo712 opened this issue Aug 28, 2024 · 1 comment
Open

Sudden model size increase for high pruning ratios #417

Hrayo712 opened this issue Aug 28, 2024 · 1 comment

Comments

@Hrayo712
Copy link

Hrayo712 commented Aug 28, 2024

Hi @VainF

First of all, thanks for the amazing work.

I am currently working on deriving a tradeoff curve for ViT-b-16 (torchvision) to show the accuracy vs compression tradeoff for a given pruning method and configuration. Everything seems to work great. However, I have noticed, that on high pruning ratios, model size suddenly increases. Upon further inspection, this is because some layers are suddenly left out from pruning on this high sparsity ratios.

For example, using the simplified (and self-contained) code below:

import torch
import torch_pruning as tp
from torchvision import models
from torch_pruning.utils import count_ops_and_params

pruning_ratios = [
    0.1,
    0.2,
    0.3,
    0.4,
    0.5,
    0.6,
    0.7,
    0.8,
    0.9,
]

for ratio in pruning_ratios:
    model = models.get_model(
        "vit_b_16",
        weights="IMAGENET1K_V1",
        num_classes=1000,
        image_size=224
    )

    round_to = model.encoder.layers[0].num_heads
    unwrapped_parameters = [
        (model.class_token, 2), (model.encoder.pos_embedding, 2)
    ]
    example_inputs = torch.randn(1, 3, 224, 224)
    pruner = tp.MetaPruner(
        model=model,
        example_inputs=example_inputs,
        importance=tp.importance.MagnitudeImportance(p=2),
        global_pruning=True,
        pruning_ratio=ratio,
        ignored_layers=[],
        round_to=round_to,
        unwrapped_parameters=unwrapped_parameters,
    )

    pruner.step()
    model.hidden_dim = model.conv_proj.out_channels
    
    macs, params = count_ops_and_params(
        model, 
        example_inputs,
    )
    print(f"Number of parameters for ratio: {ratio}: {params/1e6}")
    print(f"Number of MACs for ratio: {ratio}: {macs/1e6}")
    del pruner
    del model

We obtain the following:

Number of parameters for ratio: 0.1: 75.533176
Number of MACs for ratio: 0.1: 15394.679128

Number of parameters for ratio: 0.2: 57.909372
Number of MACs for ratio: 0.2: 11858.000928

Number of parameters for ratio: 0.3: 46.424916
Number of MACs for ratio: 0.3: 9557.7033

Number of parameters for ratio: 0.4: 38.525112
Number of MACs for ratio: 0.4: 7987.479672

Number of parameters for ratio: 0.5: 31.303536
Number of MACs for ratio: 0.5: 6554.09394

Number of parameters for ratio: 0.6: 24.763224
Number of MACs for ratio: 0.6: 5256.025044

Number of parameters for ratio: 0.7: 16.665684
Number of MACs for ratio: 0.7: 3622.225068

Number of parameters for ratio: 0.8: 3.670524
Number of MACs for ratio: 0.8: 865.9407

Number of parameters for ratio: 0.9: 35.735992
Number of MACs for ratio: 0.9: 7579.648216

As you can see, on pruning ratio 0.9, model size suddenly increases. In this example, due to the self attention layers not getting pruned anymore.

This example is based on ViT, but I have noticed similar behavior on other models.

So I wonder, is this maybe a bug? or expected behavior perhaps?

I hope you can comment and clarify.

Thanks in advance for your support!

P.S. The above was tested using the following versions

Torch: 2.4.0
Torchvision: 0.19.0
Torch-pruning: 1.4.2
@Hrayo712 Hrayo712 changed the title Sudden model size increase for high sparsity ratios Sudden model size increase for high pruning ratios Aug 30, 2024
@janthmueller
Copy link

The pruner has a built-in safeguard to prevent layer collapse, which is probably the main issue here. If you try to prune all channels from a layer (or several layers), this safeguard steps in to stop it. PyTorch can’t handle empty tensors well since they disrupt operations like gradient calculation. By preventing layer collapse, some parameters stay untouched, which can result in more parameters remaining than expected, even with a high pruning ratio.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants