Skip to content

Commit

Permalink
Squashed changes from private branch:
Browse files Browse the repository at this point in the history
* Update readme
* Docu of recent changes
* Implement more sensible translation distribution for augmentation
* Add the alignment for Biwi from the SRHP/OPAL paper
* Add BIWI Benchmark
* Fix labels in evaluate_stability plots
* Fix non_blocking data transfers.
* Stabilize NLL loss by mixing in a uniform probability density
* Remove rotation laplace distribution
* Change lower bound for variance output
* Fix warning from loading weights with torch.load
* Comment awefull train update func and fix type annotations
* Try post train quantization. Optional added to exporter.
* Parameterize variance like earlier and remove BN (OpenTrack release v0.2)
* Rename the variance heads (actual scales heads)
* New training schedules: Exponential LR rampup; NLL rampup; train variance parameters slower;
* Always include the shape plausibility loss
* Implement shape plausibility loss without torch gmm package
* Tests for support of negative log likelihood losses
* Improve grouping of losses and visualization of losses
* Save models in a more self-describing format for not having to know how to construct it prior to loading
* Bring back the learnable pose offset in local head coordinates
* Vastly simplify the ONNX compute graph!
* Rewrote image geometric transforms:
    - Torch implementation for differentiability and GPU execution (currently not used)
    - OpenCV implementation for speed on cpu
* Evaluation script fixes
  • Loading branch information
DaWelter committed Oct 31, 2024
1 parent f95bdb4 commit 9891d60
Show file tree
Hide file tree
Showing 47 changed files with 1,925 additions and 1,048 deletions.
39 changes: 39 additions & 0 deletions doc/recent-changes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Recent Changes

## Post Train Quantization
* Added to onnx export.
* Minimal accuracy loss for ResNet18 model. MobileNet variant becomes too noisy and is probably unusable.
* Inference time reduced to ca. 60% of float32 original.

## Variance Parameterization
* Back to earlier implementation:
- Use "smoothclip" func to force diagonals of covariance factors positive.
- Overparameterize with scale factor that is applied to all covariances.
* Remove BN

## Training
* Exponential LR warmup
* Train variance parameters 10x slower
* First train without NLL losses. After LR warmup, ramp up the weight of NLL losses.
* Add the "shape plausibility loss" back in. It's based on a GMM probability density for the face shape parameters.
* The warmup changes helped to get:
- Decent performance without BN in the variance heads
- Smoother loss curves than before

Curve of rotation loss at time of publication:

![Curve of rotation loss at time of publication](traincurve-paper.jpg)

Now:

![Curve of rotation loss at time of publication](traincurve-now.jpg)


