package rsalgos;
/**
 * Implementation of the AKG algorithm
 * @author Bogumil Kaminski & Przemyslaw Szufel
 */
import static java.lang.Math.abs;
import static java.lang.Math.sqrt;

import java.util.Arrays;

import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.distribution.NormalDistribution;

import simtools.CalculateEv;
import clustersim.RSpoint;

public class AKG_RSalgorithm implements AsynchcronousRSalgorithm {

	private double sigmas2[] = null; //priors  
	private double mu[] = null; //priors
	private int N;
	private UnivariateFunction f = new UnivariateFunction() {
		private NormalDistribution normal = new NormalDistribution();
		@Override
		public double value(double x) {			
			return x*normal.cumulativeProbability(x)+normal.density(x);
		}
	};
	private final NormalDistribution normal = new NormalDistribution();
	@Override
	public RSpoint getPointInit_k_0(RSpoint[] points, int w, int pointCounter) {
		if (sigmas2==null) {
			N = points.length;
			sigmas2 = new double[N];
			Arrays.fill(sigmas2, 1); //initial beliefs equal to real values
			mu =  new double[N];
			Arrays.fill(mu, 0); //initial beliefs equal to 0
		}
		//in initialization just evenly split points across available workers, since all priors are equal and no other info is available
		return points[w % points.length];
		
	}
	
	@Override
	public RSpoint getNextPoint(double y,RSpoint point,final RSpoint[] points, int w, int pointCounter, boolean debug) {
		final int N = points.length;
		final double sigmas2ki[] = new double[N]; //sigma^2 w at time k
		final double sigmas2ki_si[] = new double[N]; //sigma^2 with s
		final double sigmas2ki_si_p1[] = new double[N]; //sigma^2 with s+1
		final double muk[] = new double[N]; //mi at time k
		double sigma2e = 1;
		for (int i=0;i<N;i++) {
			if (point.i == i) {		
				sigmas2ki[i] = 1/(1/sigmas2[i] + 1/sigma2e);
				muk[i] = (mu[i]/sigmas2[i]+y/sigma2e)*sigmas2ki[i];
			} else {
				sigmas2ki[i] = sigmas2[i];
				muk[i] = mu[i];
			}
			sigmas2ki_si[i] =   sigmas2ki[i]-1/(1/sigmas2ki[i]+points[i].getS()/sigma2e);
			sigmas2ki_si_p1[i] = sigmas2ki[i]-1/(1/sigmas2ki[i]+(points[i].getS()+1)/sigma2e);
		}
		
		final double[] muk_maxOfOthers = Tools.maxOfOtherElems(muk);
		int s[] = new int[N];
		for (int i=0;i<N;i++) s[i] = points[i].getS();
		final double[] muk_maxOfOthers_s = Tools.maxOfOtherElems(muk,s);
		
		
		int xks = -1;
		double bestArg = -1e99;
		double arg;
		double akgp[] = new double[N];
		for (int i=0;i<N;i++) {
			arg = calculateAkgEv(sigmas2ki_si_p1, muk, muk_maxOfOthers_s,s,sigmas2ki_si, i);		
			akgp[i] = arg;
			if (arg > bestArg) {
				bestArg = arg;
				xks = i;			
			}
		}
		if (debug) {
			System.out.println("muk=c("+Tools.str(muk)+")");
			System.out.println("sigmas2ki_si_p1= c("+Tools.str(sigmas2ki_si_p1)+");wavesigma <- function(i) {sqrt(sigmas2ki_si_p1[i])} ");
			System.out.println("usk_maxOfOthers=c("+Tools.str(muk_maxOfOthers)+")");
			System.out.println("usk_maxOfOthers_s=c("+Tools.str(muk_maxOfOthers_s)+")");
			System.out.println("pc="+pointCounter+"; args="+Tools.str(akgp));
		}
		sigmas2=sigmas2ki; 
		mu = muk;
		return points[xks];
	}

	public double kalculateKgEv(final double[] sigmas2ki_si_p1, final double[] muk, final double[] muk_maxOfOthers, final int i) {
		double v = sqrt(sigmas2ki_si_p1[i])*f.value(-abs( (muk[i]-muk_maxOfOthers[i])/sqrt(sigmas2ki_si_p1[i]) ));
		return v;
	}

	public double calculateAkgEv(final double[] sigmas2ki_si_p1, final double[] muk, final double[] muk_maxOfOthers_s, final int s[], final double[] sigmas2ki_si,  final int i) {
		UnivariateFunction F = new UnivariateFunction() {				
			@Override
			public double value(double x) {		
				if (x < muk_maxOfOthers_s[i]) return 0.0;
				double res = 1;				
				for (int j=0;j<muk.length;j++) {						
					if (j==i ) { //considering si+1							
						res *= normal.cumulativeProbability((x-muk[i])/sqrt(sigmas2ki_si_p1[i]));							
					} else if (s[j] > 0) { //only in this case sigmas2ki_si exists
						res *= normal.cumulativeProbability((x-muk[j])/sqrt(sigmas2ki_si[j]));							
					}
				}
				return res;
			}
		};
		//double res1= CalculateEv.calculateEv_old_stepH(F,-15,15,1.0/65536);
		double res1 = CalculateEv.calculateEV_KG(F);
		return res1;
	}
}
