mirror of https://github.com/Qortal/Brooklyn
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
3.4 KiB
92 lines
3.4 KiB
// |
|
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. |
|
// SPDX-License-Identifier: MIT |
|
// |
|
#include "NonMaxSuppression.hpp" |
|
|
|
#include <algorithm> |
|
|
|
namespace od |
|
{ |
|
|
|
static std::vector<unsigned int> GenerateRangeK(unsigned int k) |
|
{ |
|
std::vector<unsigned int> range(k); |
|
std::iota(range.begin(), range.end(), 0); |
|
return range; |
|
} |
|
|
|
|
|
/** |
|
* @brief Returns the intersection over union for two bounding boxes |
|
* |
|
* @param[in] First detect containing bounding box. |
|
* @param[in] Second detect containing bounding box. |
|
* @return Calculated intersection over union. |
|
* |
|
*/ |
|
static double IntersectionOverUnion(DetectedObject& detect1, DetectedObject& detect2) |
|
{ |
|
uint32_t area1 = (detect1.GetBoundingBox().GetHeight() * detect1.GetBoundingBox().GetWidth()); |
|
uint32_t area2 = (detect2.GetBoundingBox().GetHeight() * detect2.GetBoundingBox().GetWidth()); |
|
|
|
float yMinIntersection = std::max(detect1.GetBoundingBox().GetY(), detect2.GetBoundingBox().GetY()); |
|
float xMinIntersection = std::max(detect1.GetBoundingBox().GetX(), detect2.GetBoundingBox().GetX()); |
|
|
|
float yMaxIntersection = std::min(detect1.GetBoundingBox().GetY() + detect1.GetBoundingBox().GetHeight(), |
|
detect2.GetBoundingBox().GetY() + detect2.GetBoundingBox().GetHeight()); |
|
float xMaxIntersection = std::min(detect1.GetBoundingBox().GetX() + detect1.GetBoundingBox().GetWidth(), |
|
detect2.GetBoundingBox().GetX() + detect2.GetBoundingBox().GetWidth()); |
|
|
|
double areaIntersection = std::max(yMaxIntersection - yMinIntersection, 0.0f) * |
|
std::max(xMaxIntersection - xMinIntersection, 0.0f); |
|
double areaUnion = area1 + area2 - areaIntersection; |
|
|
|
return areaIntersection / areaUnion; |
|
} |
|
|
|
std::vector<int> NonMaxSuppression(DetectedObjects& inputDetections, float iouThresh) |
|
{ |
|
// Sort indicies of detections by highest score to lowest. |
|
std::vector<unsigned int> sortedIndicies = GenerateRangeK(inputDetections.size()); |
|
std::sort(sortedIndicies.begin(), sortedIndicies.end(), |
|
[&inputDetections](int idx1, int idx2) |
|
{ |
|
return inputDetections[idx1].GetScore() > inputDetections[idx2].GetScore(); |
|
}); |
|
|
|
std::vector<bool> visited(inputDetections.size(), false); |
|
std::vector<int> outputIndiciesAfterNMS; |
|
|
|
for (int i=0; i < inputDetections.size(); ++i) |
|
{ |
|
// Each new unvisited detect should be kept. |
|
if (!visited[sortedIndicies[i]]) |
|
{ |
|
outputIndiciesAfterNMS.emplace_back(sortedIndicies[i]); |
|
visited[sortedIndicies[i]] = true; |
|
} |
|
|
|
// Look for detections to suppress. |
|
for (int j=i+1; j<inputDetections.size(); ++j) |
|
{ |
|
// Skip if already kept or suppressed. |
|
if (!visited[sortedIndicies[j]]) |
|
{ |
|
// Detects must have the same label to be suppressed. |
|
if (inputDetections[sortedIndicies[j]].GetLabel() == inputDetections[sortedIndicies[i]].GetLabel()) |
|
{ |
|
auto iou = IntersectionOverUnion(inputDetections[sortedIndicies[i]], |
|
inputDetections[sortedIndicies[j]]); |
|
if (iou > iouThresh) |
|
{ |
|
visited[sortedIndicies[j]] = true; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
return outputIndiciesAfterNMS; |
|
} |
|
|
|
} // namespace od
|
|
|