-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
2,606 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
This comment has been minimized.
Sorry, something went wrong. |
||
correlation.egg-info | ||
checkpoints/* | ||
.vscode/* | ||
|
||
*.png | ||
*.csv | ||
*.pyc | ||
|
||
*.sh | ||
*.zip | ||
chairs_split.txt | ||
alt_cuda_corr/build/ | ||
alt_cuda_corr/dist/ | ||
*.flo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
@inproceedings{jahediMultiScaleRAFTCombining2022, | ||
title = {Multi-Scale {{RAFT}}: Combining Hierarchical Concepts for Learning-Based Optical Flow Estimation}, | ||
shorttitle = {Multi-{{Scale RAFT}}}, | ||
booktitle = {2022 {{IEEE International Conference}} on {{Image Processing}} ({{ICIP}})}, | ||
author = {Jahedi, Azin and Mehl, Lukas and Rivinius, Marc and Bruhn, Andr{\'e}s}, | ||
year = {2022}, | ||
month = oct, | ||
pages = {1236--1240}, | ||
publisher = {{IEEE}}, | ||
issn = {2381-8549}, | ||
doi = {10.1109/ICIP46576.2022.9898048}, | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,80 @@ | ||
# MS_RAFT | ||
|
||
We will make the code soon available. | ||
In this repository we release (for now) the inference code for our work: | ||
|
||
> **[Multi-Scale RAFT: Combining Hierarchical Concepts for Learning-Based Optical Flow Estimation](https://dx.doi.org/10.1109/ICIP46576.2022.9898048)**<br/> | ||
> _ICIP 2022_ <br/> | ||
> Azin Jahedi, Lukas Mehl, Marc Rivinius and Andrés Bruhn | ||
If you find our work useful please [cite via BibTeX](CITATIONS.bib). | ||
|
||
|
||
## 🆕 Follow-Up Work | ||
|
||
We improved the accuracy further by extending the method and applying a modified training setup. | ||
Our new approach is called `MS_RAFT_plus` and won the [Robust Vision Challenge 2022](http://www.robustvision.net/). | ||
|
||
The code is available on [GitHub](https://github.com/cv-stuttgart/MS_RAFT_plus). | ||
|
||
|
||
## Requirements | ||
|
||
The code has been tested with PyTorch 1.10.2+cu113. | ||
Install the required dependencies via | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
Alternatively you can also manually install the following packages in your virtual environment: | ||
- `torch`, `torchvision`, and `torchaudio` (e.g., with `--extra-index-url https://download.pytorch.org/whl/cu113` for CUDA 11.3) | ||
- `matplotlib` | ||
- `scipy` | ||
- `tensorboard` | ||
- `opencv-python` | ||
- `tqdm` | ||
- `parse` | ||
|
||
|
||
## Pre-Trained Checkpoints | ||
|
||
You can download our pre-trained model from the [releases page](https://github.com/cv-stuttgart/MS_RAFT/releases/tag/v1.0.0). | ||
|
||
|
||
## Datasets | ||
|
||
Datasets are expected to be located under `./data` in the following layout: | ||
``` | ||
./data | ||
├── kitti15 # KITTI 2015 | ||
│ └── dataset | ||
│ ├── testing/... | ||
│ └── training/... | ||
└── sintel # Sintel | ||
├── test/... | ||
└── training/... | ||
``` | ||
|
||
|
||
## Running MS_RAFT | ||
|
||
You can evaluate a trained model via | ||
```Shell | ||
python evaluate.py --model sintel.pth --dataset sintel | ||
``` | ||
This needs about 12 GB of GPU VRAM on MPI Sintel images. | ||
|
||
If your GPU has smaller capacity, please compile the CUDA correlation module (once) via: | ||
```Shell | ||
cd alt_cuda_corr && python setup.py install && cd .. | ||
``` | ||
and then run: | ||
```Shell | ||
python evaluate.py --model sintel.pth --dataset sintel --cuda_corr | ||
``` | ||
Using `--cuda_corr`, estimating the flow on MPI Sintel images needs about 4 GB of GPU VRAM. | ||
|
||
|
||
## Acknowledgement | ||
|
||
Parts of this repository are adapted from [RAFT](https://github.com/princeton-vl/RAFT) ([license](licenses/RAFT/LICENSE)). | ||
We thank the authors for their excellent work. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#include <torch/extension.h> | ||
#include <vector> | ||
|
||
// CUDA forward declarations | ||
std::vector<torch::Tensor> corr_cuda_forward( | ||
torch::Tensor fmap1, | ||
torch::Tensor fmap2, | ||
torch::Tensor coords, | ||
int radius); | ||
|
||
std::vector<torch::Tensor> corr_cuda_backward( | ||
torch::Tensor fmap1, | ||
torch::Tensor fmap2, | ||
torch::Tensor coords, | ||
torch::Tensor corr_grad, | ||
int radius); | ||
|
||
// C++ interface | ||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") | ||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) | ||
|
||
std::vector<torch::Tensor> corr_forward( | ||
torch::Tensor fmap1, | ||
torch::Tensor fmap2, | ||
torch::Tensor coords, | ||
int radius) { | ||
CHECK_INPUT(fmap1); | ||
CHECK_INPUT(fmap2); | ||
CHECK_INPUT(coords); | ||
|
||
return corr_cuda_forward(fmap1, fmap2, coords, radius); | ||
} | ||
|
||
|
||
std::vector<torch::Tensor> corr_backward( | ||
torch::Tensor fmap1, | ||
torch::Tensor fmap2, | ||
torch::Tensor coords, | ||
torch::Tensor corr_grad, | ||
int radius) { | ||
CHECK_INPUT(fmap1); | ||
CHECK_INPUT(fmap2); | ||
CHECK_INPUT(coords); | ||
CHECK_INPUT(corr_grad); | ||
|
||
return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); | ||
} | ||
|
||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward", &corr_forward, "CORR forward"); | ||
m.def("backward", &corr_backward, "CORR backward"); | ||
} |
Oops, something went wrong.
Hello dear author, you two really did a great job , can you give the train.py, thank you very much