/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.pqc.crypto.mlkem;

import org.bouncycastle.pqc.crypto.mlkem.CBD;
import org.bouncycastle.pqc.crypto.mlkem.MLKEMEngine;
import org.bouncycastle.pqc.crypto.mlkem.Ntt;
import org.bouncycastle.pqc.crypto.mlkem.Reduce;
import org.bouncycastle.pqc.crypto.mlkem.Symmetric;

class Poly {
    private short[] coeffs = new short[256];
    private MLKEMEngine engine;
    private int polyCompressedBytes;
    private int eta1;
    private int eta2;
    private Symmetric symmetric;

    public Poly(MLKEMEngine mLKEMEngine) {
        this.engine = mLKEMEngine;
        this.polyCompressedBytes = mLKEMEngine.getKyberPolyCompressedBytes();
        this.eta1 = mLKEMEngine.getKyberEta1();
        this.eta2 = MLKEMEngine.getKyberEta2();
        this.symmetric = mLKEMEngine.getSymmetric();
    }

    public void addCoeffs(Poly poly) {
        int n = 0;
        while (n < 256) {
            this.setCoeffIndex(n, (short)(this.getCoeffIndex(n) + poly.getCoeffIndex(n)));
            ++n;
        }
    }

    public static void baseMultMontgomery(Poly poly, Poly poly2, Poly poly3) {
        int n = 0;
        while (n < 64) {
            Ntt.baseMult(poly, 4 * n, poly2.getCoeffIndex(4 * n), poly2.getCoeffIndex(4 * n + 1), poly3.getCoeffIndex(4 * n), poly3.getCoeffIndex(4 * n + 1), Ntt.nttZetas[64 + n]);
            Ntt.baseMult(poly, 4 * n + 2, poly2.getCoeffIndex(4 * n + 2), poly2.getCoeffIndex(4 * n + 3), poly3.getCoeffIndex(4 * n + 2), poly3.getCoeffIndex(4 * n + 3), (short)(-1 * Ntt.nttZetas[64 + n]));
            ++n;
        }
    }

    public byte[] compressPoly() {
        byte[] byArray = new byte[8];
        byte[] byArray2 = new byte[this.polyCompressedBytes];
        int n = 0;
        this.conditionalSubQ();
        if (this.polyCompressedBytes == 128) {
            int n2 = 0;
            while (n2 < 32) {
                int n3 = 0;
                while (n3 < 8) {
                    int n4 = this.getCoeffIndex(8 * n2 + n3);
                    n4 <<= 4;
                    n4 += 1665;
                    n4 *= 80635;
                    n4 >>= 28;
                    byArray[n3] = (byte)(n4 &= 0xF);
                    ++n3;
                }
                byArray2[n] = (byte)(byArray[0] | byArray[1] << 4);
                byArray2[n + 1] = (byte)(byArray[2] | byArray[3] << 4);
                byArray2[n + 2] = (byte)(byArray[4] | byArray[5] << 4);
                byArray2[n + 3] = (byte)(byArray[6] | byArray[7] << 4);
                n += 4;
                ++n2;
            }
        } else if (this.polyCompressedBytes == 160) {
            int n5 = 0;
            while (n5 < 32) {
                int n6 = 0;
                while (n6 < 8) {
                    int n7 = this.getCoeffIndex(8 * n5 + n6);
                    n7 <<= 5;
                    n7 += 1664;
                    n7 *= 40318;
                    n7 >>= 27;
                    byArray[n6] = (byte)(n7 &= 0x1F);
                    ++n6;
                }
                byArray2[n] = (byte)(byArray[0] | byArray[1] << 5);
                byArray2[n + 1] = (byte)(byArray[1] >> 3 | byArray[2] << 2 | byArray[3] << 7);
                byArray2[n + 2] = (byte)(byArray[3] >> 1 | byArray[4] << 4);
                byArray2[n + 3] = (byte)(byArray[4] >> 4 | byArray[5] << 1 | byArray[6] << 6);
                byArray2[n + 4] = (byte)(byArray[6] >> 2 | byArray[7] << 3);
                n += 5;
                ++n5;
            }
        } else {
            throw new RuntimeException("PolyCompressedBytes is neither 128 or 160!");
        }
        return byArray2;
    }

    public void conditionalSubQ() {
        int n = 0;
        while (n < 256) {
            this.setCoeffIndex(n, Reduce.conditionalSubQ(this.getCoeffIndex(n)));
            ++n;
        }
    }

    public void convertToMont() {
        int n = 1353;
        int n2 = 0;
        while (n2 < 256) {
            this.setCoeffIndex(n2, Reduce.montgomeryReduce(this.getCoeffIndex(n2) * 1353));
            ++n2;
        }
    }

