# This program is used to choose the best value of M based on DIC. 
# Since the clustering scheme is sub-clusters within each variable 
# cluster, to simplify the process of inferring M and improve the power, 
# we could cluster the variables using all subjects. 

rm(list=ls(all.names=T))

library(MSBVAR)
library(MCMCpack)

basefn<-"cluster400" #base file name
simulation<-100
nloop<-200 # number of loops
k<-10

dzeta<-function(par1,par2,par3){ #return in log scale
# par1: zeta, par2: sum of log Pi, par3: length of Pi
  if(par1>1) tmp<-2*log(par1)
    else tmp<-0
  lgamma(par3*par1)+{par1-1}*par2-lgamma(par1)*par3-tmp
  }

tob<-function(num){
  ret<-numeric()
  i<-1
  while(num>0){
    if(num%%2==1) ret<-c(ret,i)
    num<-floor(num/2)
    i<-i+1
  }
  ret
  }

big<-function(data.idx){

iteration<-function(draw=F,prev){
  Pi<-prev$Pi
  dk<-prev$dk
  pos<-prev$pos
  pat<-prev$pat
  pat.count<-prev$pat.count
  bhat<-prev$bhat
  idx.bm<-prev$idx.bm
  index<-prev$index
  Sigma<-prev$Sigma
  sigma<-prev$sigma
  zeta<-prev$zeta
  
  dk.p<-dk
  xbeta<-xx%*%t(bhat)
  for(i in 1:k){
    Pi[,i]<-rdirichlet(1,rep(zeta,M)+dk[,i])
    for(j in 1:M)
      pdk[j]<-log(Pi[j,i])-.5*sum({y[,i]-xbeta[,j]}^2)/Sigma[i,i] #Actually the log pdk
    pdk<-pdk-max(pdk)
    dk[,i]<-rmultinom(1,1,exp(pdk))
    }
    
  # Check wellness of clustering
  if(any(apply(dk,1,sum)==0)) dk<-dk.p

  # the binary rep of current clustering
  bin<-apply(dk*2^(t(bin.pos*t(dk))),1,sum)
  dk<-dk[order(bin),]
  tmp<-sort(bin)[-M]
  newpat<-T
  if(pat.count>0){
    if(all(pat[1:(M-1),pos]==tmp)) newpat<-F
    else{
      for(i in 1:pat.count){
        if(all(pat[1:(M-1),i]==tmp)){
          newpat<-F
          pos<-i
          break;
          }
        }
      }
    }
  if(newpat){
    pat.count<-pat.count+1
    if(draw){
      idx.bm<-c(idx.bm,pat.count)
      pat<-cbind(pat,c(tmp,1,index))
      }
    else{
      pat<-cbind(pat,c(tmp,0,pat.count))
      }
    }
  else{
    if(draw){
      pat[M,pos]<-pat[M,pos]+1
      pat[M+1,pos]<-index
      idx.bm<-c(idx.bm,pos)
      }
    }
  
  dic<-0
  
  for(m in 1:M){ # there are M clusters
    # responses are in the current cluster
    clu<-which(dk[m,]==1)
    kM<-length(clu)
    iV<-1/c(rep(sigma.b,fixed),rep(sigma,nknot))
    # take the sub-matrix from Sigma
    Sigma.m<-Sigma[clu,clu]
    iSigma.m<-try(solve(Sigma.m),T)
    if(is.character(iSigma.m)) iSigma.m<-diag(abs(1/diag(Sigma.m)))
        
    tmp2<-matrix(0,nrow=b.len,ncol=b.len)
    tmp3<-matrix(0,nrow=b.len,ncol=1)
    for(j in 1:n){
      xx.t<-matrix(rep(xx[j,],kM),nrow=kM,byrow=T)
      y.t<-y[j,clu]
      tmp<-t(xx.t)%*%iSigma.m
      tmp2<-tmp2+tmp%*%xx.t
      tmp3<-tmp3+tmp%*%y.t
      }
    tmp<-solve(tmp2+diag(iV))
    bhat[m,]<-rmultnorm(1,tmp%*%tmp3,tmp)
    sigma<-1/rgamma(1,a+nknot/2,scale=c+sum(bhat[m,-(1:fixed)]^2)/2)
    
    tmp<-xx[,1:fixed]%*%bhat[m,1:fixed]
    err<-y[,clu]-matrix(rep(tmp,kM),nrow=n)
    Sigma[clu,clu]<-Sigma.m<-riwish(n+nu,t(err)%*%err+S[clu,clu])
    
    if(draw){
      for(j in 1:n){
        dic<-dic+t(err[j,])%*%solve(Sigma.m)%*%err[j,]
        }
      if(kM==1) 
      {
          if (Sigma.m<=0) 
          { 
              Sigma.m<-0
              dic<-dic-9999999
          }
          else
          {
              dic<-dic+n*log(abs(Sigma.m))
          }
      }
      else 
      {
         if (det(Sigma.m)<=0)
         {
              dic<-dic-9999999
         }
         else
         {
             dic<-dic+ n*log(abs(det(Sigma.m)))
         }
      }
    }
  }
  
  # MH to update zeta with jumping dist lognormal
  sum.log.P<-sum(log(Pi[,k]))
  zeta.l<-log(zeta)
  can<-rlnorm(1,zeta.l,v)
  ratio<-min(exp(dzeta(can,sum.log.P,M)-dzeta(zeta,sum.log.P,M))/ (dlnorm(can,zeta.l,v)/dlnorm(zeta,zeta.l,v)),1)
  if(is.nan(ratio)){
    ratio<-0
    #ratio.nan<-c(ratio.nan,index)
    }
  if(runif(1)<ratio){
    zeta<-can
    }
  
  if(draw){
    index<-index+1
    }
  return(list(dic=dic, Pi=Pi, dk=dk, pos=pos, pat=pat, pat.count=pat.count, bhat=bhat, idx.bm=idx.bm, index=index, Sigma=Sigma, sigma=sigma, zeta=zeta))
  }
  
  x<-pst[,1,data.idx]
  y<-pst[,1:k+1,data.idx]
  
  M.optim<-3
  dic.optim<-Inf
  
  #construct the design matrix X
  xx<-matrix(0,nrow=n,ncol=b.len)
  xx[,1:fixed]<-cbind(1,x,x^2)
  for(i in 1:nknot)
    xx[,i+fixed]<-as.numeric(knot[i+1]<x)*(x-knot[i+1])^2
  
  for(M in 2:5){
    Pi<-matrix(1,nrow=M,ncol=k)/M
    sigma<-1 # var of coefficients of splines
    dk<-matrix(rep(diag(1,M),ceiling(k/M)),nrow=M)[,1:k]
    bhat<-matrix(rep(0,b.len*M),nrow=M) # each row of bhat corresponds to a cluster
    Sigma<-riwish(nu,S)
    pdk<-rep(1,M) # prob of drawing posterior dk
    zeta<-.75
  
    idx.bm<-numeric()
    index<-1
    pat.count<-0
    pat<-numeric() # pattern vector to store the cluster info for response variables
    pos<-1
    dic<-numeric()
    
    tmp.cluster<-list(Pi=Pi, dk=dk, pos=pos, pat=pat, pat.count=pat.count, bhat=bhat, idx.bm=idx.bm, index=index, Sigma=Sigma, sigma=sigma, zeta=zeta)

    # burn in
    for(i in 1:b.nloop)
      tmp.cluster<-iteration(draw=F,tmp.cluster)

    # draw
    dist.min<-n*(n+1)*k/2
    pat.f.dep<-numeric()
    for(i in 1:d.nloop){
      tmp.cluster<-iteration(draw=T,tmp.cluster)
      dic<-c(dic,tmp.cluster$dic)
      }  
  
    #Calculating DIC
    dic.avg<-mean(dic)
    dic.var<-var(dic)
    dic.f<-dic.avg+dic.var/2
    cat("simu",data.idx,"\tM=",M,"\tDIC=",dic.f,"\tDIC.avg=",dic.avg,"\tDIC.var=",dic.var,"\n",file=fn,append=T)
    if(dic.f<dic.optim){
      dic.optim<-dic.f
      M.optim<-M
      }
    }
  list(M=M.optim,dic=dic.optim)
  }

################
# main program #
################
loadfn<-paste(basefn,".Rdata",sep="")
#read data
load(loadfn)

b.nloop<-nloop*.3 # number of burn in iterations
d.nloop<-nloop-b.nloop # # number of draws

fn<-paste(basefn,"_noDP_dic.txt",sep="")
cat("Simu\tM.optim\tDIC\n",file=fn,append=F)

# give initial values, M: # of clusters, knot: knots vector(10 knots),
# V: covariance of bhat (iV: its inverse), v: jumping dist for zeta lognormal(zeta.prev,v)
# pattern: the pattern of clustering one dimension with last one be the frequency

bin.pos<-0:(k-1)
nknot<-15 # number of knots

nknot1<-nknot+2
nu<-k+2
a<-c<-0.5 # pars for Inv Gamma dist for sigma
sigma.b<-100 # var of beta_m(fixed)
S<-diag(rep(.5,k))
v<-1.8

fixed<-3
b.len<-fixed+nknot
knot<-seq(from=rang.x[1],to=rang.x[2],length.out=nknot1)

set.seed(2013)

parameter.list<-as.list(1:simulation)
parameter.list<-lapply(1:simulation,big)
save(parameter.list,file=paste(basefn,"_res.Rdata",sep=""))

#end
