てれじょんのメモ帳

twitter(@television_met)に書ききれないことを投稿します。

最小二乗法を実装した話


これはなに?

 おはこんばんにちはてれじょんです。厳しい暑さが続いていますが皆さんお元気でしょうか?私は陸の孤島秋田に帰省しているのでそこまできつい暑さでは無いのですが、それでもまあ暑いです。死なない程度に生きていきたいものです。天気図も書きたいものです。(うちのスーパーコンピューターであるところの"1-LOUGH"*1 が秋田に無いので自明に録音ができなくて詰んでいます。助けて。)
 ところで皆さんはまちカドまぞくを知っていますか?悪いことは言わないので知らないヒトは今すぐ見てください。今日ツイッターを見ていたところ「2期決定!」というのを発見して特大の喜びをしたのですが、リークだと知り一瞬で真顔に戻りました。公式で発表されてないことを確認する前に拡散した自分も悪いんですが、好きな作品の喜ばしい発表に対して素直に喜べないのは本当に本当に本当に本当に悲しい限りなので、リーカーは潔く自決してほしいですね。私も自決するべきですが。
 さて本題ですが、今日はFortranの教科書*2にさらっと書かれていた最小二乗法を実装しました。点列を入力して偏微分して得られる連立方程式を解く……という簡単なステップですが、連立方程式を解くときにGauss-Jordan法を使ったり、係数行列を与える際に再帰呼び出しをしたりとで結構コードが長くなってしまったのでここに残しておきます。

どういう手順?

  1. m個のデータをインプットします。ここでは良いデータが見つからなかったので適当な多項式に誤差項を乱数で与えました。
  2. 多項式とデータの残差の二乗の総和を、求める多項式の係数で偏微分します。これでn+1本の連立方程式が得られます。
  3. 連立方程式をときます。ここではよく知られたGauss-Jordan法を使いました。

もっとくやしく

まずコードを貼っておきます。

module subprog
   use globals
   implicit none
contains 
   subroutine gauss_jordan (a0,x,b,n)
      !部分pivot選択ありです
      integer,intent(IN) :: n
      real(8),intent(IN) :: a0(n,n),b(n)
      real(8),intent(OUT) :: x(n)
      integer i,j,m
      real(8) a(n,n+1),am,tmp(n+1),aji

      a(1:n,1:n)=a0
      
      do i=1,n
         a(i,n+1)=b(i)
      enddo

      do i=1,n
         m=i
         am=a(i,i)
         do j=i+1,n
            if(abs(am)<abs(a(j,i)))then
               am=a(j,i)
               m=j
            endif
         enddo

         if(am==0.0d0) stop "matrix A is not invertible!"
         if(i/=m)then
            tmp(1:n+1)=a(i,1:n+1)
            a(i,1:n+1)=a(m,1:n+1)
            a(m,1:n+1)=tmp(1:n+1)
         endif

         !消去開始
         a(i,i:n+1)=a(i,i:n+1)/am
         a(i,i)=1.0d0

         do j=1,n
            if(i==j)cycle
            aji=a(j,i)
            a(j,i:n+1)=a(j,i:n+1)-aji*a(i,i:n+1)
            a(j,i)=0.0d0
         enddo
      enddo

      do i=1,n
         x(i)=a(i,n+1)
      enddo
   end subroutine gauss_jordan

   subroutine mat_test(n,a,b)
      integer n
      real(8),intent(INOUT) ::a(n,n),b(n)
      call random_seed
      call random_number(a(1:n,1:n))
      call random_number(b(1:n))
   end subroutine mat_test

   function fun(x) result(y)
      real(8),intent(IN) :: x
      real(8) y
      y=0.1d0*x**3-0.2d0*x**2+0.9d0*x+1.0d0
   end function fun

   recursive subroutine set_vec(x,y,r,n,m,tmp,now)
   integer,intent(IN) :: n,m,now
   real(8),intent(IN) :: x(m),y(m)
   real(8),intent(INOUT) :: tmp(m),r(n+1)
   integer i
   real(8) :: sum=0.0d0
   
   if(now==0)then
      do i=1,m
      sum=sum+y(i)
      enddo
      r(now+1)=sum
      tmp(1:m)=1.0d0
   else
      call set_vec(x,y,r,n,m,tmp,now-1)
      do i=1,m
         tmp(i)=tmp(i)*x(i)
         sum=sum+y(i)*tmp(i)
      enddo
      r(now+1)=sum
   endif
   end subroutine set_vec

   recursive subroutine set_mat(x,y,c,n,m,tmp,now)
      integer,intent(IN) :: n,m,now
      real(8),intent(IN) :: x(m),y(m)
      real(8),intent(INOUT) :: c(n+1,n+1),tmp(m)
      integer i,j
      real(8) :: sum=0.0d0
      
      if(now==0)then
         c(1,1)=dble(m)
         tmp(1:m)=1.0d0
      else
         call set_mat(x,y,c,n,m,tmp,now-1)
         do i=1,m
            tmp(i)=tmp(i)*x(i)
            sum=sum+tmp(i)
         enddo
         if(now>n)then
            do j=0,2*n-now
               c(now-n+1+j,n+1-j)=sum
            enddo
         else
            do j=0,now
               c(1+j,now+1-j)=sum
            enddo
         endif
      endif
   end subroutine 

