Efficient Graph-Based Image Segmentationのお勉強
はじめに
画像中の似ているところを見つけてラベリングをする手法があります。いわゆるImage segmentationのことで、Efficient Graph-Based Image Segmentation, 2004.が有名そうなので調べてみました。
アルゴリズム
Union-Find Treeを使うのが特徴的です。aとbが同じグループに属するかを調べる、aとbのグループを併合する、が簡単にできます。プログラミングコンテストチャレンジブックをちらちら見ながら作りました。Segmentationのアルゴリズムはhttp://irohalog.hatenablog.com/entry/2014/10/05/213948が詳しいです。
実装
もともとSalient Object Detection: A Benchmark, 2014.を見ていたらDRFIって手法が良いらしくて、それがEfficient Graph-Based Image Segmentationを使っていたので調べました。10年以上前の論文ですが、今でも通用するようです。
作ったんですが、論文の画像を再現できないので何か間違っているみたいです。何かは分かりません(-_-)
void Segmentation::Run(const std::string filename) { using namespace cv; const Mat1b image = imread(filename, 0); if (image.empty()) { std::cerr << "Error: No image file available" << std::endl; return; } // Gaussian filter to smooth the image cv::GaussianBlur(image, image, cv::Size(5, 5), 0.5); const int k = 500; std::vector<Edge> edges = MakeEdge(image); // 0. Sort Edge, by non-decreasing edge weight std::sort(edges.begin(), edges.end(), [](Edge a, Edge b) {return a.weight < b.weight;}); // 1. Start with a segmentation S0, where each vertex vi is in its own component tree_.reset(new UnionFindTree(image.rows, image.cols, k)); // 2. Repeat step 3 for q=1,...,m // 3. Construct Sq given Sq-1 for (Edge& edge : edges) { // If vi and vj are in disjoint components of Sq-1 and weight is small // compared to the internal difference of both those components, then merge CompareMerge(edge); } // Reduce min components for (Edge& edge : edges) { const double min = 500; ReduceMin(edge, min); } cv::Mat3b label(image.rows, image.cols); MakeLabeledImage(edges, label); imshow("Input Image", image); imshow("Result Image", label); cv::imwrite("a.jpg", label); cv::waitKey(0); }
細かいコード
template <typename T> std::vector<Edge> MakeEdge(const T& image) { std::vector<Edge> edges; for (int y = 0; y < image.rows; y++) { for (int x = 0; x < image.cols; x++) { if (x < image.cols - 1) edges.push_back(Edge(image, x, y, x+1, y)); if (y < image.rows - 1) edges.push_back(Edge(image, x, y, x, y+1)); if (x < image.cols - 1 && y < image.rows - 1) edges.push_back(Edge(image, x, y, x+1, y+1)); if (x < image.cols - 1 && y > 0) edges.push_back(Edge(image, x, y, x+1, y-1)); } } return std::move(edges); } void Segmentation::CompareMerge(Edge& edge) { int a = tree_->Find(edge.vi); int b = tree_->Find(edge.vj); if (a == b) return; if (edge.weight <= tree_->Threshold(a) && edge.weight <= tree_->Threshold(b)) { tree_->Merge(a, b); tree_->UpdateThreshold(tree_->Find(a), edge.weight); } } void Segmentation::ReduceMin(Edge& edge, double min) { int a = tree_->Find(edge.vi); int b = tree_->Find(edge.vj); if (a == b) return; if (tree_->Size(a) < min || tree_->Size(b) < min) tree_->Merge(a, b); } void Segmentation::MakeLabeledImage(const std::vector<Edge>& edges, cv::Mat3b& label) { std::random_device rd; std::mt19937 mt(rd()); std::uniform_int_distribution<int> random(0, 255); std::vector<cv::Vec3b> colors(label.rows*label.cols); for (int i = 0; i < colors.size(); i++) colors[i] = cv::Vec3b(random(mt), random(mt), random(mt)); for (int y = 0; y < label.rows; y++) { for (int x = 0; x < label.cols; x++) { label(y, x) = colors[tree_->Find(x + label.cols*y)]; } } }
union-find tree
class UnionFindTree { public: UnionFindTree(int height, int width, int k) : parent_(height*width), rank_(height*width), size_(height*width), threshold_(height*width), k_(k) { std::iota(parent_.begin(), parent_.end(), 0); std::fill(rank_.begin(), rank_.end(), 0); std::fill(size_.begin(), size_.end(), 1); std::fill(threshold_.begin(), threshold_.end(), k_ / 1.0); } ~UnionFindTree() {} int Find(int x) { if (parent_[x] == x) return x; else return parent_[x] = Find(parent_[x]); } void Merge(int x, int y) { x = Find(x); y = Find(y); if (x == y) return; if (rank_[x] < rank_[y]) { parent_[x] = y; size_[y] += size_[x]; } else { parent_[y] = x; size_[x] += size_[y]; if (rank_[x] == rank_[y]) rank_[x]++; } } bool IsSame(int x, int y) { return Find(x) == Find(y); } int Size(int x) const { return size_[x]; } double Threshold(int x) { return threshold_[x]; } void UpdateThreshold(int a, double weight) { threshold_[a] = weight + (double)k_ / Size(a); } private: std::vector<int> parent_; std::vector<int> rank_; std::vector<int> size_; std::vector<double> threshold_; int k_; };
Edge
class Edge { public: template <typename T> Edge(const T& image, int x1, int y1, int x2, int y2) { vi = x1 + image.cols * y1; vj = x2 + image.cols * y2; weight = CalcWeight(image, x1, y1, x2, y2); } int vi; int vj; double weight; private: static double CalcWeight(const cv::Mat1b& image, int x1, int y1, int x2, int y2) { return std::abs((double)image(y1, x1) - image(y2, x2)); } static double CalcWeight(const cv::Mat3b& image, int x1, int y1, int x2, int y2) { double w = 0.0; for (int i = 0; i < 3; i++) { w += std::pow((double)image(y1, x1)[i] - image(y2, x2)[i], 2.0); } return sqrt(w); } };