最小二乗法を実装した話
おはこんばんにちはてれじょんです。厳しい暑さが続いていますが皆さんお元気でしょうか?私は陸の孤島秋田に帰省しているのでそこまできつい暑さでは無いのですが、それでもまあ暑いです。死なない程度に生きていきたいものです。天気図も書きたいものです。(うちのスーパーコンピューターであるところの"1-LOUGH"*1 が秋田に無いので自明に録音ができなくて詰んでいます。助けて。)
ところで皆さんはまちカドまぞくを知っていますか?悪いことは言わないので知らないヒトは今すぐ見てください。今日ツイッターを見ていたところ「2期決定!」というのを発見して特大の喜びをしたのですが、リークだと知り一瞬で真顔に戻りました。公式で発表されてないことを確認する前に拡散した自分も悪いんですが、好きな作品の喜ばしい発表に対して素直に喜べないのは本当に本当に本当に本当に悲しい限りなので、リーカーは潔く自決してほしいですね。私も自決するべきですが。
さて本題ですが、今日はFortranの教科書*2にさらっと書かれていた最小二乗法を実装しました。点列を入力して偏微分して得られる連立方程式を解く……という簡単なステップですが、連立方程式を解くときにGauss-Jordan法を使ったり、係数行列を与える際に再帰呼び出しをしたりとで結構コードが長くなってしまったのでここに残しておきます。
どういう手順?
もっとくやしく
まずコードを貼っておきます。
module subprog use globals implicit none contains subroutine gauss_jordan (a0,x,b,n) !部分pivot選択ありです integer,intent(IN) :: n real(8),intent(IN) :: a0(n,n),b(n) real(8),intent(OUT) :: x(n) integer i,j,m real(8) a(n,n+1),am,tmp(n+1),aji a(1:n,1:n)=a0 do i=1,n a(i,n+1)=b(i) enddo do i=1,n m=i am=a(i,i) do j=i+1,n if(abs(am)<abs(a(j,i)))then am=a(j,i) m=j endif enddo if(am==0.0d0) stop "matrix A is not invertible!" if(i/=m)then tmp(1:n+1)=a(i,1:n+1) a(i,1:n+1)=a(m,1:n+1) a(m,1:n+1)=tmp(1:n+1) endif !消去開始 a(i,i:n+1)=a(i,i:n+1)/am a(i,i)=1.0d0 do j=1,n if(i==j)cycle aji=a(j,i) a(j,i:n+1)=a(j,i:n+1)-aji*a(i,i:n+1) a(j,i)=0.0d0 enddo enddo do i=1,n x(i)=a(i,n+1) enddo end subroutine gauss_jordan subroutine mat_test(n,a,b) integer n real(8),intent(INOUT) ::a(n,n),b(n) call random_seed call random_number(a(1:n,1:n)) call random_number(b(1:n)) end subroutine mat_test function fun(x) result(y) real(8),intent(IN) :: x real(8) y y=0.1d0*x**3-0.2d0*x**2+0.9d0*x+1.0d0 end function fun recursive subroutine set_vec(x,y,r,n,m,tmp,now) integer,intent(IN) :: n,m,now real(8),intent(IN) :: x(m),y(m) real(8),intent(INOUT) :: tmp(m),r(n+1) integer i real(8) :: sum=0.0d0 if(now==0)then do i=1,m sum=sum+y(i) enddo r(now+1)=sum tmp(1:m)=1.0d0 else call set_vec(x,y,r,n,m,tmp,now-1) do i=1,m tmp(i)=tmp(i)*x(i) sum=sum+y(i)*tmp(i) enddo r(now+1)=sum endif end subroutine set_vec recursive subroutine set_mat(x,y,c,n,m,tmp,now) integer,intent(IN) :: n,m,now real(8),intent(IN) :: x(m),y(m) real(8),intent(INOUT) :: c(n+1,n+1),tmp(m) integer i,j real(8) :: sum=0.0d0 if(now==0)then c(1,1)=dble(m) tmp(1:m)=1.0d0 else call set_mat(x,y,c,n,m,tmp,now-1) do i=1,m tmp(i)=tmp(i)*x(i) sum=sum+tmp(i) enddo if(now>n)then do j=0,2*n-now c(now-n+1+j,n+1-j)=sum enddo else do j=0,now c(1+j,now+1-j)=sum enddo endif endif end subroutine end module subprog program main use globals use subprog implicit none integer :: m,i,j,n,fo=11,fi=12 real(8) dx,x,er,sm real(8),ALLOCATABLE :: xi(:),yi(:),c(:,:),b(:),tmp(:),a(:) write(*,*)"input step m" read(*,*)m dx=20.0d0/dble(m) open(fo,file="output.d") do i=1,m x=-10.0d0+dble(i)*dx call random_seed call random_number(er) write(fo,*)x,fun(x)+50.0d0*er enddo close(fo) open(fi,file="output.d") allocate(xi(m),yi(m)) write(*,*)"input dimention n" read(*,*)n do i=1,m read(fi,*)xi(i),yi(i) enddo close(fi) !ここから係数行列Aと初期条件のベクトルを構築していく allocate(c(n+1,n+1),b(n+1),tmp(m),a(n+1)) call set_vec(xi,yi,b,n,m,tmp,n) call set_mat(xi,yi,c,n,m,tmp,2*n) deallocate(tmp,xi,yi) !Gauss-Jordan法で連立方程式を解き、出力 call gauss_jordan(c,a,b,n+1) do i=0,n write(*,*)i,':',a(i) enddo open(fo,file="output2.d") do i=1,m x=-10.0d0+dble(i)*dx sm=0.0d0 do j=1,n+1 sm=sm+a(j)*x**(j-1) enddo write(fo,*)x,sm enddo end program main
メインプログラムの最初のところではまずoutput.dに対して点列を与えています。ここでは区間]をm等分してそれぞれをとし、さらに適当な関数を与えてとすることで個の点列を得ました。
次は連立方程式を出してくるステップです。ここで、求める近似関数をとすれば、データとの残差の総和は
とかけます。ここでは次の多項式で近似することにしているのでとおけば
と書くことができます。この式をを独立変数とする関数だと見做すと、が最小となるときとなります。(それはそう)(を横軸にとって下に凸な放物線を考えるとそれはそうだなって気分になってくる)
なのでを素直にで偏微分してやることで
を得ます。ここでベクトルと次正方行列を
で定義すれば、先の(1)は元連立方程式
と同値になります。後はコレをGauss-Jordan法を用いて解くだけです。
Gauss-Jordan法とはなんぞやという話ですが、かっこいい名前が付いているもののやってることは普段連立方程式を解くときに使うのと同じアルゴリズムです。拡大係数行列を導入して左上から掃き出していきます。例えば3元1次連立方程式を解くときは下の図1のようになります。(tex打ちサボりました。)
更にここでは部分pivot選択というのをしています。掃き出し法では、ある行aを一つ選んで、その行aの定数倍を他の行bに足すことで他の行bの要素を消していますよね。実は、このとき絶対値が小さい行をaに選んでしまうと計算誤差が大きくなってしまいます。それを防ぐために、一度他の行を全部見てから、その中で最も絶対値が大きいものを選んでくるというのが部分pivot選択です。詳細は下の図2を参照です。図2では、有効数字4桁として5桁目を四捨五入して計算しました。
また、コードの通り、係数行列とベクトルを与えるときには再帰subroutineを使っています。特に前者のsubroutineであるところのset_matの方では、が一定となるような要素から埋めていくことで大幅に計算量を削減しています。
解析結果
としたときの計算結果は以下のようになりました。上から順に次の結果を図示しておきます。
もともとの多項式が3次だったので、のあたりではそこそこ精度がよかったのですが、のときはなんか誤差がでっかくなっちゃってるのにも気づきました。にすると落ち着いて来てました。最小二乗法完全理解マンがいたらコレの原因を教えてくれると嬉しいです。