読者です 読者をやめる 読者になる 読者になる

EMアルゴリズムでMoGを解く練習

定番ネタですが、EMアルゴリズムのプログラムを書きました。

目的
対数ゆう度関数を最大化するパラメータを求めます。今回はMoG(Mixture of Gaussians)なので
Pr(x|\theta)=\sum_{k=1}^{K} \lambda_{k} Norm_{x}(\mu_{k},\Sigma_{k}^{})
\lambda_k\mu_k\Sigma_kがほしいです。そのためには
\hat{\theta}=argmax_{\theta}\sum_{i=1}^{I}log Pr(x_i | \theta)=argmax_{\theta}\sum_{i=1}^{I} log(\sum_{k=1}^{K} \lambda_{k} Norm_{x}(\mu_{k},\Sigma_{k}^{}))
を計算すればいいみたいです。
E-Step
r_{ik}=\frac{\lambda_k Norm_{x_i}(\mu_k , \Sigma_{k})}{\sum_{j=1}^{K} \lambda_k Norm_{x_i}(\mu_j , \Sigma_{j})}
M-Step
Latexで数式書くのは私には無理なので、詳しくはhttp://www.computervisionmodels.com/のSlide7を見てください。
プログラム
結果はこんな感じになりました。

さらに繰り返すとおかしなことになるので、どこかバグがあるようです。
半年ぐらいプログラミングをしていないので、リハビリも兼ねているんですがC++むずかしいですね。
Form1.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.IO;
using System.Windows.Forms;
using OxyPlot;
using OxyPlot.Series;
using OxyPlot.Axes;
using EM_CPP;

namespace EMAlgorithm
{
    public partial class Form1 : Form
    {
        private IEnumerable<ScatterPoint> _dataset; 
        private readonly EM _em = new EM();
        private readonly PlotModel _model = new PlotModel("Example");

        public Form1()
        {
            InitializeComponent();
        }

        private void LoadDataset(string filename)
        {
            try
            {
                _dataset =  File.ReadLines(filename)
                                .Select(line => line.Split(' '))
                                .Select(tokens => new ScatterPoint
                                {
                                    X = double.Parse(tokens[0].Trim()) - 3.5,
                                    Y = (double.Parse(tokens[1].Trim()) - 70)*5/60,
                                    Value = 0,
                                });

            }
            catch (Exception)
            {
                throw;
            }
        }

        private void Process()
        {
            var data = _dataset.Select(onedata => new Tuple<double, double, double>(
                                                      onedata.X,
                                                      onedata.Y,
                                                      onedata.Value
                                                      )).ToList();
            _em.Process(data);
        }

        private void button1_Click(object sender, EventArgs e)
        {
            try
            {
                LoadDataset(@"faithful.txt");
            }
            catch (Exception)
            {
                
                throw;
            }

            PlotData();

            DrawNormalDistribution(0);
            DrawNormalDistribution(1);

            this.plot1.Model = _model;
        }

        private void PlotData()
        {
            var scatter = new ScatterSeries()
            {
                MarkerType = MarkerType.Circle,
                MarkerSize = 4,
            };
            foreach (var point in _dataset)
            {
                scatter.Points.Add(point);
            }

            var axes = new LinearColorAxis()
            {
                Position = AxisPosition.Right,
                Minimum = 0,
                Maximum = 1,
                Palette = OxyPalette.Interpolate(255, OxyColors.Red, OxyColors.Blue),
            };

            _model.Axes.Add(axes);
            _model.Series.Add(scatter);
        }
        
        private void DrawNormalDistribution(int num)
        {
            const int n = 100;
            const double x0 = -2.5;
            const double x1 = 2.5;
            const double y0 = -2.5;
            const double y1 = 2.5;
            Func<double, double, double> normal2D = (x, y) => _em.Normal2d(x, y, num);

            var xvalues = ArrayHelper.CreateVector(x0, x1, n);
            var yvalues = ArrayHelper.CreateVector(y0, y1, n);
            var peaksData = ArrayHelper.Evaluate(normal2D, xvalues, yvalues);
            //var hms = new HeatMapSeries { X0 = x0, X1 = x1, Y0 = y0, Y1 = y1, Data = peaksData };
            //model.Series.Add(hms);

            var cs = new ContourSeries
            {
                Color = (num == 0) ? OxyColors.Blue : OxyColors.Red,
                FontSize = 20,
                //ContourLevelStep = 0.2,
                ContourLevels = new double[]{0.1},
                LabelBackground = OxyColors.Undefined,
                ColumnCoordinates = yvalues,
                RowCoordinates = xvalues,
                Data = peaksData,
                LabelStep = 1,
                StrokeThickness = 7,
            };

            _model.Series.Add(cs);
        }

        private void CalculatePlotValue()
        {
            _dataset = _dataset.Select(point => new ScatterPoint
                {
                    X = point.X,
                    Y = point.Y,
                    Value = _em.Normal2d(point.X, point.Y, 0) / (_em.Normal2d(point.X, point.Y, 0) + _em.Normal2d(point.X, point.Y, 1)),  
                });
        }

        private void button2_Click(object sender, EventArgs e)
        {
            _model.Series.Clear();

            CalculatePlotValue();

            PlotData();

            Process();

            DrawNormalDistribution(0);
            DrawNormalDistribution(1);

            this.plot1.RefreshPlot(true);
        }
    }
}

EM.h

#pragma once

#define _USE_MATH_DEFINES
#include <math.h>

#include <vector>