end module subprog

program main
   use globals
   use subprog
   implicit none
   integer :: m,i,j,n,fo=11,fi=12
   real(8) dx,x,er,sm
   real(8),ALLOCATABLE :: xi(:),yi(:),c(:,:),b(:),tmp(:),a(:)
   
   write(*,*)"input step m"
   read(*,*)m
   dx=20.0d0/dble(m)

   open(fo,file="output.d")
   do i=1,m
      x=-10.0d0+dble(i)*dx
      call random_seed
      call random_number(er)
      write(fo,*)x,fun(x)+50.0d0*er
   enddo
   close(fo)

   open(fi,file="output.d")
   allocate(xi(m),yi(m))
   write(*,*)"input dimention n"
   read(*,*)n
   do i=1,m
      read(fi,*)xi(i),yi(i)
   enddo
   close(fi)

   !ここから係数行列Aと初期条件のベクトルを構築していく
   allocate(c(n+1,n+1),b(n+1),tmp(m),a(n+1))
   call set_vec(xi,yi,b,n,m,tmp,n)
   call set_mat(xi,yi,c,n,m,tmp,2*n)
   deallocate(tmp,xi,yi)

   !Gauss-Jordan法で連立方程式を解き、出力
   call gauss_jordan(c,a,b,n+1)
   do i=0,n
      write(*,*)i,':',a(i)
   enddo

   open(fo,file="output2.d")
   do i=1,m
      x=-10.0d0+dble(i)*dx
      sm=0.0d0
      do j=1,n+1
         sm=sm+a(j)*x**(j-1)
      enddo
      write(fo,*)x,sm
   enddo

end program main

メインプログラムの最初のところではまずoutput.dに対して点列を与えています。ここでは区間[-10,10]をm等分してそれぞれをx_1,x_2,...,x_mとし、さらに適当な関数gを与えてy_i=g(x_i)+err(errは乱数による誤差項)とすることでm個の点列を得ました。

次は連立方程式を出してくるステップです。ここで、求める近似関数をf(x)とすれば、データとの残差の総和R

 R=\sum_{i=1}^{m}\left(y_i-f(x_i)\right)^2

とかけます。ここではn次の多項式で近似することにしているのでf(x):=a_0+a_1x+a_2x^2+\cdots+a_nx^nとおけば

 R=\sum_{i=1}^m\left(y_i-a_0-a_1x_i-a_2x_i^2-\cdots-a_nx_i^n\right)^2

