/**
 * This class represents a complex matrix
 *
 * @author Jonas L. Jensen
 **/

import java.lang.Math;

public class Matrix {
    int m, n;
    ComplexNumber[][] A;

    /**
     * Initializes a m x n matrix, containing only zeros
     *
     * @param m height of matrix
     * @param n width of matrix
     **/
    public Matrix(int m, int n) {
	this.m = m;
	this.n = n;
	A = new ComplexNumber[m+1][n+1];

	for(int i=1; i<=m; i++)
	    for(int j=1; j<=n; j++)
		setValue(i, j, new ComplexNumber());
    }


    /**
     * Returns the n x n identity matrix, with one's in the diagonal 
     * and zeros in every other entry. Eg. identity(3) returns:
     * <code><p>[  1.0   0.0   0.0   ]
     * <p>[  0.0   1.0   0.0   ]
     * <p>[  0.0   0.0   1.0   ]</code>
     *
     * @param n width of matrix
     * @return the n x n identity matrix
     **/
    public static Matrix identity(int n) {
	Matrix I = new Matrix(n, n);
	
	for(int i=1; i<=n; i++)
	    I.setValue(i, i, ComplexNumber.ONE);

	return I;
    }


    /**
     * Returns an array of size 2, containing height and width of the matrix,
     * in that order.
     *
     * @return { width , height } of matrix
     **/
    public int[] getSize() {
	int[] a = { m , n };
	return a;
    }


    /**
     * Adds 2 matrices of the same size, by adding each entry.
     * B must have same size as this. If not, an exception is thrown.
     * 
     * @throws InvalidSizeException If B don't have the same size as this
     * @param  B the matrix to add to this
     * @return the sum of this and B
     **/
    public Matrix add(Matrix B) throws InvalidSizeException {
	if ((B.getSize()[0] == m) && (B.getSize()[1] == n)) {
	    Matrix C = new Matrix(m,n);

	    for(int i=1; i<=m; i++)
		for(int j=1; j<=n; j++)
		    C.setValue(i, j, this.getValue(i,j).add(B.getValue(i,j)));

	    return C;

	} else {
	    throw new InvalidSizeException("matrices must be same size to add...");
	}
    }
    


    /**
     * Sets the (i,j)'th entry in this to val. 
     * <p>
     * (i,j) must satisfy: 1 <= i <= m and 1 <= j <= n. If not, an exception is thrown.
     *
     * @param i line of entry
     * @param j column of entry
     * @param val new value
     **/
    public void setValue(int i, int j, ComplexNumber val) throws MatrixIndexOutOfBoundsException {
	if ( (i<1) || (i>m) || (j<1) || (j>n) )
	    throw new MatrixIndexOutOfBoundsException("invalid entry");

	A[i][j] = val;
    }


    /**
     * Returns the value of the (i,j)'th entry of this. 
     * <p>
     * (i,j) must satisfy: 1 <= i <= m and 1 <= j <= n. If not, an exception is thrown.
     *
     * @param i line of entry
     * @param j column of entry
     * @return the (i,j)'th entry of this
     **/      
    public ComplexNumber getValue(int i, int j) throws MatrixIndexOutOfBoundsException {
	if ( (i<1) || (i>m) || (j<1) || (j>n) )
	    throw new MatrixIndexOutOfBoundsException("invalid entry");

	return A[i][j];
    }


    /**
     * Returns this scaled by z. Meaning every entry in this is 
     * multiplicated by z.
     *
     * @param z the scalar.
     * @return z*this
     */
    public Matrix scale(ComplexNumber z) {
	Matrix B = new Matrix(m,n);

	for(int i=1; i<=m; i++)
	    for(int j=1; j<=n; j++)
		B.setValue(i, j, this.getValue(i, j).multiplicate(z));

	return B;
    }



    /**
     * Returns a matrix with value this times B, done by matrix multiplication.
     * B must have height equal to this' width. If not, an exception is thrown.
     *
     * @throws InvalidSizeException If B's height don't equal this' width
     * @param B the matrix to multiplicate with
     * @return this * B
     */
    public Matrix multiplicate(Matrix B) throws InvalidSizeException {
	if (B.getSize()[0] == n) {
	    int h = m; int w = B.getSize()[1];
	    Matrix C = new Matrix(h,w);
	    
	    for(int i=1; i<=h; i++) {
		for(int j=1; j<=w; j++) {
		    ComplexNumber sum = new ComplexNumber();
		    for(int k=1; k<=n; k++) {
			sum = sum.add(this.getValue(i, k).multiplicate(B.getValue(k, j)));
		    }
		    C.setValue(i, j, sum);
		}
	    }

	    return C;
	} else {
	    throw new InvalidSizeException("the specified matrix must have width " + n);
	}
    }

						       
    /**
     * Swaps row i and j in this
     *
     * @param i row to swap
     * @param j row to swap
     */
    public void rowSwap(int i, int j) throws MatrixIndexOutOfBoundsException {
	if ( (i<1) || (i>m) || (j<1) || (j>n) )
	    throw new MatrixIndexOutOfBoundsException("invalid entry");

	for(int k=1; k<=n; k++) {
	    ComplexNumber tmp = this.getValue(i,k);
	    this.setValue(i,k,this.getValue(j,k));
	    this.setValue(j,k,tmp);
	}
    }


