サイン波の多項式近似
多項式近似
サイン波にガウスノイズを加えた点から、元のサイン波を近似してみる。
これはPRMLという本の第1章の話で、こちらのサイトを参考にさせてもらった。
誤差
を最小とするを求めれば近似多項式を作れる。最小を求めるには微分すれば良いので、
つまり
となるので、
を解けばいい。
Eigen3
C++の行列計算にはEigen3を使う。公式サイト
めちゃくちゃ使いやすいけど、あんまりネットじゃ見かけない。何でかな。
テンプレートで書かれているので、ヘッダをインクルードするだけで使える。Qt CreatorならEigenを良い感じのところに解凍して、例えば人にプロジェクトを渡すときならINCLUDEPATH += ./eigenとするだけ。
実装
さっき求めた式をEigenで解くクラスを作る。
curve_fitting.h
#ifndef CURVEFITTING_H #define CURVEFITTING_H #include <Eigen/Dense> class CurveFitting { public: explicit CurveFitting(int M); void MAP(const Eigen::VectorXd& x, const Eigen::VectorXd& t, const double lambda); double y(const double x) const; private: Eigen::VectorXd W_; double M_; }; #endif // CURVEFITTING_H
curve_fitting.cpp
#include "curvefitting.h" #include <iostream> using namespace Eigen; CurveFitting::CurveFitting(int M) : M_(M) { } void CurveFitting::MAP(const VectorXd& x, const VectorXd& t, const double lambda) { MatrixXd A(M_+1, M_+1); for (int i = 0; i < M_+1; ++i) { for (int j = 0; j < M_+1; ++j) { double temp = x.array().pow(i+j).sum(); // A_ij if (i == j) { temp += lambda; } A(i, j) = temp; } } VectorXd T(M_+1); for (int i = 0; i < M_+1; ++i) { T(i) = (x.array().pow(i) * t.array()).sum(); } W_ = A.colPivHouseholderQr().solve(T); } double CurveFitting::y(const double x) const { double result = W_(0); for (int i = 1; i < M_+1; ++i) { result += W_(i) * pow(x, i); } return result; }
lambdaはMAP推定に使うけど、とりあえず0.0にして無視しておく。
それでグラフの描画側は、
plot.cpp
#include "plot.h" #include <qwt_plot_curve.h> #include <qwt_legend.h> #include <random> #include "curvefitting.h" #include <Eigen/Dense> using namespace Eigen; Plot::Plot(QWidget *parent) : QwtPlot(parent) { setCanvasBackground(QColor(Qt::white)); // 凡例 insertLegend(new QwtLegend(), QwtPlot::BottomLegend); // 曲線の設定 curve1_ = new QwtPlotCurve("sine"); curve1_->setRenderHint(QwtPlotItem::RenderAntialiased); curve1_->setPen(QPen(Qt::red)); curve1_->attach(this); dots1_ = new QwtPlotCurve("sine+noise"); dots1_->setRenderHint(QwtPlotItem::RenderAntialiased); dots1_->setPen(QPen(Qt::blue, 5, Qt::SolidLine, Qt::RoundCap)); dots1_->setStyle(QwtPlotCurve::Dots); dots1_->attach(this); curve2_ = new QwtPlotCurve("estimated sine"); curve2_->setRenderHint(QwtPlotItem::RenderAntialiased); curve2_->setPen(QPen(Qt::green)); curve2_->attach(this); // 曲線の描画 plotCurve(); } void Plot::plotCurve() { const int kArraySize = 1000; const int kNumDots = 10; VectorXd x(kArraySize); VectorXd y(kArraySize); CurveFitting *curve_fitting = new CurveFitting(3); // 次数3 std::mt19937 engine( static_cast<unsigned long>(time(0)) ); std::normal_distribution<double> dist(0.0, 0.2); VectorXd xdots(kNumDots); VectorXd ydots(kNumDots); for (int i = 0; i < kArraySize; ++i) { x(i) = i / (kArraySize-1.0); y(i) = sin(2.0*M_PI*x(i)); } for (int i = 0; i < kNumDots; ++i) { xdots(i) = i / (kNumDots-1.0); ydots(i) = sin(2.0*M_PI*xdots(i)) + dist(engine); } // xdots = VectorXd::LinSpaced(0.0, 1.0, 10); // ydots = xdots.unaryExpr([&](const double x) { // return sin(2.0*M_PI*x) + dist(engine); // }); curve_fitting->MAP(xdots, ydots, 0.0); std::vector<double> model(kArraySize); for (int i = 0; i < kArraySize; ++i) { model[i] = curve_fitting->y(x(i)); } curve1_->setSamples(&x[0], &y[0], kArraySize); dots1_->setSamples(&xdots[0], &ydots[0], kNumDots); curve2_->setSamples(&x[0], &model[0], kArraySize); replot(); delete curve_fitting; }