と書くことができます。この式をa_k (k=0,1,\cdots,n)を独立変数とする関数だと見做すと、Rが最小となるとき\partial R/ \partial a_k=0となります。(それはそう)(a_kを横軸にとって下に凸な放物線を考えるとそれはそうだなって気分になってくる)
なのでRを素直にa_k偏微分してやることで


\sum_{i=1}^m 2\left(y_i-a_0-a_1x_i-a_2x_i^2-\cdots-a_nx_i^n\right)(-x_i^k)=0 \\
\therefore~~\sum_{i=1}^m \sum_{j=0}^n a_jx_i^{k+j}=\sum_{i=1}^m y_ix_i^k~~~~\cdots\cdots(1)


を得ます。ここでベクトル\bm{b}=(b_k)~(k=0,1,\cdots,n)(n+1)次正方行列C=(c_{k,j})~(k,j=0,1,\cdots,n)

 b_k:=\sum_{i=1}^my_ix_i^k \\
c_{k,j}:=\sum_{i=1}^mx_i^{k+j}

で定義すれば、先の(1)はn+1連立方程式

C\bm{a}=\bm{b}

と同値になります。後はコレをGauss-Jordan法を用いて解くだけです。
Gauss-Jordan法とはなんぞやという話ですが、かっこいい名前が付いているもののやってることは普段連立方程式を解くときに使うのと同じアルゴリズムです。拡大係数行列を導入して左上から掃き出していきます。例えば3元1次連立方程式を解くときは下の図1のようになります。(tex打ちサボりました。)

f:id:jhonson1415:20200827153829p:plain
図1.掃き出し法で連立方程式を解く様子

更にここでは部分pivot選択というのをしています。掃き出し法では、ある行aを一つ選んで、その行aの定数倍を他の行bに足すことで他の行bの要素を消していますよね。実は、このとき絶対値が小さい行をaに選んでしまうと計算誤差が大きくなってしまいます。それを防ぐために、一度他の行を全部見てから、その中で最も絶対値が大きいものを選んでくるというのが部分pivot選択です。詳細は下の図2を参照です。図2では、有効数字4桁として5桁目を四捨五入して計算しました。

f:id:jhonson1415:20200827154120p:plain
図2.部分ピボット選択しなきゃいけない理由

また、コードの通り、係数行列Cとベクトル\bm{b}を与えるときには再帰subroutineを使っています。特に前者のsubroutineであるところのset_matの方では、(k+j)が一定となるような要素から埋めていくことで大幅に計算量を削減しています。

解析結果

g(x):=0.1x^3-0.2x^2+0.9x+1.0~~,m=300としたときの計算結果は以下のようになりました。上から順にn=3,10,100次の結果を図示しておきます。

f:id:jhonson1415:20200827155446p:plain
図3.n=3のときの解析結果
f:id:jhonson1415:20200827155516p:plain
図4.n=10のときの解析結果
f:id:jhonson1415:20200827155533p:plain
図5.n=100のときの解析結果

もともとの多項式が3次だったので、n=3のあたりではそこそこ精度がよかったのですが、n=6のときはなんか誤差がでっかくなっちゃってるのにも気づきました。n=7にすると落ち着いて来てました。最小二乗法完全理解マンがいたらコレの原因を教えてくれると嬉しいです。

f:id:jhonson1415:20200827155610p:plainf:id:jhonson1415:20200827155614p:plain
図6.n=6,7のときの解析結果

さいごに

 実際に解析で使えるようなプログラムを実装するのは楽しいですね。次はLaplace方程式の数値解法を実装して記事にできたら良いなと思ってます。また、Gauss-Jordan法は式の数が10万本くらいになると崩壊してくるので、別のアルゴリズムも実装していきたいですね。あと本当に最後に、まちカドまぞくの2期、本当に期待しています。明日辺り発狂する僕が見られると思います。

*1:一郎と読みます、ノートパソコンは次郎です

*2:第2版 数値計算のためのFortran90/95プログラミング入門、牛島省著、森北出版