    public void decompressPoly(byte[] byArray) {
        int n = 0;
        if (this.engine.getKyberPolyCompressedBytes() == 128) {
            int n2 = 0;
            while (n2 < 128) {
                this.setCoeffIndex(2 * n2, (short)((short)(byArray[n] & 0xFF & 0xF) * 3329 + 8 >> 4));
                this.setCoeffIndex(2 * n2 + 1, (short)((short)((byArray[n] & 0xFF) >> 4) * 3329 + 8 >> 4));
                ++n;
                ++n2;
            }
        } else if (this.engine.getKyberPolyCompressedBytes() == 160) {
            byte[] byArray2 = new byte[8];
            int n3 = 0;
            while (n3 < 32) {
                byArray2[0] = (byte)(byArray[n] & 0xFF);
                byArray2[1] = (byte)((byArray[n] & 0xFF) >> 5 | (byArray[n + 1] & 0xFF) << 3);
                byArray2[2] = (byte)((byArray[n + 1] & 0xFF) >> 2);
                byArray2[3] = (byte)((byArray[n + 1] & 0xFF) >> 7 | (byArray[n + 2] & 0xFF) << 1);
                byArray2[4] = (byte)((byArray[n + 2] & 0xFF) >> 4 | (byArray[n + 3] & 0xFF) << 4);
                byArray2[5] = (byte)((byArray[n + 3] & 0xFF) >> 1);
                byArray2[6] = (byte)((byArray[n + 3] & 0xFF) >> 6 | (byArray[n + 4] & 0xFF) << 2);
                byArray2[7] = (byte)((byArray[n + 4] & 0xFF) >> 3);
                n += 5;
                int n4 = 0;
                while (n4 < 8) {
                    this.setCoeffIndex(8 * n3 + n4, (short)((byArray2[n4] & 0x1F) * 3329 + 16 >> 5));
                    ++n4;
                }
                ++n3;
            }
        } else {
            throw new RuntimeException("PolyCompressedBytes is neither 128 or 160!");
        }
    }

    public void fromBytes(byte[] byArray) {
        int n = 0;
        while (n < 128) {
            this.setCoeffIndex(2 * n, (short)((byArray[3 * n] & 0xFF | (byArray[3 * n + 1] & 0xFF) << 8) & 0xFFF));
            this.setCoeffIndex(2 * n + 1, (short)(((long)((byArray[3 * n + 1] & 0xFF) >> 4) | (long)((byArray[3 * n + 2] & 0xFF) << 4)) & 0xFFFL));
            ++n;
        }
    }

    public void fromMsg(byte[] byArray) {
        if (byArray.length != 32) {
            throw new RuntimeException("KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!");
        }
        int n = 0;
        while (n < 32) {
            int n2 = 0;
            while (n2 < 8) {
                short s = (short)(-1 * (short)((byArray[n] & 0xFF) >> n2 & 1));
                this.setCoeffIndex(8 * n + n2, (short)(s & 0x681));
                ++n2;
            }
            ++n;
        }
    }

    public short getCoeffIndex(int n) {
        return this.coeffs[n];
    }

    public short[] getCoeffs() {
        return this.coeffs;
    }

    public void getEta1Noise(byte[] byArray, byte by) {
        byte[] byArray2 = new byte[256 * this.eta1 / 4];
        this.symmetric.prf(byArray2, byArray, by);
        CBD.mlkemCBD(this, byArray2, this.eta1);
    }

    public void getEta2Noise(byte[] byArray, byte by) {
        byte[] byArray2 = new byte[256 * this.eta2 / 4];
        this.symmetric.prf(byArray2, byArray, by);
        CBD.mlkemCBD(this, byArray2, this.eta2);
    }

    public void polyInverseNttToMont() {
        this.setCoeffs(Ntt.invNtt(this.getCoeffs()));
    }

    public void polyNtt() {
        this.setCoeffs(Ntt.ntt(this.getCoeffs()));
        this.reduce();
    }

    public void polySubtract(Poly poly) {
        int n = 0;
        while (n < 256) {
            this.setCoeffIndex(n, (short)(poly.getCoeffIndex(n) - this.getCoeffIndex(n)));
            ++n;
        }
    }

    public void reduce() {
        int n = 0;
        while (n < 256) {
            this.setCoeffIndex(n, Reduce.barretReduce(this.getCoeffIndex(n)));
            ++n;
        }
    }

    public void setCoeffIndex(int n, short s) {
        this.coeffs[n] = s;
    }

    public void setCoeffs(short[] sArray) {
        this.coeffs = sArray;
    }

    public byte[] toBytes() {
        byte[] byArray = new byte[384];
        this.conditionalSubQ();
        int n = 0;
        while (n < 128) {
            short s = this.getCoeffIndex(2 * n);
            short s2 = this.getCoeffIndex(2 * n + 1);
            byArray[3 * n] = (byte)s;
            byArray[3 * n + 1] = (byte)(s >> 8 | s2 << 4);
            byArray[3 * n + 2] = (byte)(s2 >> 4);
            ++n;
        }
        return byArray;
    }

    public byte[] toMsg() {
        int n = 832;
        int n2 = 3329 - n;
        byte[] byArray = new byte[MLKEMEngine.getKyberIndCpaMsgBytes()];
        this.conditionalSubQ();
        int n3 = 0;
        while (n3 < 32) {
            byArray[n3] = 0;
            int n4 = 0;
            while (n4 < 8) {
                short s = this.getCoeffIndex(8 * n3 + n4);
                int n5 = (n - s & s - n2) >>> 31;
                int n6 = n3;
                byArray[n6] = (byte)(byArray[n6] | (byte)(n5 << n4));
                ++n4;
            }
            ++n3;
        }
        return byArray;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("[");
        int n = 0;
        while (n < this.coeffs.length) {
            stringBuffer.append(this.coeffs[n]);
            if (n != this.coeffs.length - 1) {
                stringBuffer.append(", ");
            }
            ++n;
        }
        stringBuffer.append("]");
        return stringBuffer.toString();
    }
}

