/*
 * Decompiled with CFR 0.152.
 */
package keel.Algorithms.Neural_Networks.gann;

import java.io.BufferedWriter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import keel.Algorithms.Neural_Networks.gann.Data;
import keel.Algorithms.Neural_Networks.gann.Genesis;
import keel.Algorithms.Neural_Networks.gann.SetupParameters;
import keel.Dataset.Attributes;

public class Network {
    public int Nlayers;
    public int Ninputs;
    public int Noutputs;
    public int[] Nhidden;
    public double[][][] w;
    public double[][][] momentum;
    public double[][] delta;
    public double[][] activation;
    public final double a = 1.7165;
    public final double b_log = 1.5;
    public final double b_htan = 0.6666;
    public String[] transfer;

    public Network() {
    }

    public Network(SetupParameters global) {
        int i;
        this.transfer = new String[global.Nhidden_layers + 1];
        for (i = 0; i < global.Nhidden_layers + 1; ++i) {
            this.transfer[i] = global.transfer[i];
        }
        this.Ninputs = global.Ninputs;
        this.Noutputs = global.Noutputs;
        this.Nlayers = global.Nhidden_layers + 2;
        this.Nhidden = new int[this.Nlayers];
        this.w = new double[this.Nlayers - 1][][];
        this.delta = new double[this.Nlayers][];
        this.activation = new double[this.Nlayers][];
        this.momentum = new double[this.Nlayers - 1][][];
        this.Nhidden[0] = this.Ninputs;
        this.delta[0] = new double[this.Nhidden[0]];
        this.activation[0] = new double[this.Nhidden[0]];
        for (i = 1; i < this.Nlayers; ++i) {
            this.Nhidden[i] = global.Nhidden[i - 1];
            this.w[i - 1] = new double[this.Nhidden[i]][this.Nhidden[i - 1]];
            this.momentum[i - 1] = new double[this.Nhidden[i]][this.Nhidden[i - 1]];
            this.delta[i] = new double[this.Nhidden[i]];
            this.activation[i] = new double[this.Nhidden[i]];
        }
        this.Nhidden[this.Nlayers - 1] = this.Noutputs;
        for (int k = 0; k < this.Nlayers - 1; ++k) {
            double range = Math.sqrt(3.0) / (double)this.Nhidden[k];
            for (int i2 = 0; i2 < this.Nhidden[k + 1]; ++i2) {
                for (int j = 0; j < this.Nhidden[k]; ++j) {
                    this.w[k][i2][j] = Genesis.frandom(-range, range);
                }
            }
        }
    }

    public void TrainNetworkWithCrossvalidation(SetupParameters global, Data data) {
        double old_error;
        double new_error = 0.0;
        if (global.problem.compareToIgnoreCase("Classification") == 0) {
            new_error = this.TestNetworkInClassification(global, data.validation, global.n_val_patterns);
        } else if (global.problem.compareToIgnoreCase("Regression") == 0) {
            new_error = this.TestNetworkInRegression(global, data.validation, global.n_val_patterns);
        } else {
            System.err.println("Type of problem incorrectly defined");
            System.exit(1);
        }
        do {
            if (global.bp_type.compareToIgnoreCase("BPstd") == 0) {
                this.BackPropagation(global, global.cycles, data.train, global.n_train_patterns);
            }
            old_error = new_error;
            if (global.problem.compareToIgnoreCase("Classification") == 0) {
                new_error = this.TestNetworkInClassification(global, data.validation, global.n_val_patterns);
                continue;
            }
            if (global.problem.compareToIgnoreCase("Regression") != 0) continue;
            new_error = this.TestNetworkInRegression(global, data.validation, global.n_val_patterns);
        } while (new_error <= (1.0 - global.improve) * old_error);
    }

    public void TrainNetwork(SetupParameters global, double[][] data, int npatterns) {
        if (global.bp_type.compareToIgnoreCase("BPstd") == 0) {
            this.BackPropagation(global, global.cycles, data, npatterns);
        }
    }

