読者です 読者をやめる 読者になる 読者になる

サイン波の多項式近似

多項式近似

サイン波にガウスノイズを加えた点から、元のサイン波を近似してみる。
これはPRMLという本の第1章の話で、こちらのサイトを参考にさせてもらった。
誤差
E(\mathbf{\omega}) = {1\over 2}\sum_{n=1}^N \{y(x_n,\mathbf{\omega})-t_n\}^2 = {1\over 2}\sum_{n=1}^N \(\sum_{j=0}^M \omega_jx_n ^j-t_n \)^2
を最小とする\mathbf{\omega}を求めれば近似多項式を作れる。最小を求めるには微分すれば良いので、
\begin{align} \frac{\partial E(\mathbf{\omega})}{\partial \omega_i} &= \sum_{n=1}^N \(\sum_{j=0}^M \omega_jx_n ^j-t_n\)\frac{\partial}{\partial \omega_i}\(\sum_{k=0}^M \omega_kx_n ^k-t_n\)\\ &= \sum_{n=1}^N \(\sum_{j=0}^M \omega_jx_n ^j-t_n\) x_n^i \\&=0 \end{align}
つまり
\sum_{n=1}^N \sum_{j=0}^M \omega_j x_n^j x_n^i = \sum_{n=1}^N t_n x_n^i
となるので、
\sum_{j=0}^M A_{ij}\omega_j=T_i
\begin{align}A_{ij}=\sum_{n=1}^N x_n^{i+j} &   T_i=\sum_{n=1}^N x_n^i t_n\end{align}
を解けばいい。

Eigen3

C++の行列計算にはEigen3を使う。公式サイト
めちゃくちゃ使いやすいけど、あんまりネットじゃ見かけない。何でかな。

テンプレートで書かれているので、ヘッダをインクルードするだけで使える。Qt CreatorならEigenを良い感じのところに解凍して、例えば人にプロジェクトを渡すときならINCLUDEPATH += ./eigenとするだけ。

実装

さっき求めた式をEigenで解くクラスを作る。
curve_fitting.h

#ifndef CURVEFITTING_H
#define CURVEFITTING_H

#include <Eigen/Dense>

class CurveFitting
{
public:
    explicit CurveFitting(int M);
    void MAP(const Eigen::VectorXd& x,
             const Eigen::VectorXd& t,
             const double lambda);
    double y(const double x) const;

private:
    Eigen::VectorXd W_;
    double M_;
};

#endif // CURVEFITTING_H

curve_fitting.cpp

#include "curvefitting.h"
#include <iostream>
using namespace Eigen;

CurveFitting::CurveFitting(int M) : M_(M)
{
}

void CurveFitting::MAP(const VectorXd& x,
                       const VectorXd& t,
                       const double lambda)
{
    MatrixXd A(M_+1, M_+1);
    for (int i = 0; i < M_+1; ++i) {
        for (int j = 0; j < M_+1; ++j) {
            double temp = x.array().pow(i+j).sum(); // A_ij
            if (i == j) {
                temp += lambda;
            }
            A(i, j) = temp;
        }
    }

    VectorXd T(M_+1);
    for (int i = 0; i < M_+1; ++i) {
        T(i) = (x.array().pow(i) * t.array()).sum();
    }

    W_ = A.colPivHouseholderQr().solve(T);
}

double CurveFitting::y(const double x) const
{
    double result = W_(0);
    for (int i = 1; i < M_+1; ++i) {
        result += W_(i) * pow(x, i);
    }
    return result;
}

lambdaはMAP推定に使うけど、とりあえず0.0にして無視しておく。
それでグラフの描画側は、
plot.cpp

#include "plot.h"
#include <qwt_plot_curve.h>
#include <qwt_legend.h>
#include <random>
#include "curvefitting.h"
#include <Eigen/Dense>
using namespace Eigen;

Plot::Plot(QWidget *parent) :
    QwtPlot(parent)
{
    setCanvasBackground(QColor(Qt::white));
    // 凡例
    insertLegend(new QwtLegend(), QwtPlot::BottomLegend);

    // 曲線の設定
    curve1_ = new QwtPlotCurve("sine");
    curve1_->setRenderHint(QwtPlotItem::RenderAntialiased);
    curve1_->setPen(QPen(Qt::red));
    curve1_->attach(this);

    dots1_ = new QwtPlotCurve("sine+noise");
    dots1_->setRenderHint(QwtPlotItem::RenderAntialiased);
    dots1_->setPen(QPen(Qt::blue, 5, Qt::SolidLine, Qt::RoundCap));
    dots1_->setStyle(QwtPlotCurve::Dots);
    dots1_->attach(this);

    curve2_ = new QwtPlotCurve("estimated sine");
    curve2_->setRenderHint(QwtPlotItem::RenderAntialiased);
    curve2_->setPen(QPen(Qt::green));
    curve2_->attach(this);

    // 曲線の描画
    plotCurve();
}

void Plot::plotCurve()
{
    const int kArraySize = 1000;
    const int kNumDots = 10;

    VectorXd x(kArraySize);
    VectorXd y(kArraySize);

    CurveFitting *curve_fitting = new CurveFitting(3); // 次数3

    std::mt19937 engine( static_cast<unsigned long>(time(0)) );
    std::normal_distribution<double> dist(0.0, 0.2);

    VectorXd xdots(kNumDots);
    VectorXd ydots(kNumDots);

    for (int i = 0; i < kArraySize; ++i) {
        x(i) = i / (kArraySize-1.0);
        y(i) = sin(2.0*M_PI*x(i));
    }

    for (int i = 0; i < kNumDots; ++i) {
        xdots(i) = i / (kNumDots-1.0);
        ydots(i) = sin(2.0*M_PI*xdots(i)) + dist(engine);
    }
//    xdots = VectorXd::LinSpaced(0.0, 1.0, 10);
//    ydots = xdots.unaryExpr([&](const double x) {
//        return sin(2.0*M_PI*x) + dist(engine);
//    });

    curve_fitting->MAP(xdots, ydots, 0.0);
    std::vector<double> model(kArraySize);
    for (int i = 0; i < kArraySize; ++i) {
        model[i] = curve_fitting->y(x(i));
    }

    curve1_->setSamples(&x[0], &y[0], kArraySize);
    dots1_->setSamples(&xdots[0], &ydots[0], kNumDots);
    curve2_->setSamples(&x[0], &model[0], kArraySize);

    replot();
    delete curve_fitting;
}

それで実行すると

こんな感じで元のサイン波を推定できている。