From 1d781d25b4ad3246be46e6df52685a2197c4977c Mon Sep 17 00:00:00 2001 From: Louis Jean Date: Wed, 8 Dec 2021 17:10:19 +0100 Subject: [PATCH] fix(tensorrt): yolox postprocessing in C++ Onnx post processing returns wrong results due to a tensorrt bug --- src/backends/tensorrt/.gitignore | 1 + src/backends/tensorrt/models/yolo.hpp | 76 +++++++++++++++++++++++++++ src/backends/tensorrt/tensorrtlib.cc | 53 ++++++++++++------- src/backends/tensorrt/tensorrtlib.h | 2 + tests/CMakeLists.txt | 15 ++++-- tests/ut-tensorrtapi.cc | 34 +++++++++--- tools/torch/trace_yolox.py | 24 ++++++--- 7 files changed, 168 insertions(+), 37 deletions(-) create mode 100644 src/backends/tensorrt/.gitignore create mode 100644 src/backends/tensorrt/models/yolo.hpp diff --git a/src/backends/tensorrt/.gitignore b/src/backends/tensorrt/.gitignore new file mode 100644 index 000000000..626cfe0e1 --- /dev/null +++ b/src/backends/tensorrt/.gitignore @@ -0,0 +1 @@ +-models/ diff --git a/src/backends/tensorrt/models/yolo.hpp b/src/backends/tensorrt/models/yolo.hpp new file mode 100644 index 000000000..51a749520 --- /dev/null +++ b/src/backends/tensorrt/models/yolo.hpp @@ -0,0 +1,76 @@ +// Copyright (C) 2021 Jolibrain http://www.jolibrain.com + +// Author: Louis Jean + +// This program is free software; you can redistribute it and/or +// modify it under the terms of the GNU General Public License +// as published by the Free Software Foundation; either version 3 +// of the License, or (at your option) any later version. + +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +#include +#include + +namespace dd +{ + namespace yolo_utils + { + /** Convert from format: + * unsorted bbox*4 | objectness | class softmax*n_classes + * to format: + * sorted batch id | class_id | class confidence | bbox * 4*/ + static std::vector + parse_yolo_output(const std::vector &model_out, size_t batch_size, + size_t top_k, size_t n_classes, size_t im_width, + size_t im_height) + { + std::vector vals; + vals.reserve(batch_size * top_k * 7); + size_t step = n_classes + 5; + auto batch_it = model_out.begin(); + + for (size_t batch = 0; batch < batch_size; ++batch) + { + std::vector> result; + result.reserve(top_k); + auto end_it = batch_it + top_k * step; + + for (; batch_it != end_it; batch_it += step) + { + // get class id & confidence + auto max_batch_it + = std::max_element(batch_it + 5, batch_it + step); + float cls_pred = std::distance(batch_it + 5, max_batch_it); + float prob = *max_batch_it * (*(batch_it + 4)); + + // convert center, dims to xyxy + float xc = *batch_it, yc = *(batch_it + 1), w = *(batch_it + 2), + h = *(batch_it + 3); + result.push_back(std::vector{ + 0, cls_pred, prob, (xc - w / 2) / (im_width - 1), + (yc - h / 2) / (im_height - 1), + (xc + w / 2) / (im_width - 1), + (yc + h / 2) / (im_height - 1) }); + } + + std::sort(result.begin(), result.end(), + [](const std::vector &a, + const std::vector &b) { return a[2] > b[2]; }); + + for (auto &val : result) + { + vals.insert(vals.end(), val.begin(), val.end()); + } + batch_it = end_it; + } + return vals; + } + } +} diff --git a/src/backends/tensorrt/tensorrtlib.cc b/src/backends/tensorrt/tensorrtlib.cc index d859dc4a4..54d76c72b 100644 --- a/src/backends/tensorrt/tensorrtlib.cc +++ b/src/backends/tensorrt/tensorrtlib.cc @@ -32,6 +32,7 @@ #include #endif #include "utils/bbox.hpp" +#include "models/yolo.hpp" namespace dd { @@ -108,6 +109,7 @@ namespace dd _timeserie = tl._timeserie; _regression = tl._regression; _need_nms = tl._need_nms; + _template = tl._template; _inputIndex = tl._inputIndex; _outputIndex0 = tl._outputIndex0; _outputIndex1 = tl._outputIndex1; @@ -200,23 +202,21 @@ namespace dd + this->_mlmodel._repo); } - // XXX(louis): this default value should be moved out of trt lib when - // init_mllib will be changed to DTOs if (ad.has("topk")) _top_k = ad.get("topk").get(); if (ad.has("template")) { - std::string tmplate = ad.get("template").get(); - this->_logger->info("Model template is {}", tmplate); + _template = ad.get("template").get(); + this->_logger->info("Model template is {}", _template); - if (tmplate == "yolox") + if (_template == "yolox") { this->_mltype = "detection"; _need_nms = true; } else - throw MLLibBadParamException("Unknown template " + tmplate); + throw MLLibBadParamException("Unknown template " + _template); } _builder = std::shared_ptr( @@ -544,9 +544,9 @@ namespace dd this->_logger->info("trying to determine number of classes..."); _nclasses = caffe_proto::findNClasses(this->_mlmodel._def, _bbox); if (_nclasses < 0) - throw MLLibBadParamException( - "failed detecting the number of classes, specify it through " - "API with nclasses"); + throw MLLibBadParamException("failed detecting the number of " + "classes, specify it through " + "API with nclasses"); this->_logger->info("found {} classes", _nclasses); } @@ -636,8 +636,7 @@ namespace dd } else { - if (this->_mlmodel._model.find("net_tensorRT.onnx") - != std::string::npos) + if (this->_mlmodel.is_onnx_source()) _explicit_batch = true; } @@ -671,9 +670,16 @@ namespace dd if (_bbox) { + if (_dims.nbDims < 3) + throw MLLibBadParamException( + "Bbox model requires 3 output dimensions, found " + + std::to_string(_dims.nbDims)); + _outputIndex1 = _engine->getBindingIndex("keep_count"); _buffers.resize(3); - _floatOut.resize(_max_batch_size * _top_k * 7); + int det_out_size = _max_batch_size * _top_k * _dims.d[2]; + // int det_out_size = _max_batch_size * _top_k * 7; + _floatOut.resize(det_out_size); _keepCount.resize(_max_batch_size); if (inputc._bw) cudaMalloc(&_buffers.data()[_inputIndex], @@ -684,7 +690,7 @@ namespace dd _max_batch_size * 3 * inputc._height * inputc._width * sizeof(float)); cudaMalloc(&_buffers.data()[_outputIndex0], - _max_batch_size * _top_k * 7 * sizeof(float)); + det_out_size * sizeof(float)); cudaMalloc(&_buffers.data()[_outputIndex1], _max_batch_size * sizeof(int)); } @@ -816,7 +822,7 @@ namespace dd { cudaMemcpyAsync(_floatOut.data(), _buffers.data()[_outputIndex0], - num_processed * _top_k * 7 * sizeof(float), + _floatOut.size() * sizeof(float), cudaMemcpyDeviceToHost, cstream); cudaMemcpyAsync(_keepCount.data(), _buffers.data()[_outputIndex1], @@ -837,10 +843,10 @@ namespace dd // GAN/raw output else if (!extract_layer.empty()) { - cudaMemcpyAsync( - _floatOut.data(), _buffers.data()[_outputIndex0], - num_processed * _floatOut.size() * sizeof(float), - cudaMemcpyDeviceToHost, cstream); + cudaMemcpyAsync(_floatOut.data(), + _buffers.data()[_outputIndex0], + _floatOut.size() * sizeof(float), + cudaMemcpyDeviceToHost, cstream); cudaStreamSynchronize(cstream); } else // classification / regression @@ -868,12 +874,21 @@ namespace dd if (_bbox) { int results_height = _top_k; - const int det_size = 7; + // preproc yolox + if (_template == "yolox") + { + _floatOut = yolo_utils::parse_yolo_output( + _floatOut, num_processed, results_height, _nclasses, + inputc._width, inputc._height); + }; + + const int det_size = 7; const float *outr = _floatOut.data(); for (int j = 0; j < num_processed; j++) { + int k = 0; std::vector probs; std::vector cats; diff --git a/src/backends/tensorrt/tensorrtlib.h b/src/backends/tensorrt/tensorrtlib.h index 4c40b8651..170392ec5 100644 --- a/src/backends/tensorrt/tensorrtlib.h +++ b/src/backends/tensorrt/tensorrtlib.h @@ -122,6 +122,8 @@ namespace dd bool _writeEngine = true; std::string _arch; int _gpuid = 0; + std::string + _template; /**< template for models that require specific treatment */ //!< The TensorRT engine used to run the network std::shared_ptr _engine = nullptr; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 27a2e0c53..fe78c5855 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -350,12 +350,19 @@ if (USE_TENSORRT) "resnet_onnx_trt.tar.gz" "resnet_onnx_trt" ) +# DOWNLOAD_DATASET( +# "ONNX yolox model" +# "https://deepdetect.com/models/init/desktop/images/detection/yolox_onnx_trt.tar.gz" +# "examples/trt" +# "yolox_onnx_trt.tar.gz" +# "yolox_onnx_trt" +# ) DOWNLOAD_DATASET( - "ONNX yolox model" - "https://deepdetect.com/models/init/desktop/images/detection/yolox_onnx_trt.tar.gz" + "ONNX yolox model without wrapper" + "https://deepdetect.com/models/init/desktop/images/detection/yolox_onnx_trt_nowrap.tar.gz" "examples/trt" - "yolox_onnx_trt.tar.gz" - "yolox_onnx_trt" + "yolox_onnx_trt_nowrap.tar.gz" + "yolox_onnx_trt_nowrap" ) DOWNLOAD_DATASET( "ONNX CycleGAN model" diff --git a/tests/ut-tensorrtapi.cc b/tests/ut-tensorrtapi.cc index f48ea074b..f9f5111ef 100644 --- a/tests/ut-tensorrtapi.cc +++ b/tests/ut-tensorrtapi.cc @@ -40,7 +40,7 @@ static std::string squeez_repo = "../examples/trt/squeezenet_ssd_trt/"; static std::string refinedet_repo = "../examples/trt/faces_512/"; static std::string squeezv1_repo = "../examples/trt/squeezenet_v1/"; static std::string resnet_onnx_repo = "../examples/trt/resnet_onnx_trt/"; -static std::string yolox_onnx_repo = "../examples/trt/yolox_onnx_trt/"; +static std::string yolox_onnx_repo = "../examples/trt/yolox_onnx_trt_nowrap/"; static std::string cyclegan_onnx_repo = "../examples/trt/cyclegan_resnet_attn_onnx_trt/"; @@ -256,7 +256,7 @@ TEST(tensorrtapi, service_predict_bbox_onnx) + yolox_onnx_repo + "\"},\"parameters\":{\"input\":{\"connector\":\"image\",\"height\":" "640,\"width\":640,\"rgb\":true},\"mllib\":{\"template\":\"yolox\"," - "\"maxBatchSize\":1,\"maxWorkspaceSize\":256,\"gpuid\":0," + "\"maxBatchSize\":2,\"maxWorkspaceSize\":256,\"gpuid\":0," "\"nclasses\":80}}}"; std::string joutstr = japi.jrender(japi.service_create(sname, jstr)); ASSERT_EQ(created_str, joutstr); @@ -266,7 +266,7 @@ TEST(tensorrtapi, service_predict_bbox_onnx) = "{\"service\":\"" + sname + "\",\"parameters\":{\"input\":{},\"output\":{\"bbox\":true," "\"confidence_threshold\":0.8}},\"data\":[\"" - + resnet_onnx_repo + "cat.jpg\"]}"; + + resnet_onnx_repo + "cat.jpg\",\"" + yolox_onnx_repo + "dog.jpg\"]}"; joutstr = japi.jrender(japi.service_predict(jpredictstr)); JDoc jd; std::cout << "joutstr=" << joutstr << std::endl; @@ -274,12 +274,19 @@ TEST(tensorrtapi, service_predict_bbox_onnx) ASSERT_TRUE(!jd.HasParseError()); ASSERT_EQ(200, jd["status"]["code"]); ASSERT_TRUE(jd["body"]["predictions"].IsArray()); + ASSERT_EQ(jd["body"]["predictions"].Size(), 2); + + uint32_t cat_id = jd["body"]["predictions"][0]["uri"].GetString() + == (resnet_onnx_repo + "cat.jpg") + ? 0 + : 1; + uint32_t dog_id = 1 - cat_id; - auto &preds = jd["body"]["predictions"][0]["classes"]; + auto &preds = jd["body"]["predictions"][cat_id]["classes"]; ASSERT_EQ(preds.Size(), 1); std::string cl1 = preds[0]["cat"].GetString(); - ASSERT_TRUE(cl1 == "14"); - ASSERT_TRUE(preds[0]["prob"].GetDouble() > 0.85); + ASSERT_EQ(cl1, "15"); + ASSERT_TRUE(preds[0]["prob"].GetDouble() > 0.9); auto &bbox = preds[0]["bbox"]; ASSERT_TRUE(bbox["xmin"].GetDouble() < 50 && bbox["xmax"].GetDouble() > 200 && bbox["ymin"].GetDouble() < 50 @@ -287,13 +294,24 @@ TEST(tensorrtapi, service_predict_bbox_onnx) // Check confidence threshold ASSERT_TRUE(preds[preds.Size() - 1]["prob"].GetDouble() >= 0.8); + // Check second pred + auto &preds2 = jd["body"]["predictions"][dog_id]["classes"]; + ASSERT_EQ(preds2.Size(), 1); + std::string cl2 = preds2[0]["cat"].GetString(); + ASSERT_EQ(cl2, "16"); + ASSERT_TRUE(preds2[0]["prob"].GetDouble() > 0.8); + auto &bbox2 = preds[0]["bbox"]; + ASSERT_TRUE(bbox2["xmin"].GetDouble() < 50 && bbox2["xmax"].GetDouble() > 200 + && bbox2["ymin"].GetDouble() < 50 + && bbox2["ymax"].GetDouble() > 200); + ASSERT_TRUE(fileops::file_exists(yolox_onnx_repo + "TRTengine_arch" - + get_trt_archi() + "_bs1")); + + get_trt_archi() + "_bs2")); jstr = "{\"clear\":\"lib\"}"; joutstr = japi.jrender(japi.service_delete(sname, jstr)); ASSERT_EQ(ok_str, joutstr); ASSERT_TRUE(!fileops::file_exists(yolox_onnx_repo + "TRTengine_arch" - + get_trt_archi() + "_bs1")); + + get_trt_archi() + "_bs2")); } TEST(tensorrtapi, service_predict_gan_onnx) diff --git a/tools/torch/trace_yolox.py b/tools/torch/trace_yolox.py index c9405c504..a75dfa7b5 100755 --- a/tools/torch/trace_yolox.py +++ b/tools/torch/trace_yolox.py @@ -23,6 +23,11 @@ def main(): parser.add_argument('--num_classes', type=int, default=80, help="Number of classes of the model") parser.add_argument('--gpu', type=int, help="GPU id to run on GPU") parser.add_argument('--to_onnx', action="store_true", help="Export model to onnx") + parser.add_argument('--use_wrapper', action="store_true", help="In case of onnx export, if this option is present, the model will be wrapped so that its output match dede expectations") + parser.add_argument('--top_k', type=int, default=200, help="When exporting to onnx, specify maximum returned prediction count") + parser.add_argument('--batch_size', type=int, default=1, help="When exporting to onnx, batch size of model") + parser.add_argument('--img_width', type=int, default=640, help="Width of the image when exporting with fixed image size") + parser.add_argument('--img_height', type=int, default=640, help="Height of the image when exporting with fixed image size") args = parser.parse_args() @@ -78,18 +83,21 @@ def main(): if args.to_onnx: model = replace_module(model, nn.SiLU, SiLU) - model = YoloXWrapper_TRT(model) + model = YoloXWrapper_TRT(model, topk = args.top_k, raw_output = not args.use_wrapper) model.to(device) model.eval() filename += ".onnx" - example = get_image_input(1, 640, 640) + example = get_image_input(args.batch_size, args.img_width, args.img_height) + # XXX: dynamic batch size not supported with wrapper + # XXX: dynamic batch size not yet supported in dede as well + dynamic_axes = None # {"input": {0: "batch"}} if not args.use_wrapper else None torch.onnx.export( model, example, filename, export_params=True, verbose=args.verbose, opset_version=11, do_constant_folding=True, - input_names=["input"], output_names=["detection_out", "keep_count"]) - #, dynamic_axes={"input": {0: "batch"}}) + input_names=["input"], output_names=["detection_out", "keep_count"], + dynamic_axes = dynamic_axes) else: # wrap model model = YoloXWrapper(model, args.num_classes, postprocess) @@ -187,10 +195,11 @@ def forward(self, x, ids = None, bboxes = None, labels = None): class YoloXWrapper_TRT(torch.nn.Module): - def __init__(self, model, topk=200): + def __init__(self, model, topk=200, raw_output=False): super(YoloXWrapper_TRT, self).__init__() self.model = model self.topk = topk + self.raw_output = raw_output def to_xyxy(self, boxes): xyxy_boxes = boxes.new_zeros(boxes.shape) @@ -204,8 +213,11 @@ def forward(self, x): # xmin, ymin, xmax, ymax, objectness, conf cls1, conf cl2... output = self.model(x)[0] + if self.raw_output: + return output, torch.zeros(output.shape[0]) + box_count = output.shape[1] - cls_scores, cls_pred = output[:,:,6:].max(dim=2, keepdim=True) + cls_scores, cls_pred = output[:,:,5:].max(dim=2, keepdim=True) batch_ids = torch.arange(output.shape[0], device=x.device).view( -1, 1).repeat(1, output.shape[1]).unsqueeze(2)