## Model
* Add back learnable local pose offsets (Looks like it doesn't help)
* Simplify ONNX graph
* Store model config in checkpoint. Allows loading without having to know the config in advance.

## Evaluation
* Add evaluation of 2d NME for 68 landmark AFLW2000-3D benchmark.
*
Binary file added doc/traincurve-now.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/traincurve-paper.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
119 changes: 80 additions & 39 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
# OpNet: On the power of data augmentation for head pose estimation networks
OpenTrack "NeuralNet Tracker" Training & Evaluation
===================================================

A.K.A. OpenTrack's NeuralNet Tracker Training and Evaluation Code
/ [**"OpNet: On the power of data augmentation for head pose estimation"**](https://arxiv.org/abs/2407.05357)

Intro
-----
If you are looking for the code for the publication please note the [`paper` branch](https://github.com/opentrack/neuralnet-tracker-traincode/tree/paper),
which is a special tailored snapshot for the publication.

This branch contains the code for the publication. Beware, it also contains leftover things from past experiments.
This repository contains the code to train the neural nets for the NeuralNet tracker plugin of [Opentrack](https://github.com/opentrack/opentrack). It allows head tracking with a simple webcam.


Overview
--------

The tracker plugin is based on deep learning, i.e. neural network models optimized using data to perform their tasks.
There are two parts: A localizer network, and the actual pose estimation network.
The localizer tries to find a single face and generates a bounding box around it from where a crop is extracted for the pose network to analyze.

In the following there are steps outlined to reproduce the networks
delivered with OpenTrack. This includes training and evaluation. However, the instructions are currently focussed on the pose estimator. At the end there is a section on the localizer.

This readme contains instructions for evaluation and training.

Install
-------
Expand Down Expand Up @@ -42,13 +53,15 @@ Evaluation

Download AFLW2000-3D from http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm.

Biwi can be obtained from Kaggle https://www.kaggle.com/datasets/kmader/biwi-kinect-head-pose-database. I couldn't find a better source that is still accessible.

Download a pytorch model checkpoint.

* Baseline Ensemble: https://drive.google.com/file/d/19LrssD36COWzKDp7akxtJFeVTcltNVlR/view?usp=sharing
* Additionally trained on Face Synthetics (BL+FS): https://drive.google.com/file/d/19zN8KICVEbLnGFGB5KkKuWrPjfet-jC8/view?usp=sharing
* Labeling Ensemble (RA-300W-LP from Table 3): https://drive.google.com/file/d/13LSi6J4zWSJnEzEXwZxr5UkWndFXdjcb/view?usp=sharing

### Option 1
### Option 1 (AFLW2000 3D)

Run `scripts/AFLW20003dEvaluation.ipynb`
It should give results pretty close to the paper. The face crop selection is different though and so the result won't be exactly the same.
Expand All @@ -58,59 +71,62 @@ It should give results pretty close to the paper. The face crop selection is dif
Run the preprocessing and then the evaluation script.

```bash
# The output filename "aflw2k.h5" must batch the hardcoded value in "pipelines.py"
python scripts/dsaflw2k_processing.py $DATADIR/AFLW2000-3D.zip $DATADIR/aflw2k.h5`
# Preprocess the data. The output filename "aflw2k.h5" must match the hardcoded value in "pipelines.py"
python scripts/dsaflw2k_processing.py <path to>/AFLW2000-3D.zip $DATADIR/aflw2k.h5`
# Will look in $DATADIR for aflw2k.h5.
python scripts/evaluate_pose_network.py --ds aflw2k3d <path to model(.onnx|.ckpt)>
```

It supports ONNX conversions as well as pytorch checkpoints. But the script must be adapted to the concrete model configuration for the checkpoint if that is used. If you wish to process the outputs further, like for averaging like in the paper, there is an option to generate json files.
It supports ONNX conversions as well as PyTorch checkpoints. For PyTorch the script must be adapted to the concrete model configuration for the checkpoint. If you wish to process the outputs further, like for averaging like in the paper, there is an option to generate json files.

Evaluation on the Biwi benchmark works similarly. However, we use the annotations file from https://github.com/pcr-upm/opal23_headpose in order to adhere to the experimental protocol. It can be found under https://github.com/pcr-upm/opal23_headpose/blob/main/annotations/biwi_ann.txt.
```bash
# Preprocess the data.
python --opal-annotation <path to>/biwi_ann.txt scripts/dsprocess_biwi.py <path to>/biwi.zip $DATADIR/biwi-v3.h5
Integration in OpenTrack
------------------------
# Will look in $DATADIR for biwi-v3.h5.
python scripts/evaluate_pose_network.py --ds biwi --roi-expansion 0.8 --perspective-correction <path to model(.onnx|.ckpt)>
```
You want the `--perspective-correction` for SOTA results. It enables that the orientation obtained from the face crop is corrected for camera perspective since with the Kinect's field of view, the assumption of orthographic projection no longer holds true. I.e. the pose from the crop is transformed into the global coordinate frame. W.r.t this frame it is compared with the original labels. Without the correction, the pose from the crop is taken directly for comparison with the labels.
Setting `--roi-expansion 0.8` causes the cropped area to be smaller relative to the bounding box annotation. That is also necessary for good results because the annotations have much larger bounding boxes than the networks were trained with.
https://github.com/opentrack/opentrack
It currently has some older models though. Choose the "Neuralnet" tracker plugin.
Integration in OpenTrack
------------------------
Choose the "Neuralnet" tracker plugin. It currently comes with some older models which don't
achieve the same SOTA benchmark results but are a little bit more noise resistent and invariant
to eye movements.

Training
--------

Several datasets are used. All of which are preprocessed and the result (partially) stored in h5 files.
Rough guidelines for reproduction follow.

Rough guidelines for reproduction follow. First to get the data there is
the expositional script below which enumerates everything.
### Datasets

```bash
# 300W-LP
# Go to http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm and find the download for 300W-LP.zip.
# Currently it's on google drive with the ID as used below. Better check it yourself.
gdown 0B7OEHD3T4eCkVGs0TkhUWFN6N1k
# Note: gdown is a pip installable tool for downloading from google drive. You can ofc use anything you want.
#### 300W-LP & AFLW2000-3d

# AFLW2000-3d
wget www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/Database/AFLW2000-3D.zip
There should be download links for `300W-LP.zip` and `AFLW2000-3D.zip` on http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm.

#LaPa Megaface 3D Labeled "Large Pose" Extension
#https://drive.google.com/file/d/1K4CQ8QqAVXj3Cd-yUt3HU9Z8o8gDmSEV/view?usp=drive_link
$ gdown 1K4CQ8QqAVXj3Cd-yUt3HU9Z8o8gDmSEV
#### 300W-LP Reproduction
My version of 300W-LP with custom out-of-plane rotation augmentation applied.
Includes "closed-eyes" augmentation as well as directional illumination.
On Google Drive https://drive.google.com/file/d/1uEqba5JCGQMzrULnPHxf4EJa04z_yHWw/view?usp=drive_link.

#300W-LP Reproduction
#https://drive.google.com/file/d/1uEqba5JCGQMzrULnPHxf4EJa04z_yHWw/view?usp=drive_link
$ gdown 1uEqba5JCGQMzrULnPHxf4EJa04z_yHWw
#### LaPa Megaface 3D Labeled "Large Pose" Extension
My pseudo / semi-automatically labeled subset of the Megaface frames from LaPa.
On Google Drive https://drive.google.com/file/d/1K4CQ8QqAVXj3Cd-yUt3HU9Z8o8gDmSEV/view?usp=drive_link.

#WFLW 3D Labeled "Large Pose" Extension
#https://drive.google.com/file/d/1SY33foUF8oZP8RUsFmcEIjq5xF5m3oJ1/view?usp=drive_link
$ gdown 1SY33foUF8oZP8RUsFmcEIjq5xF5m3oJ1
#### WFLW 3D Labeled "Large Pose" Extension
My pseudo / semi-automatically labeled subset.
On Google Drive https://drive.google.com/file/d/1SY33foUF8oZP8RUsFmcEIjq5xF5m3oJ1/view?usp=drive_link.

# Face Synthetics (https://github.com/microsoft/FaceSynthetics)
wget --tries=0 --continue --server-response --timeout=0 --retry-connrefused https://facesyntheticspubwedata.blob.core.windows.net/iccv-2021/dataset_100000.zip
```
#### Face Synthetics
There should be a download link on https://github.com/microsoft/FaceSynthetics for the 100k samples variant `dataset_100000.zip`.

Now some preprocessing and unpacking:
### Preprocessing

```bash
python scripts/dsprocess_aflw2k.py AFLW2000-3D.zip $DATADIR/aflw2k.h5
Expand All @@ -129,6 +145,8 @@ unzip reproduction_300wlp-v12.zip -d ../$DATADIR/

The processed files can be inspected in the notebook `DataVisualization.ipynb`.

### Training Process

Now training should be possible. For the baseline it should be:
```bash
python scripts/train_poseestimator.py --lr 1.e-3 --epochs 1500 --ds "repro_300_wlp+lapa_megaface_lp:20000+wflw_lp" \
Expand Down Expand Up @@ -159,6 +177,20 @@ It will look at the environment variable `DATADIR` to find the datasets. Notable
--ds "repro_300_wlp" # Train only on the 300W-LP reproduction
--ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" # Train the "BL + FS" case which should give best performing models.
```
### Deployment

I use ONNX for deployment and most evaluation purposes. There is a script for conversion. WARNING: it is necessary to adapt its code to the model configuration. :-/ It is easy though. Only one statement where the model is instantiated needs to be changed. The script has two modes. For exports for OpenTrack use
```bash
python scripts/export_model.py --posenet <model.ckpt>
```
It omits the landmark predictions and renames the output tensors (for historical reasons). The script performs sanity checks to ensure the outputs from ONNX are almost equal to PyTorch outputs.
To use the model in OpenTrack, find the directory with the other `.onnx` models and copy the new one there. Then in OpenTrack, in the tracker settings, there is a button to select the model file.

For evaluation use
```
python scripts/export_model.py --full --posenet <model.ckpt>
```
The model created in this way includes all outputs.

Creation of 3D Labeled WFLW & LaPa Large Pose Expansions
--------------------------------------------------------
Expand Down Expand Up @@ -209,4 +241,13 @@ encoded as JPG, else as PNG. When `storage` is set to `image_filename` then the
files. The other label fields are label data and should be relatively self-explanatory.
Relevant code for reading and writing those files can be found in `trackertraincode/datasets/dshdf5.py`,
`trackertraincode/datasets/dshdf5pose.py` and the preprocessing scripts `scripts/dsprocess_*.py`.
`trackertraincode/datasets/dshdf5pose.py` and the preprocessing scripts `scripts/dsprocess_*.py`.
Localizer Network
-----------------
There is an old notebook to train this network.
The training data is a processed version of the Wider Face dataset. The processing accounts for the fact that Wider Face contains images with potentially many faces. Therefore, sections which contain only one face or none are extracted.
The localizer network is trained to generate a "heatmap" with a peak where it suspects the center of a face. In addition, parameters of a bounding box are outputted.
3 changes: 2 additions & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ python scripts/train_poseestimator.py --lr 1.e-3 --epochs 1500 --ds "repro_300_w
--with-nll-loss \
--roi-override original \
--no-blurpool \
--backbone mobilenetv1
--backbone resnet18 \
--outdir model_files/
14 changes: 5 additions & 9 deletions scripts/AFLW20003dEvaluation.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit 9891d60

Please sign in to comment.