読者です 読者をやめる 読者になる 読者になる

C#で散布図を表示する練習

Qtの場合ですと、qwtというなかなか使いやすいグラフライブラリがあります。C#にもいくつかライブラリがあるようですが、どれが良いのかよくわかりません。ある方に聞いたところChartが良いと教えてくれたので、試しに散布図を作ってみました。

表示させるデータは
http://research.microsoft.com/en-us/um/people/cmbishop/PRML/webdatasets/datasets.htm
のClassificationです。
データの形式は
1.208985 0.421448 0.000000
0.504542 -0.285730 1.000000
みたいにx,y,labelとなっています。

ラベル別に点を赤と青で分け、その識別線を緑で描画しています。直線は描画範囲外で点を二つ評価して、むりやり直線を引いているだけで、非線形判別の場合はどうするのかはわかりません。qwtだと識別の結果をグラデーションで表示できますが、Chartはどうなんでしょう。

プログラムはデータ1点ごとにオンライン学習のSCW(Soft Confidence-Weigted Learning)を実行して、識別線を動かしています。下の図のように徐々に線が更新されるイメージです。



ソースコード
Form1.cs

using System;
using System.Collections.Generic;
using System.Drawing;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using System.Windows.Forms;
using System.Windows.Forms.DataVisualization.Charting;
using ConfidenceWeightedNative;

namespace ConfidenceWeighted
{
    public partial class Form1 : Form
    {
        private readonly CPPWrapper _wrapper = new CPPWrapper();
        private readonly List<Data> _dataset = new List<Data>();

        public Form1()
        {
            InitializeComponent();

            InitializeChart();
        }

        private void InitializeChart()
        {
            chart1.ChartAreas[0].Axes[0].Title = "x";
            chart1.ChartAreas[0].Axes[1].Title = "y";
            chart1.ChartAreas[0].AxisX.MajorGrid.Enabled = false;
            chart1.ChartAreas[0].AxisY.MajorGrid.Enabled = false;

            chart1.ChartAreas[0].AxisX.Minimum = -3.5;
            chart1.ChartAreas[0].AxisX.Maximum = 3.5;
            chart1.ChartAreas[0].AxisY.Minimum = -3.5;
            chart1.ChartAreas[0].AxisY.Maximum = 3.5;

            chart1.Series[0].ChartType = SeriesChartType.Point;
            chart1.Series[0].MarkerStyle = MarkerStyle.Circle;
            chart1.Series[0].MarkerColor = Color.Red;
            chart1.Series[0].LegendText = "a";

            var b = new Series
            {
                ChartType = SeriesChartType.Point,
                MarkerStyle = MarkerStyle.Circle,
                MarkerColor = Color.DarkBlue,
                LegendText = "b"
            };
            chart1.Series.Add(b);

            // プロットするデータの読み込み
            LoadDataset();
        }

        private Task ProcessInCpp()
        {
            return Task.Run(() =>
                {
                    foreach (var p in _dataset)
                    {
                        // 点のプロット
                        if (p.label == 1)
                        {
                            chart1.Series[0].Points.AddXY(p.x, p.y);
                        }
                        else
                        {
                            chart1.Series[1].Points.AddXY(p.x, p.y);
                        }

                        var resultW = _wrapper.Process(p);

                        DrawLine(resultW);

                        System.Threading.Thread.Sleep(200);

                        textBox1.Text = (resultW[0] + " " + resultW[1] + " " + resultW[2]);
                    }
                });
        }

        private async void button1_Click(object sender, EventArgs e)
        {
            button1.Enabled = false;
            await ProcessInCpp();
        }

        // 識別の直線を引く
        private bool _isFirst = true;
        private void DrawLine(List<double> w)
        {
            var line = new Series
            {
                ChartType = SeriesChartType.Line,
                //MarkerStyle = MarkerStyle.Circle,
                //MarkerColor = Color.ForestGreen,
                Color = Color.ForestGreen,
                LegendText = "line"
            };
            if (_isFirst)
            {
                chart1.Series.Add(line);
            }
            _isFirst = false;

            double a = w[1];
            double b = w[2];
            double c = w[0];
            var f = new Func<double, double>(x => -a/b * x - c/b);
            double x1 = -3.5;
            double y1 = f(x1);
            double x2 = 3.5;
            double y2 = f(x2);
            chart1.Series[2].Points.Clear();
            chart1.Series[2].Points.AddXY(x1, y1);
            chart1.Series[2].Points.AddXY(x2, y2);
        }

        // データセットの読み込み
        private void LoadDataset()
        {
            const string filepath = @"classification.txt";
            try
            {
                var lines = File.ReadAllLines(filepath);

                var filterd = lines
                    .Select(line => line.Split(' '))
                    .Select(tokens => new
                        {
                            x = tokens[0].Trim(),
                            y = tokens[1].Trim(),
                            label = tokens[2].Trim(),
                        });

                foreach (var p in filterd)
                {
                    double x;
                    double y;
                    double label;
                    double.TryParse(p.x, out x);
                    double.TryParse(p.y, out y);
                    double.TryParse(p.label, out label);
                    label = (label == 1.0) ? 1.0 : -1.0;

                    var data = new Data {label = label, x = x, y = y};
                    _dataset.Add(data);
                }
            }
            catch (Exception)
            {
                System.Windows.Forms.MessageBox.Show(filepath + " がありません", "Error");
            }
        }
    }
}