    /**
     * Adds row i to row j
     *
     * @param i row to add to row j
     * @param j row to be added by row i
     */
    public void rowAdd(int i, int j) throws MatrixIndexOutOfBoundsException {
	if ( (i<1) || (i>m) || (j<1) || (j>n) )
	    throw new MatrixIndexOutOfBoundsException("invalid entry");

	for(int k=1; k<=n; k++)
	    this.setValue(j,k,this.getValue(j,k).add(this.getValue(i,k)));
    }


    /**
     * Adds row i x times to row j
     *
     * @param i row to add to row j x times
     * @param j row to be added by row i x times
     * @param x number of times row i is addad to row j
     */
    public void rowAdd(int i, int j, ComplexNumber x) {
	if ( (i<1) || (i>m) || (j<1) || (j>n) )
	    throw new MatrixIndexOutOfBoundsException("invalid entry");

	for(int k=1; k<=n; k++)
	    setValue(j,k,getValue(j,k).add(getValue(i,k).multiplicate(x)));
    }


    /**
     * Multiplicates row i by x
     *
     * @param i row to multiplicate with x
     * @param x multiplication factor
     */
    public void rowMultiplicate(int i, ComplexNumber x) {
	if ( (i<1) || (i>m) )
	    throw new MatrixIndexOutOfBoundsException("invalid entry");

	for(int k=1; k<=n; k++)
	    setValue(i,k,getValue(i,k).multiplicate(x));
    }



    public static ComplexNumber det(Matrix B) throws InvalidSizeException {
	int[] s = B.getSize();
	int h=s[0]; int w=s[1];
	if( !(h==w) ) throw new InvalidSizeException("Matrix must be n x n to calculate determinant");
	

	//The recursion starts here...
	if( (h==1) && (w==1) ) return B.getValue(1,1);

	ComplexNumber sum = new ComplexNumber();

	//Calculates by the first row...
	for(int i=1; i<=w; i++) {

	    //Creates a new (h-1,w-1) matrix, and copies entries from B
	    Matrix C = new Matrix(w-1,h-1);
	    for(int j=2; j<=h; j++) {
		int c = 0;

		for(int k=1; k<=w; k++) {
		    //Don't copy entry, if i==k...
		    if(!(i==k)) {
			c++;
			C.setValue(j-1, c, B.getValue(j, k));
		    }
		}
	    }
	    //sums up the minor determinants.
	    sum = sum.add( B.getValue(1,i).multiplicate(new ComplexNumber(Math.pow(-1,i+1),0).multiplicate(det(C))) );
	}
	return sum;
    }

    
    public Matrix inverse() throws NotInvertibleException {
	return Matrix.inverse(this);
    }



    public static Matrix inverse(Matrix B) throws NotInvertibleException {
	int[] s = B.getSize();
	int h=s[0]; int w=s[1];

	if (h != w) throw new InvalidSizeException("Must be quadratic to be invertible");
	if (det(B).equals(ComplexNumber.ZERO)) throw new NotInvertibleException("det(B) == 0");

	Matrix BC = new Matrix(h, 2*w);
	Matrix C = identity(3);

	//Creates a new matrix BC of the form [B | I]
	for(int i=1; i<=h; i++)
	    for(int j=1; j<=w; j++)
		BC.setValue(i, j, B.getValue(i,j));

	for(int i=1; i<=h; i++)
	    for(int j=w+1; j<=2*w; j++)
		BC.setValue(i, j, C.getValue(i,j-w));

	BC.rref();
	
	for(int i=1; i<=h; i++)
	    for(int j=1; j<=w; j++)
		C.setValue(i,j, BC.getValue(i, j+w));
		
	return C;
    }



    public Matrix transpose() {
	Matrix B = new Matrix(n,m);

	for(int i=1; i<=m; i++)
	    for(int j=1; j<=n; j++)
		B.setValue(j,i, getValue(i,j));

	return B;
    }


    /**
     * Returns the Frobenius-norm (sum of norm of all entrys)
     *
     * @return the norm of this
     */
    public double norm() {
	double sum = 0;

	for(int i=1; i<=m; i++)
	    for(int j=1; j<=n; j++)
		sum += getValue(i,j).norm();

	return sum;
    }