#include <eigen/core>
#include <eigen/LU>

using namespace System::Collections::Generic;

namespace EM_CPP
{

public ref class EM
{
public:
	EM();
	!EM();
	~EM();

	double Normal2d(double _x, double _y, int num)
	{
		Eigen::VectorXd X(2);
		X << _x, _y;

		Eigen::VectorXd mean(2);
		mean << mean_[num]->Item1, mean_[num]->Item2;

		Eigen::MatrixXd S(2, 2);
		S << sigma_[num]->Item1, sigma_[num]->Item2, sigma_[num]->Item2, sigma_[num]->Item3;

		auto numerator = (X-mean).transpose() * S.inverse() * (X-mean);
		auto denominator = 2.0 * M_PI * sqrt(abs(S.determinant()));
		return exp(-numerator(0)/2.0) / denominator;
	}

	double GetMaxGaussValue(int num)
	{
		return Normal2d(mean_[num]->Item1, mean_[num]->Item2, num);
	}

	void Process(List<System::Tuple<double, double, double>^>^ dataset);

	List<System::Tuple<double, double>^>^ mean_;
	List<System::Tuple<double, double, double>^>^ sigma_;
};

} // namespace EM_CPP

EM.cpp

#include "EM.h"

#include <vector>

namespace EM_CPP
{

EM::EM()
{
	mean_ = gcnew List<System::Tuple<double, double>^>();
	sigma_ = gcnew List<System::Tuple<double, double, double>^>();

	mean_->Add(gcnew System::Tuple<double, double>(-1.5, 0.5));
	mean_->Add(gcnew System::Tuple<double, double>(1.5, -0.5));
	sigma_->Add(gcnew System::Tuple<double, double, double>(0.5, 0, 0.5));
	sigma_->Add(gcnew System::Tuple<double, double, double>(0.5, 0, 0.5));
}

EM::!EM()
{

}

EM::~EM()
{
	this->!EM();
}

void EM::Process(List<System::Tuple<double, double, double>^>^ dataset)
{
	const int K = 2;
	std::vector<Eigen::VectorXd> x;
	std::vector<Eigen::VectorXd> mean(K);
	std::vector<Eigen::MatrixXd> sigma(K);
	std::vector<double> lambda(K);
	std::vector<std::vector<double>> r;

	for each (System::Tuple<double, double, double>^ data in dataset)
	{
		Eigen::VectorXd temp(K);
		temp << data->Item1, data->Item2;
		x.push_back(temp);

		r.push_back(std::vector<double>(K));
	}

	Eigen::VectorXd temp(K);
	temp << mean_[0]->Item1, mean_[0]->Item2;
	mean[0] = temp; 
	temp << mean_[1]->Item1, mean_[1]->Item2;
	mean[1] = temp;

	Eigen::MatrixXd temp_m(K, K);
	temp_m << sigma_[0]->Item1, sigma_[0]->Item2, sigma_[0]->Item2, sigma_[0]->Item3;
	sigma[0] = temp_m;
	temp_m << sigma_[1]->Item1, sigma_[1]->Item2, sigma_[1]->Item2, sigma_[1]->Item3;
	sigma[1] = temp_m;

	std::fill(lambda.begin(), lambda.end(), 0.5);
	//std::fill(r.begin(), r.end(), 0);

	// E-Step
	double sum = 0; 
	for (size_t i = 0; i < x.size(); i++)
		for (int k = 0; k < K; k++)
			sum += lambda[k] * Normal2d(x[i](0), x[k](1), k);

	for (size_t i = 0; i < x.size(); i++)
		for (int k = 0; k < K; k++)
			r[i][k] = lambda[k] * Normal2d(x[i](0), x[i](1), k) / sum;

	// M-Step
	double sum_rik = 0;
	for (size_t i = 0; i < x.size(); i++)
		for (int k = 0; k < K; k++)
			sum_rik += r[i][k];

	for (int k = 0; k < K; k++)
	{
		double sum = 0;
		for (size_t i = 0; i < x.size(); i++)
			sum += r[i][k];
		lambda[k] = sum / sum_rik;

		Eigen::VectorXd temp_x(2);
		temp_x.setZero();
		for (size_t i = 0; i < x.size(); i++)
			temp_x += r[i][k] * x[i];
		double sum_ri = 0;
		for (size_t i = 0; i < x.size(); i++)
			sum_ri += r[i][k];
		mean[k] = temp_x / sum_ri;

		Eigen::MatrixXd temp_sigma(2, 2);
		temp_sigma.setZero();
		for (size_t i = 0; i < x.size(); i++)
			temp_sigma += r[i][k] * (x[i] - mean[k]) * (x[i] - mean[k]).transpose();
		sigma[k] = temp_sigma / sum_ri;
	}


	mean_ = gcnew List<System::Tuple<double, double>^>();
	sigma_ = gcnew List<System::Tuple<double, double, double>^>();

	mean_->Add(gcnew System::Tuple<double, double>(mean[0](0), mean[0](1)));
	mean_->Add(gcnew System::Tuple<double, double>(mean[1](0), mean[1](1)));
	sigma_->Add(gcnew System::Tuple<double, double, double>(sigma[0](0), sigma[0](1), sigma[0](3)));
	sigma_->Add(gcnew System::Tuple<double, double, double>(sigma[1](0), sigma[1](1), sigma[1](3)));

	//System::Console::WriteLine("x={0} y={1}", mean_[0]->Item1, mean_[0]->Item2);
}

} // namespace EM_CPP