    public double TestNetworkInClassification(SetupParameters global, double[][] data, int npatterns) {
        double ok = 0.0;
        for (int i = 0; i < npatterns; ++i) {
            this.GenerateOutput(data[i]);
            int max_index = 0;
            for (int j = 1; j < this.Noutputs; ++j) {
                if (!(this.activation[this.Nlayers - 1][max_index] < this.activation[this.Nlayers - 1][j])) continue;
                max_index = j;
            }
            int Class2 = 0;
            for (int j = 1; j < this.Noutputs; ++j) {
                if (!(data[i][Class2 + this.Ninputs] < data[i][j + this.Ninputs])) continue;
                Class2 = j;
            }
            if (Class2 != max_index) continue;
            ok += 1.0;
        }
        double fitness = ok / (double)npatterns;
        return fitness;
    }

    public double TestNetworkInRegression(SetupParameters global, double[][] data, int npatterns) {
        double RMS = 0.0;
        for (int i = 0; i < npatterns; ++i) {
            this.GenerateOutput(data[i]);
            double error = 0.0;
            for (int j = 0; j < this.Noutputs; ++j) {
                RMS += Math.sqrt(error += Math.pow(this.activation[this.Nlayers - 1][j] - data[i][this.Ninputs + j], 2.0));
            }
        }
        double fitness = RMS / (double)(npatterns * this.Noutputs);
        return fitness;
    }

    private void BackPropagation(SetupParameters global, int cycles, double[][] data, int npatterns) {
        int i;
        double[] error = new double[this.Noutputs];
        for (int k = 0; k < this.Nlayers - 1; ++k) {
            for (i = 0; i < this.Nhidden[k + 1]; ++i) {
                for (int j = 0; j < this.Nhidden[k]; ++j) {
                    this.momentum[k][i][j] = 0.0;
                }
            }
        }
        for (int iter = 0; iter < cycles; ++iter) {
            int j;
            int i2;
            int k;
            int pattern = Genesis.irandom(0.0, npatterns);
            this.GenerateOutput(data[pattern]);
            for (i = 0; i < this.Noutputs; ++i) {
                error[i] = data[pattern][this.Ninputs + i] - this.activation[this.Nlayers - 1][i];
            }
            for (i = 0; i < this.Noutputs; ++i) {
                this.delta[this.Nlayers - 1][i] = this.transfer[this.Nlayers - 2].compareToIgnoreCase("Log") == 0 ? error[i] * 1.5 * this.activation[this.Nlayers - 1][i] * (1.0 - this.activation[this.Nlayers - 1][i] / 1.7165) : (this.transfer[this.Nlayers - 2].compareToIgnoreCase("Htan") == 0 ? error[i] * 0.38834838333818816 * (1.7165 - this.activation[this.Nlayers - 1][i]) * (1.7165 + this.activation[this.Nlayers - 1][i]) : error[i]);
            }
            for (k = this.Nlayers - 2; k > 0; --k) {
                for (i2 = 0; i2 < this.Nhidden[k]; ++i2) {
                    this.delta[k][i2] = 0.0;
                    for (j = 0; j < this.Nhidden[k + 1]; ++j) {
                        double[] dArray = this.delta[k];
                        int n = i2;
                        dArray[n] = dArray[n] + this.delta[k + 1][j] * this.w[k][j][i2];
                    }
                    if (this.transfer[k - 1].compareToIgnoreCase("Log") == 0) {
                        double[] dArray = this.delta[k];
                        int n = i2;
                        dArray[n] = dArray[n] * (1.5 * this.activation[k][i2] * (1.0 - this.activation[k][i2] / 1.7165));
                        continue;
                    }
                    if (this.transfer[k - 1].compareToIgnoreCase("Htan") != 0) continue;
                    double[] dArray = this.delta[k];
                    int n = i2;
                    dArray[n] = dArray[n] * (0.38834838333818816 * (1.7165 - this.activation[k][i2]) * (1.7165 + this.activation[k][i2]));
                }
            }
            for (k = this.Nlayers - 2; k >= 0; --k) {
                for (i2 = 0; i2 < this.Nhidden[k + 1]; ++i2) {
                    for (j = 0; j < this.Nhidden[k]; ++j) {
                        double change = global.eta * this.delta[k + 1][i2] * this.activation[k][j] + global.alpha * this.momentum[k][i2][j] - global.lambda * this.w[k][i2][j];
                        double[] dArray = this.w[k][i2];
                        int n = j;
                        dArray[n] = dArray[n] + change;
                        this.momentum[k][i2][j] = change;
                    }
                }
            }
        }
    }

