Efficient Graph-Based Image Segmentationのお勉強

はじめに

画像中の似ているところを見つけてラベリングをする手法があります。いわゆるImage segmentationのことで、Efficient Graph-Based Image Segmentation, 2004.が有名そうなので調べてみました。

アルゴリズム

Union-Find Treeを使うのが特徴的です。aとbが同じグループに属するかを調べる、aとbのグループを併合する、が簡単にできます。プログラミングコンテストチャレンジブックをちらちら見ながら作りました。Segmentationのアルゴリズムhttp://irohalog.hatenablog.com/entry/2014/10/05/213948が詳しいです。

実装

もともとSalient Object Detection: A Benchmark, 2014.を見ていたらDRFIって手法が良いらしくて、それがEfficient Graph-Based Image Segmentationを使っていたので調べました。10年以上前の論文ですが、今でも通用するようです。
作ったんですが、論文の画像を再現できないので何か間違っているみたいです。何かは分かりません(-_-)
f:id:wildpie:20150211175255j:plain

void Segmentation::Run(const std::string filename)
{
  using namespace cv;
  const Mat1b image = imread(filename, 0);
  if (image.empty())
  {
    std::cerr << "Error: No image file available" << std::endl;
    return;
  }
  // Gaussian filter to smooth the image
  cv::GaussianBlur(image, image, cv::Size(5, 5), 0.5);

  const int k = 500;

  std::vector<Edge> edges = MakeEdge(image);
  // 0. Sort Edge, by non-decreasing edge weight 
  std::sort(edges.begin(), edges.end(), [](Edge a, Edge b) {return a.weight < b.weight;});
  // 1. Start with a segmentation S0, where each vertex vi is in its own component
  tree_.reset(new UnionFindTree(image.rows, image.cols, k));
  // 2. Repeat step 3 for q=1,...,m
  // 3. Construct Sq given Sq-1
  for (Edge& edge : edges)
  {
    // If vi and vj are in disjoint components of Sq-1 and weight is small 
    // compared to the internal difference of both those components, then merge
    CompareMerge(edge);
  }

  // Reduce min components
  for (Edge& edge : edges)
  {
    const double min = 500;
    ReduceMin(edge, min);
  }

  cv::Mat3b label(image.rows, image.cols);
  MakeLabeledImage(edges, label);

  imshow("Input Image", image);
  imshow("Result Image", label);
  cv::imwrite("a.jpg", label);

  cv::waitKey(0);
}

細かいコード

template <typename T>
std::vector<Edge> MakeEdge(const T& image)
{
  std::vector<Edge> edges;

  for (int y = 0; y < image.rows; y++)
  {
    for (int x = 0; x < image.cols; x++)
    {
      if (x < image.cols - 1)
        edges.push_back(Edge(image, x, y, x+1, y));

      if (y < image.rows - 1)
        edges.push_back(Edge(image, x, y, x, y+1));

      if (x < image.cols - 1 && y < image.rows - 1)
        edges.push_back(Edge(image, x, y, x+1, y+1));

      if (x < image.cols - 1 && y > 0)
        edges.push_back(Edge(image, x, y, x+1, y-1));
    }
  }

  return std::move(edges);
}

void Segmentation::CompareMerge(Edge& edge)
{
  int a = tree_->Find(edge.vi);
  int b = tree_->Find(edge.vj);
  if (a == b) return;

  if (edge.weight <= tree_->Threshold(a)
      && edge.weight <= tree_->Threshold(b))
  {
    tree_->Merge(a, b);
    tree_->UpdateThreshold(tree_->Find(a), edge.weight);
  }
}

void Segmentation::ReduceMin(Edge& edge, double min)
{
  int a = tree_->Find(edge.vi);
  int b = tree_->Find(edge.vj);
  if (a == b) return;

  if (tree_->Size(a) < min || tree_->Size(b) < min)
    tree_->Merge(a, b);
}

void Segmentation::MakeLabeledImage(const std::vector<Edge>& edges, cv::Mat3b& label)
{
  std::random_device rd;
  std::mt19937 mt(rd());
  std::uniform_int_distribution<int> random(0, 255);
  std::vector<cv::Vec3b> colors(label.rows*label.cols);
  for (int i = 0; i < colors.size(); i++)
    colors[i] = cv::Vec3b(random(mt), random(mt), random(mt));

  for (int y = 0; y < label.rows; y++)
  {
    for (int x = 0; x < label.cols; x++)
    {
      label(y, x) = colors[tree_->Find(x + label.cols*y)];
    }
  }
}

union-find tree

class UnionFindTree
{
public:
  UnionFindTree(int height, int width, int k) 
    : parent_(height*width), rank_(height*width), size_(height*width), threshold_(height*width), k_(k)
  {
    std::iota(parent_.begin(), parent_.end(), 0);
    std::fill(rank_.begin(), rank_.end(), 0);
    std::fill(size_.begin(), size_.end(), 1);
    std::fill(threshold_.begin(), threshold_.end(), k_ / 1.0);
  }

  ~UnionFindTree() {}

  int Find(int x)
  {
    if (parent_[x] == x)
      return x;
    else
      return parent_[x] = Find(parent_[x]);
  }

  void Merge(int x, int y)
  {
    x = Find(x);
    y = Find(y);
    if (x == y) return;

    if (rank_[x] < rank_[y])
    {
      parent_[x] = y;
      size_[y] += size_[x];
    }
    else
    {
      parent_[y] = x;
      size_[x] += size_[y];
      if (rank_[x] == rank_[y]) rank_[x]++;
    }
  }

  bool IsSame(int x, int y)
  {
    return Find(x) == Find(y);
  }

  int Size(int x) const
  {
    return size_[x];
  }

  double Threshold(int x)
  {
    return threshold_[x];
  }

  void UpdateThreshold(int a, double weight)
  {
    threshold_[a] = weight + (double)k_ / Size(a);
  }

private:
  std::vector<int> parent_;
  std::vector<int> rank_;
  std::vector<int> size_;
  std::vector<double> threshold_;
  int k_;
};

Edge

class Edge
{
public:
	template <typename T>
	Edge(const T& image, int x1, int y1, int x2, int y2) 
	{
		vi = x1 + image.cols * y1;
		vj = x2 + image.cols * y2;
		weight = CalcWeight(image, x1, y1, x2, y2);
	}

	int vi;
	int vj;
	double weight;

private:
	static double CalcWeight(const cv::Mat1b& image, int x1, int y1, int x2, int y2)
	{
		return std::abs((double)image(y1, x1) - image(y2, x2));
	}

	static double CalcWeight(const cv::Mat3b& image, int x1, int y1, int x2, int y2)
	{
		double w = 0.0;
		for (int i = 0; i < 3; i++)
		{
			w += std::pow((double)image(y1, x1)[i] - image(y2, x2)[i], 2.0);
		}

		return sqrt(w);
	}
};