Wrapper.h

#pragma once

#using <System.Windows.Forms.dll>

namespace ConfidenceWeightedNative
{
	public ref class Data
	{
	public:
		double x;
		double y;
		double label; // 1 or -1
	};

	class SCW;

	public ref class CPPWrapper
	{
	public:
		CPPWrapper();
		~CPPWrapper();
		!CPPWrapper();

		System::Collections::Generic::List<double>^ Process(Data^ p);

	private:
		SCW *scw_;
	};
}

Wrapper.cpp

#include "Wrapper.h"

#include <tuple>

#include "SCW.h"

namespace ConfidenceWeightedNative
{
	using namespace System::Collections::Generic;

	CPPWrapper::CPPWrapper()
		: scw_(nullptr)
	{
		scw_ = new SCW();
	}

	CPPWrapper::~CPPWrapper()
	{
		this->!CPPWrapper();
	}

	CPPWrapper::!CPPWrapper()
	{
		if (scw_ != nullptr)
		{
			delete scw_;
			scw_ = nullptr;
		}
	}

	List<double>^ CPPWrapper::Process(Data^ p)
	{
		double x = p->x;
		double y = p->y;
		double label = p->label;
		// std::make_tuple(p->label, p->x, p->y);がなぜかできない
		auto weight = scw_->Process(std::make_tuple(label, x, y));

		auto result = gcnew List<double>();

		for (size_t i = 0; i < weight.size(); i++)
			result->Add(weight[i]);
		
		return result;
	}
} // namespace ConfidenceWeightedNative

SCW.h

#pragma once

#include <vector>

#include <Eigen/core>

namespace ConfidenceWeightedNative
{
	class SCW
	{
	public:
		SCW();

		std::vector<double> Process(const std::tuple<double, double, double>& dataset);
	private:
		const double eta;
		const double phi;
		const double C;

		double LossFunction(const Eigen::VectorXd& w, const Eigen::VectorXd& x_t, const double y_t) const;
		double CalcAlpha(const Eigen::VectorXd& x, const double y) const;
		double CalcBeta(const double alpha, const Eigen::VectorXd& x) const;

		Eigen::VectorXd mu_;
		Eigen::MatrixXd sigma_;
	};
}

SCW.cpp

#include "SCW.h"

#include <Eigen/Core>
using namespace Eigen;

namespace ConfidenceWeightedNative
{
	double erf(double x)
	{
		double re  = -1.2732395447351626861510701069801e+0;

		double sx = x * x;

		return (x >= 0 ? +1 : -1) * sqrt(1 - exp(re * sx) * 
(1 + sx * sx * (0.1101999999998335 / (sx + 7.2085122317063002) + 0.0219999999999997)));
	} 

	inline double cdf(double x)
	{
		return 0.5 + 0.5 * erf(x / sqrt(2.0));
	}

	SCW::SCW()
		:eta(0.9), phi(1.0/cdf(eta)), C(1.0)
	{
		mu_ = VectorXd::Zero(1+2);
		sigma_ = MatrixXd::Identity(1+2, 1+2);
	}

	std::vector<double> SCW::Process(const std::tuple<double, double, double>& p)
	{
		// receive an example xt
		// offset, x, y
		VectorXd x_t(1+2);
		x_t << 1.0, std::get<1>(p), std::get<2>(p);

		// make prediction
		const double yhat_t = (mu_.dot(x_t) >= 0.0) ? 1.0 : 0.0;

		// receive true label
		const double y_t = std::get<0>(p);

		// suffer loss
		const double loss = LossFunction(mu_, x_t, y_t);

		if (loss > 0.0)
		{
			// update coefficients
			const double alpha_t = CalcAlpha(x_t, y_t);
			const double beta_t = CalcBeta(alpha_t, x_t);

			// update
			mu_.noalias() += alpha_t * y_t * sigma_ * x_t;
			sigma_ -= beta_t * sigma_ * x_t * x_t.transpose() * sigma_;
		}

		std::vector<double> result_weight(1+2);
		for (int i = 0; i < 1+2; i++)
			result_weight[i] = mu_(i);

		return result_weight;
	}

	double SCW::LossFunction(const VectorXd& w, const VectorXd& x_t, const double y_t) const
	{
		const double m = y_t * w.dot(x_t);
		if (m >= 1.0)
		{
			return 0.0;
		}
		else
		{
			return 1.0 - m;
		}
	}

	double SCW::CalcAlpha(const VectorXd& x, const double y) const
	{
		const double v_t = x.transpose() * sigma_ * x;
		const double m_t = y * mu_.dot(x);
		const double psi = 1.0 + phi * phi * 0.5;
		const double zeta = 1.0 + phi * phi;
		const double phi2 = phi * phi;
		const double phi4 = phi2 * phi2;

		const double alpha = (-m_t * psi + sqrt(m_t*m_t*phi4*0.25 + v_t*phi2*zeta)) / (v_t * zeta);

		return std::min(C, std::max(0.0, alpha));
	}

	double SCW::CalcBeta(const double alpha, const VectorXd& x) const
	{
		const double v = x.transpose() * sigma_ * x;
		const double temp = -alpha * v * phi + sqrt(alpha*alpha*v*v*phi*phi+4.0*v);
		const double u = temp * temp * 0.25;

		return (alpha * phi) / (sqrt(u) + v * alpha * phi);
	}
} // namespace ConfidenceWeightedNative