This repository is the official implementation of Generalized Data Weighting via Class-level Gradient Manipulation (NeurIPS 2021).
If you find this code useful in your research then please cite:
@article{chen2021generalized,
title={Generalized DataWeighting via Class-Level Gradient Manipulation},
author={Chen, Can and Zheng, Shuhao and Chen, Xi and Dong, Erqun and Liu, Xue Steve and Liu, Hao and Dou, Dejing},
journal={Advances in Neural Information Processing Systems},
volume={34},
pages={14097--14109},
year={2021}
}
- Linux
- Python 3.7
- Pytorch 1.9.0
- Torchvision 0.9.1
More specifically, run this command:
pip install -r requirements.txt
Download CIFAR10 and place it in ./data.
To compare mw-net and gdw on CIFAR10 under 40% uniform noise, run this command:
python -u main.py --corruption_prob 0.4 --dataset cifar10 --mode mw-net --outer_lr 100
python -u main.py --corruption_prob 0.4 --dataset cifar10 --mode gdw --outer_lr 100
We set the outer level learning as 100 on CIFAR10 and 1000 on CIFAR100.
We place training logs of the above command in ./log and list results as below:
Method | mw-net | gdw |
---|---|---|
Accuracy | 86.62% | 87.97% |
We thank the Pytorch implementation on mw-net(https://github.com/xjtushujun/meta-weight-net).