# +----------------------------------------------------+
# |                  Lekce 5 - demo 4                  |
# +----------------------------------------------------+
#  METHOD OF CONJUGATED GRADIENTS 
#   
#  PROBLEM:      Speed of convergence of CG for matrixes
#  THIS VERSION:  matrixes with rescribed eigenvalues
#    

import numpy as np
#import scipy.linalg as linalg
import matplotlib.pyplot as plt

# Matrix generator: cluster eigval
def GenMat(N):
    q=np.linalg.qr(np.random.rand(N,N))[0]
    d=np.random.rand(N)
    k=3
    l=20
    d[:l]=d[:l]+1.0
    d[l:-k]=d[l:-k]+50.0
    d[-k:]=d[-k:]+2500.0
    A=np.dot(q.transpose(),np.dot(np.diag(d),q))
    return A

# Matrix generator: Bad eigval - geometric
def GenMat2(N):
    q=np.linalg.qr(np.random.rand(N,N))[0]
    d=np.zeros(N)
    d[0]=N
    c=N**(1/N)
    for i in range(N-1):
        d[i+1]=d[i]/c        
    A=np.dot(q.transpose(),np.dot(np.diag(d),q))
    return A

# Matrix generator: arithmetic
def GenMat3(N):
    q=np.linalg.qr(np.random.rand(N,N))[0]
    d=np.zeros(N)
    for i in range(N):
        d[i]=i+1.0        
    A=np.dot(q.transpose(),np.dot(np.diag(d),q))
    return A

#CG solver:
def CGsolve(A,b,n):
  #Exaxt solution for cheking the convergence
    xe=np.linalg.solve(A,b)
    er=np.zeros(2*n)
  #CG algorithm
    x=np.zeros(n)
    r=b
    p=r
    gm=np.dot(r,r)
    for i in range(2*n):
        w=np.dot(A,p)
        al=gm/np.dot(p,w)
        x=x+al*p
        r=r-al*w
        gm_old=gm
        gm=np.dot(r,r)
        be=gm/gm_old
        p=r+be*p
        er[i]=np.sqrt(np.dot((x-xe),(x-xe)))
    return x,er

N=500
b=np.ones(N)
A=GenMat(N)
kap0=np.linalg.cond(A)
print('cond=',kap0)
xx, e0 =CGsolve(A,b,N)
A=GenMat2(N)
kap2=np.linalg.cond(A)
print('cond=',kap2)
xx, e2 =CGsolve(A,b,N)
A=GenMat3(N)
kap3=np.linalg.cond(A)
print('cond=',kap3)
xx, e3 =CGsolve(A,b,N)

it=170
plt.ylim((-16,5))
plt.plot(np.linspace(1,it,it),np.log10(e0[:it]),'go')
plt.plot(np.linspace(1,it,it),np.log10(e0[:it]),'g')
plt.plot(np.linspace(1,it,it),np.log10(2*e0[0]*((np.sqrt(kap0)-1)/(np.sqrt(kap0)+1))**np.linspace(0,it-1,it)),'g--')
plt.plot(np.linspace(1,it,it),np.log10(e2[:it]),'bo')
plt.plot(np.linspace(1,it,it),np.log10(e2[:it]),'b')
plt.plot(np.linspace(1,it,it),np.log10(2*e2[0]*((np.sqrt(kap2)-1)/(np.sqrt(kap2)+1))**np.linspace(0,it-1,it)),'b--')
plt.plot(np.linspace(1,it,it),np.log10(e3[:it]),'ro')
plt.plot(np.linspace(1,it,it),np.log10(e3[:it]),'r')
plt.plot(np.linspace(1,it,it),np.log10(2*e3[0]*((np.sqrt(kap3)-1)/(np.sqrt(kap3)+1))**np.linspace(0,it-1,it)),'r--')
plt.show()
