Skip to content

Commit

Permalink
initial commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
az-ja committed Nov 5, 2022
1 parent d022287 commit 5647f70
Show file tree
Hide file tree
Showing 21 changed files with 2,606 additions and 1 deletion.
15 changes: 15 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

This comment has been minimized.

Copy link
@nansanzht

nansanzht Nov 21, 2022

Hello dear author, you two really did a great job , can you give the train.py, thank you very much

correlation.egg-info
checkpoints/*
.vscode/*

*.png
*.csv
*.pyc

*.sh
*.zip
chairs_split.txt
alt_cuda_corr/build/
alt_cuda_corr/dist/
*.flo
13 changes: 13 additions & 0 deletions CITATIONS.bib
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@inproceedings{jahediMultiScaleRAFTCombining2022,
title = {Multi-Scale {{RAFT}}: Combining Hierarchical Concepts for Learning-Based Optical Flow Estimation},
shorttitle = {Multi-{{Scale RAFT}}},
booktitle = {2022 {{IEEE International Conference}} on {{Image Processing}} ({{ICIP}})},
author = {Jahedi, Azin and Mehl, Lukas and Rivinius, Marc and Bruhn, Andr{\'e}s},
year = {2022},
month = oct,
pages = {1236--1240},
publisher = {{IEEE}},
issn = {2381-8549},
doi = {10.1109/ICIP46576.2022.9898048},
}

79 changes: 78 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,80 @@
# MS_RAFT

We will make the code soon available.
In this repository we release (for now) the inference code for our work:

> **[Multi-Scale RAFT: Combining Hierarchical Concepts for Learning-Based Optical Flow Estimation](https://dx.doi.org/10.1109/ICIP46576.2022.9898048)**<br/>
> _ICIP 2022_ <br/>
> Azin Jahedi, Lukas Mehl, Marc Rivinius and Andrés Bruhn
If you find our work useful please [cite via BibTeX](CITATIONS.bib).


## 🆕 Follow-Up Work

We improved the accuracy further by extending the method and applying a modified training setup.
Our new approach is called `MS_RAFT_plus` and won the [Robust Vision Challenge 2022](http://www.robustvision.net/).

The code is available on [GitHub](https://github.com/cv-stuttgart/MS_RAFT_plus).


## Requirements

The code has been tested with PyTorch 1.10.2+cu113.
Install the required dependencies via
```
pip install -r requirements.txt
```

Alternatively you can also manually install the following packages in your virtual environment:
- `torch`, `torchvision`, and `torchaudio` (e.g., with `--extra-index-url https://download.pytorch.org/whl/cu113` for CUDA 11.3)
- `matplotlib`
- `scipy`
- `tensorboard`
- `opencv-python`
- `tqdm`
- `parse`


## Pre-Trained Checkpoints

You can download our pre-trained model from the [releases page](https://github.com/cv-stuttgart/MS_RAFT/releases/tag/v1.0.0).


## Datasets

Datasets are expected to be located under `./data` in the following layout:
```
./data
├── kitti15 # KITTI 2015
│ └── dataset
│ ├── testing/...
│ └── training/...
└── sintel # Sintel
├── test/...
└── training/...
```


## Running MS_RAFT

You can evaluate a trained model via
```Shell
python evaluate.py --model sintel.pth --dataset sintel
```
This needs about 12 GB of GPU VRAM on MPI Sintel images.

If your GPU has smaller capacity, please compile the CUDA correlation module (once) via:
```Shell
cd alt_cuda_corr && python setup.py install && cd ..
```
and then run:
```Shell
python evaluate.py --model sintel.pth --dataset sintel --cuda_corr
```
Using `--cuda_corr`, estimating the flow on MPI Sintel images needs about 4 GB of GPU VRAM.


## Acknowledgement

Parts of this repository are adapted from [RAFT](https://github.com/princeton-vl/RAFT) ([license](licenses/RAFT/LICENSE)).
We thank the authors for their excellent work.
54 changes: 54 additions & 0 deletions alt_cuda_corr/correlation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include <torch/extension.h>
#include <vector>

// CUDA forward declarations
std::vector<torch::Tensor> corr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius);

std::vector<torch::Tensor> corr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius);

// C++ interface
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<torch::Tensor> corr_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);

return corr_cuda_forward(fmap1, fmap2, coords, radius);
}


std::vector<torch::Tensor> corr_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius) {
CHECK_INPUT(fmap1);
CHECK_INPUT(fmap2);
CHECK_INPUT(coords);
CHECK_INPUT(corr_grad);

return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &corr_forward, "CORR forward");
m.def("backward", &corr_backward, "CORR backward");
}
Loading

0 comments on commit 5647f70

Please sign in to comment.