    public void GenerateOutput(double[] input) {
        int i;
        for (i = 1; i < this.Nlayers; ++i) {
            for (int j = 0; j < this.Nhidden[i]; ++j) {
                this.activation[i][j] = 0.0;
            }
        }
        for (i = 0; i < this.Nhidden[0]; ++i) {
            this.activation[0][i] = input[i];
        }
        for (int k = 1; k < this.Nlayers; ++k) {
            for (int i2 = 0; i2 < this.Nhidden[k]; ++i2) {
                this.activation[k][i2] = 0.0;
                for (int j = 0; j < this.Nhidden[k - 1]; ++j) {
                    double[] dArray = this.activation[k];
                    int n = i2;
                    dArray[n] = dArray[n] + this.activation[k - 1][j] * this.w[k - 1][i2][j];
                }
                if (this.transfer[k - 1].compareToIgnoreCase("Log") == 0) {
                    this.activation[k][i2] = this.logistic(this.activation[k][i2]);
                    continue;
                }
                if (this.transfer[k - 1].compareToIgnoreCase("Htan") != 0) continue;
                this.activation[k][i2] = this.htan(this.activation[k][i2]);
            }
        }
    }

    public void GenerateOutput(double[] input, double[] output) {
        this.GenerateOutput(input);
        for (int i = 0; i < this.Noutputs; ++i) {
            output[i] = this.activation[this.Nlayers - 1][i];
        }
    }

    private double logistic(double x) {
        double sig = 1.7165 / (1.0 + Math.exp(-1.5 * x));
        return sig;
    }

    private double htan(double x) {
        double sig = (Math.exp(0.6666 * x) - Math.exp(-0.6666 * x)) / (Math.exp(0.6666 * x) + Math.exp(-0.6666 * x));
        return 1.7165 * sig;
    }

    public void SaveNetwork(String file_name, boolean append) {
        try {
            int i;
            FileOutputStream file = new FileOutputStream(file_name, append);
            DataOutputStream dataOut = new DataOutputStream(file);
            dataOut.writeInt(this.Nlayers);
            for (i = 0; i < this.Nlayers; ++i) {
                dataOut.writeInt(this.Nhidden[i]);
            }
            for (i = 0; i < this.Nlayers - 1; ++i) {
                if (this.transfer[i].compareToIgnoreCase("Log") == 0) {
                    dataOut.writeInt(1);
                    continue;
                }
                if (this.transfer[i].compareToIgnoreCase("Htan") == 0) {
                    dataOut.writeInt(2);
                    continue;
                }
                dataOut.writeInt(3);
            }
            for (int k = 0; k < this.Nlayers - 1; ++k) {
                for (int i2 = 0; i2 < this.Nhidden[k + 1]; ++i2) {
                    for (int j = 0; j < this.Nhidden[k]; ++j) {
                        dataOut.writeDouble(this.w[k][i2][j]);
                    }
                }
            }
            dataOut.close();
        }
        catch (FileNotFoundException ex) {
            System.err.println("Unable to create network file");
            System.exit(1);
        }
        catch (IOException ex) {
            System.err.println("IO exception");
            System.exit(1);
        }
    }

    public void LoadNetwork(String file_name) {
        try {
            int i;
            FileInputStream file = new FileInputStream(file_name);
            DataInputStream dataIn = new DataInputStream(file);
            this.Nlayers = dataIn.readInt();
            for (i = 0; i < this.Nlayers; ++i) {
                this.Nhidden[i] = dataIn.readInt();
            }
            for (i = 0; i < this.Nlayers - 1; ++i) {
                int t = dataIn.readInt();
                this.transfer[i] = t == 1 ? "Log" : (t == 2 ? "Htan" : "Lin");
            }
            this.Ninputs = this.Nhidden[0];
            this.Noutputs = this.Nhidden[this.Nlayers - 1];
            for (int k = 0; k < this.Nlayers - 1; ++k) {
                for (int i2 = 0; i2 < this.Nhidden[k + 1]; ++i2) {
                    for (int j = 0; j < this.Nhidden[k]; ++j) {
                        this.w[k][i2][j] = dataIn.readDouble();
                    }
                }
            }
            dataIn.close();
        }
        catch (FileNotFoundException ex) {
            System.err.println("Unable to load network file");
            System.exit(1);
        }
        catch (IOException ex) {
            System.err.println("IO exception");
            System.exit(1);
        }
    }