    /**
     * By rowoperations, turn the matrix into reduced-row-echelon-form.
     */
    public void rref() {
	int j;

	for(int i=1; i<=m; i++) {
	    try {
		for(j=1; getValue(i,j).equals(ComplexNumber.ZERO) && j<= n; j++);

		rowMultiplicate(i, getValue(i,j).inverse()); //Creates the pivot
		for(int k=1; k<=m; k++) //rowadds, so we that have zeros under and above the pivot
		    if (k != i)
			rowAdd(i, k, getValue(k,j).multiplicate(new ComplexNumber(-1,0)));
		    
	    } catch (MatrixIndexOutOfBoundsException e) {
	    }
	}


	//Swaps the rows in place...
	int pivot = 1;
	for(j=1; j<=n; j++) {
	    for(int i=pivot; i<=m; i++) {
		if (getValue(i,j).equals(ComplexNumber.ONE)) {
		    rowSwap(i, pivot);
		    pivot++;
		}
	    }
	}
    }


    /**
     * Returns a String-representation of the matrix, like toString(),
     * but with a prefix 'name = '. Eg output("I") on the 3x3 identity:
     *  <code><br>    [ 1.0  0.0  0.0 ]
     *  <br>I = [ 0.0  1.0  0.0 ]
     *  <br>    [ 0.0  0.0  1.0 ]</code>
     *
     * @param name Name of the matrix
     * @return String representation of matrix
     */
    public String output(String name) {
	String out = "";
	String[] line = new String[m+1];

	// Caluclates the number of whitespaces in front of the matrix
	int prefixlen = name.length() + 3;
	String prefix = "";
	for(int i=1; i<=prefixlen; i++)
	    prefix = prefix + " ";


	// Puts whitespaces and a [ at the start of each line...
	for(int i=1; i<=m; i++) {
	    if (i == Math.floor(m/2 + 1)) {
		line[i] = name + " = " + "[  ";
	    } else {
		line[i] = prefix + "[  ";
	    }
	}


	for(int j=1; j<=n; j++) {
	    // Puts the entry at the end of each line, and finds out
	    // what the max length of the lines is...	    
	    int maxlength=0;
	    for(int i=1; i<=m; i++) {
		line[i] = line[i] + this.getValue(i,j) + "  ";
		maxlength = Math.max(maxlength, line[i].length());
	    }

	    // Puts whitespaces after each line, so they match in length with
	    // the longest...
	    for(int k=1; k<=m; k++) {
		int spacesneeded = maxlength - line[k].length();
		for(int l=0; l<=spacesneeded; l++)
		    line[k] = line[k] + " ";
	    }
	}

	// Puts a ] at the end of each line...
	for(int i=1; i<=m; i++) {
	    line[i] = line[i] + "]";
	    out = out + line[i] + "\n";
	}
	
	return out;
    }    


    /**
     * Returns a string, showing the matrix. The 3 x 3 identity 
     * will be returned as:
     *  <code><br>[ 1.0  0.0  0.0 ]
     *  <br>[ 0.0  1.0  0.0 ]
     *  <br>[ 0.0  0.0  1.0 ]</code>
     *
     * @return a string showing the matrix
     **/
    public String output() {
	String out = "";
	String[] line = new String[m+1];

	// Puts a [ at the start of each line...
	for(int i=1; i<=m; i++)
	    line[i] = "[  ";


	for(int j=1; j<=n; j++) {
	    // Puts the entry at the end of each line, and finds out
	    // what the max length of the lines is...	    
	    int maxlength=0;
	    for(int i=1; i<=m; i++) {
		line[i] = line[i] + this.getValue(i,j) + "  ";
		maxlength = Math.max(maxlength, line[i].length());
	    }

	    // Puts whitespaces after each line, so they match in length with
	    // the longest...
	    for(int k=1; k<=m; k++) {
		int spacesneeded = maxlength - line[k].length();
		for(int l=0; l<=spacesneeded; l++)
		    line[k] = line[k] + " ";
	    }
	}

	// Puts a ] at the end of each line...
	for(int i=1; i<=m; i++) {
	    line[i] = line[i] + "]";
	    out = out + line[i] + "\n";
	}
	
	return out;
    }   


    
    /**
     * Returns a string, showing the matrix. The 3 x 3 identity 
     * will be returned as:
     *  <code><br>[ 1.0  0.0  0.0 ]
     *  <br>[ 0.0  1.0  0.0 ]
     *  <br>[ 0.0  0.0  1.0 ]</code>
     *
     * @return a string showing the matrix
     **/
    public String toString() {
	return output();
    }    
}