    public void PrintWeights() {
        for (int k = 0; k < this.Nlayers - 1; ++k) {
            System.out.println("Hidden[" + k + "] -> Hidden[" + (k + 1) + "]");
            System.out.println("Node\tWeights");
            for (int i = 0; i < this.Nhidden[k + 1]; ++i) {
                System.out.print(i + 1 + "\t");
                for (int j = 0; j < this.Nhidden[k]; ++j) {
                    System.out.print(this.w[k][i][j] + " ");
                }
                System.out.println();
            }
        }
    }

    public boolean NetClassifyPattern(double[] pattern) {
        this.GenerateOutput(pattern);
        int max_index = 0;
        for (int j = 1; j < this.Noutputs; ++j) {
            if (!(this.activation[this.Nlayers - 1][max_index] < this.activation[this.Nlayers - 1][j])) continue;
            max_index = j;
        }
        int Class2 = 0;
        for (int j = 1; j < this.Noutputs; ++j) {
            if (!(pattern[Class2 + this.Ninputs] < pattern[j + this.Ninputs])) continue;
            Class2 = j;
        }
        return Class2 == max_index;
    }

    public int NetGetClassOfPattern(double[] pattern) {
        this.GenerateOutput(pattern);
        int max_index = 0;
        for (int j = 1; j < this.Noutputs; ++j) {
            if (!(this.activation[this.Nlayers - 1][max_index] < this.activation[this.Nlayers - 1][j])) continue;
            max_index = j;
        }
        return max_index;
    }

    private int GetClassOfPattern(double[] pattern) {
        int max_index = 0;
        for (int j = 1; j < this.Noutputs; ++j) {
            if (!(pattern[max_index] < pattern[j])) continue;
            max_index = j;
        }
        return max_index;
    }

    public void SaveOutputFile(String file_name, double[][] data, int n, String problem) {
        try {
            FileOutputStream file = new FileOutputStream(file_name);
            BufferedWriter f = new BufferedWriter(new OutputStreamWriter(file));
            f.write("@relation " + Attributes.getRelationName() + "\n");
            f.write(Attributes.getInputAttributesHeader());
            f.write(Attributes.getOutputAttributesHeader());
            f.write(Attributes.getInputHeader() + "\n");
            f.write(Attributes.getOutputHeader() + "\n");
            f.write("@data\n");
            for (int i = 0; i < n; ++i) {
                int j;
                if (problem.compareToIgnoreCase("Classification") == 0) {
                    int Class2 = 0;
                    for (int j2 = 1; j2 < this.Noutputs; ++j2) {
                        if (!(data[i][Class2 + this.Ninputs] < data[i][j2 + this.Ninputs])) continue;
                        Class2 = j2;
                    }
                    int algorithmSolution = this.NetGetClassOfPattern(data[i]);
                    f.write(Attributes.getOutputAttributes()[0].getNominalValue(Class2) + " ");
                    f.write(Attributes.getOutputAttributes()[0].getNominalValue(algorithmSolution));
                    f.newLine();
                    f.flush();
                    continue;
                }
                for (j = 0; j < this.Noutputs; ++j) {
                    f.write(Double.toString(data[i][this.Ninputs + j]) + " ");
                }
                this.GenerateOutput(data[i]);
                for (j = 0; j < this.Noutputs; ++j) {
                    f.write(Double.toString(this.activation[this.Nlayers - 1][j]) + " ");
                }
                f.newLine();
            }
            f.close();
            file.close();
        }
        catch (FileNotFoundException e) {
            System.err.println("Cannot created output file");
            System.exit(-1);
        }
        catch (IOException e) {
            e.printStackTrace();
            System.exit(-1);
        }
    }
}

