/*
 * Decompiled with CFR 0.152.
 */
package org.bouncycastle.crypto.test;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.security.SecureRandom;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import junit.framework.Assert;
import junit.framework.TestCase;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.crypto.InvalidCipherTextException;
import org.bouncycastle.crypto.hpke.AEAD;
import org.bouncycastle.crypto.hpke.HPKE;
import org.bouncycastle.crypto.hpke.HPKEContext;
import org.bouncycastle.crypto.hpke.HPKEContextWithEncapsulation;
import org.bouncycastle.crypto.params.AsymmetricKeyParameter;
import org.bouncycastle.test.TestResourceFinder;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.encoders.Hex;

public class HPKETestVectors
extends TestCase {
    public void testBaseOneShotPairwise() throws Exception {
        HPKE hpke = new HPKE(0, 16, 1, 1);
        AsymmetricCipherKeyPair kp = hpke.generatePrivateKey();
        byte[][] output = hpke.seal(kp.getPublic(), "info".getBytes(), "aad".getBytes(), "message".getBytes(), null, null, null);
        byte[] ct = output[0];
        byte[] encap = output[1];
        byte[] message = hpke.open(encap, kp, "info".getBytes(), "aad".getBytes(), ct, null, null, null);
        Assert.assertTrue((String)"Failed", (boolean)Arrays.areEqual((byte[])message, (byte[])"message".getBytes()));
        try {
            byte[] brokenCt = Arrays.concatenate((byte[])ct, (byte[])"eh".getBytes());
            hpke.open(encap, kp, "info".getBytes(), "aad".getBytes(), brokenCt, null, null, null);
            Assert.fail((String)"no exception");
        }
        catch (InvalidCipherTextException e) {
            Assert.assertEquals((String)"Failed", (String)"mac check in GCM failed", (String)e.getMessage());
        }
    }

    public void testAuthOneShotPairwise() throws Exception {
        HPKE hpke = new HPKE(2, 18, 1, 1);
        AsymmetricCipherKeyPair reciever = hpke.generatePrivateKey();
        AsymmetricCipherKeyPair sender = hpke.generatePrivateKey();
        byte[][] output = hpke.seal(reciever.getPublic(), "info".getBytes(), "aad".getBytes(), "message".getBytes(), null, null, sender);
        byte[] ct = output[0];
        byte[] encap = output[1];
        byte[] message = hpke.open(encap, reciever, "info".getBytes(), "aad".getBytes(), ct, null, null, sender.getPublic());
        Assert.assertTrue((String)"Failed", (boolean)Arrays.areEqual((byte[])message, (byte[])"message".getBytes()));
        try {
            byte[] brokenCt = Arrays.concatenate((byte[])ct, (byte[])"eh".getBytes());
            hpke.open(encap, reciever, "info".getBytes(), "aad".getBytes(), brokenCt, null, null, sender.getPublic());
            Assert.fail((String)"no exception");
        }
        catch (InvalidCipherTextException e) {
            Assert.assertEquals((String)"Failed", (String)"mac check in GCM failed", (String)e.getMessage());
        }
        try {
            message = hpke.open(encap, reciever, "info".getBytes(), "aad".getBytes(), ct, null, null, reciever.getPublic());
            Assert.fail((String)"no exception");
        }
        catch (InvalidCipherTextException e) {
            Assert.assertEquals((String)"Failed", (String)"mac check in GCM failed", (String)e.getMessage());
        }
    }

    public void testBasePairwise() throws Exception {
        HPKE hpke = new HPKE(0, 16, 1, 1);
        AsymmetricCipherKeyPair receiver = hpke.generatePrivateKey();
        HPKEContextWithEncapsulation ctxS = hpke.setupBaseS(receiver.getPublic(), "info".getBytes());
        HPKEContext ctxR = hpke.setupBaseR(ctxS.getEncapsulation(), receiver, "info".getBytes());
        Assert.assertTrue((boolean)Arrays.areEqual((byte[])ctxS.export("context".getBytes(), 512), (byte[])ctxR.export("context".getBytes(), 512)));
        byte[] aad = new byte[32];
        byte[] message = new byte[128];
        SecureRandom random = new SecureRandom();
        int i = 0;
        while (i < 128) {
            random.nextBytes(aad);
            random.nextBytes(message);
            byte[] ct = ctxS.seal(aad, message);
            Assert.assertTrue((boolean)Arrays.areEqual((byte[])message, (byte[])ctxR.open(aad, ct)));
            ++i;
        }
    }

    public void testAuthPairwise() throws Exception {
        HPKE hpke = new HPKE(2, 16, 1, 1);
        AsymmetricCipherKeyPair receiver = hpke.generatePrivateKey();
        AsymmetricCipherKeyPair sender = hpke.generatePrivateKey();
        HPKEContextWithEncapsulation ctxS = hpke.setupAuthS(receiver.getPublic(), "info".getBytes(), sender);
        HPKEContext ctxR = hpke.setupAuthR(ctxS.getEncapsulation(), receiver, "info".getBytes(), sender.getPublic());
        Assert.assertTrue((boolean)Arrays.areEqual((byte[])ctxS.export("context".getBytes(), 512), (byte[])ctxR.export("context".getBytes(), 512)));
        byte[] aad = new byte[32];
        byte[] message = new byte[128];
        SecureRandom random = new SecureRandom();
        int i = 0;
        while (i < 128) {
            random.nextBytes(aad);
            random.nextBytes(message);
            byte[] ct = ctxS.seal(aad, message);
            Assert.assertTrue((boolean)Arrays.areEqual((byte[])message, (byte[])ctxR.open(aad, ct)));
            ++i;
        }
    }

    public void testVectors() throws Exception {
        InputStream src = TestResourceFinder.findTestResource("crypto", "hpke.txt");
        BufferedReader bin = new BufferedReader(new InputStreamReader(src));
        String line = null;
        HashMap<String, String> buf = new HashMap<String, String>();
        HashMap<String, String> encBuf = new HashMap<String, String>();
        HashMap<String, String> expBuf = new HashMap<String, String>();
        ArrayList<Encryption> encryptions = new ArrayList<Encryption>();
        ArrayList<Export> exports = new ArrayList<Export>();
        block6: while ((line = bin.readLine()) != null) {
            if ((line = line.trim()).length() == 0) {
                if (buf.size() > 0) {
                    String count = (String)buf.get("count");
                    byte mode = Byte.parseByte((String)buf.get("mode"));
                    short kem_id = Short.parseShort((String)buf.get("kem_id"));
                    short kdf_id = Short.parseShort((String)buf.get("kdf_id"));
                    short aead_id = (short)Integer.parseInt((String)buf.get("aead_id"));
                    byte[] info = Hex.decode((String)((String)buf.get("info")));
                    byte[] ikmR = Hex.decode((String)((String)buf.get("ikmR")));
                    byte[] ikmS = null;
                    byte[] ikmE = Hex.decode((String)((String)buf.get("ikmE")));
                    byte[] skRm = Hex.decode((String)((String)buf.get("skRm")));
                    byte[] skSm = null;
                    byte[] skEm = Hex.decode((String)((String)buf.get("skEm")));
                    byte[] psk = null;
                    byte[] psk_id = null;
                    byte[] pkRm = Hex.decode((String)((String)buf.get("pkRm")));
                    byte[] pkSm = null;
                    byte[] pkEm = Hex.decode((String)((String)buf.get("pkEm")));
                    byte[] enc = Hex.decode((String)((String)buf.get("enc")));
                    byte[] shared_secret = Hex.decode((String)((String)buf.get("shared_secret")));
                    byte[] key_schedule_context = Hex.decode((String)((String)buf.get("key_schedule_context")));
                    byte[] secret = Hex.decode((String)((String)buf.get("secret")));
                    byte[] key = Hex.decode((String)((String)buf.get("key")));
                    byte[] base_nonce = Hex.decode((String)((String)buf.get("base_nonce")));
                    byte[] exporter_secret = Hex.decode((String)((String)buf.get("exporter_secret")));
                    if (mode == 2 || mode == 3) {
                        ikmS = Hex.decode((String)((String)buf.get("ikmS")));
                        skSm = Hex.decode((String)((String)buf.get("skSm")));
                        pkSm = Hex.decode((String)((String)buf.get("pkSm")));
                    }
                    if (mode == 1 || mode == 3) {
                        psk = Hex.decode((String)((String)buf.get("psk")));
                        psk_id = Hex.decode((String)((String)buf.get("psk_id")));
                    }
                    HPKE hpke = new HPKE(mode, kem_id, kdf_id, aead_id);
                    AEAD aead = new AEAD(aead_id, key, base_nonce);
                    Iterator it = ((AbstractList)encryptions).iterator();
                    while (it.hasNext()) {
                        Encryption encryption = (Encryption)it.next();
                        byte[] got_ct = aead.seal(encryption.aad, encryption.pt);
                        Assert.assertTrue((String)"AEAD failed Sealing:", (boolean)Arrays.areEqual((byte[])got_ct, (byte[])encryption.ct));
                    }
                    AsymmetricCipherKeyPair derivedKeyPairR = hpke.deriveKeyPair(ikmR);
                    AsymmetricCipherKeyPair kp = hpke.deserializePrivateKey(skRm, pkRm);
                    Assert.assertTrue((String)"serialize public key failed", (boolean)Arrays.areEqual((byte[])pkRm, (byte[])hpke.serializePublicKey(kp.getPublic())));
                    Assert.assertTrue((String)"serialize private key failed", (boolean)Arrays.areEqual((byte[])skRm, (byte[])hpke.serializePrivateKey(kp.getPrivate())));
                    Assert.assertTrue((String)"receiver derived public key pair incorrect", (boolean)Arrays.areEqual((byte[])pkRm, (byte[])hpke.serializePublicKey(derivedKeyPairR.getPublic())));
                    Assert.assertTrue((String)"receiver derived secret key pair incorrect", (boolean)Arrays.areEqual((byte[])skRm, (byte[])hpke.serializePrivateKey(derivedKeyPairR.getPrivate())));
                    if (mode == 2 || mode == 3) {
                        AsymmetricCipherKeyPair derivedSenderKeyPair = hpke.deriveKeyPair(ikmS);
                        Assert.assertTrue((String)"sender derived public key pair incorrect", (boolean)Arrays.areEqual((byte[])pkSm, (byte[])hpke.serializePublicKey(derivedSenderKeyPair.getPublic())));
                        Assert.assertTrue((String)"sender derived private key pair incorrect", (boolean)Arrays.areEqual((byte[])skSm, (byte[])hpke.serializePrivateKey(derivedSenderKeyPair.getPrivate())));
                    }
                    AsymmetricCipherKeyPair derivedEKeyPair = hpke.deriveKeyPair(ikmE);
                    Assert.assertTrue((String)"ephemeral derived public key pair incorrect", (boolean)Arrays.areEqual((byte[])pkEm, (byte[])hpke.serializePublicKey(derivedEKeyPair.getPublic())));
                    Assert.assertTrue((String)"ephemeral derived private key pair incorrect", (boolean)Arrays.areEqual((byte[])skEm, (byte[])hpke.serializePrivateKey(derivedEKeyPair.getPrivate())));
                    HPKEContext c = null;
                    AsymmetricKeyParameter senderPub = null;
                    switch (mode) {
                        case 0: {
                            c = hpke.setupBaseR(pkEm, kp, info);
                            break;
                        }
                        case 1: {
                            c = hpke.setupPSKR(pkEm, kp, info, psk, psk_id);
                            break;
                        }
                        case 2: {
                            senderPub = hpke.deserializePublicKey(pkSm);
                            c = hpke.setupAuthR(pkEm, kp, info, senderPub);
                            break;
                        }
                        case 3: {
                            senderPub = hpke.deserializePublicKey(pkSm);
                            c = hpke.setupAuthPSKR(pkEm, kp, info, psk, psk_id, senderPub);
                            break;
                        }
                        default: {
                            Assert.fail((String)"invalid mode");
                        }
                    }
                    int i = 0;
                    while (i < encryptions.size()) {
                        Encryption encryption = (Encryption)encryptions.get(i);
                        if (i == 0) {
                            byte[] message = hpke.open(pkEm, kp, info, encryption.aad, encryption.ct, psk, psk_id, senderPub);
                            Assert.assertTrue((String)"Single-shot failed", (boolean)Arrays.areEqual((byte[])message, (byte[])encryption.pt));
                        }
                        byte[] got_pt = c.open(encryption.aad, encryption.ct);
                        Assert.assertTrue((String)"context failed Open", (boolean)Arrays.areEqual((byte[])got_pt, (byte[])encryption.pt));
                        ++i;
                    }
                    Iterator it2 = ((AbstractList)exports).iterator();
                    while (it2.hasNext()) {
                        Export export = (Export)it2.next();
                        byte[] got_val = c.export(export.exporterContext, export.L);
                        Assert.assertTrue((String)"context failed Open", (boolean)Arrays.areEqual((byte[])got_val, (byte[])export.exportedValue));
                    }
                }
                buf.clear();
                encryptions.clear();
                exports.clear();
                continue;
            }
            int a = line.indexOf("=");
            if (a > -1) {
                buf.put(line.substring(0, a).trim(), line.substring(a + 1).trim());
            }
            if (line.equals("encryptionsSTART")) {
                while ((line = bin.readLine()) != null) {
                    if ((line = line.trim()).equals("encryptionsSTOP")) break;
                    if (line.equals("<")) {
                        byte[] aad = Hex.decode((String)((String)encBuf.get("aad")));
                        byte[] ct = Hex.decode((String)((String)encBuf.get("ct")));
                        byte[] nonce = Hex.decode((String)((String)encBuf.get("nonce")));
                        byte[] pt = Hex.decode((String)((String)encBuf.get("pt")));
                        encryptions.add(new Encryption(aad, ct, nonce, pt));
                        encBuf.clear();
                        continue;
                    }
                    int b = line.indexOf("=");
                    if (b <= -1) continue;
                    encBuf.put(line.substring(0, b).trim(), line.substring(b + 1).trim());
                }
            }
            if (!line.equals("exportsSTART")) continue;
            while ((line = bin.readLine()) != null) {
                if ((line = line.trim()).equals("exportsSTOP")) continue block6;
                if (line.equals("<")) {
                    byte[] exporterContext = Hex.decode((String)((String)expBuf.get("exporter_context")));
                    int L = Integer.parseInt((String)expBuf.get("L"));
                    byte[] exportedValue = Hex.decode((String)((String)expBuf.get("exported_value")));
                    exports.add(new Export(exporterContext, L, exportedValue));
                    expBuf.clear();
                    continue;
                }
                int b = line.indexOf("=");
                if (b <= -1) continue;
                expBuf.put(line.substring(0, b).trim(), line.substring(b + 1).trim());
            }
        }
    }

    static class Encryption {
        byte[] aad;
        byte[] ct;
        byte[] nonce;
        byte[] pt;

        Encryption(byte[] aad, byte[] ct, byte[] nonce, byte[] pt) {
            this.aad = aad;
            this.ct = ct;
            this.nonce = nonce;
            this.pt = pt;
        }
    }

    static class Export {
        byte[] exporterContext;
        int L;
        byte[] exportedValue;

        Export(byte[] exporterContext, int L, byte[] exportedValue) {
            this.exporterContext = exporterContext;
            this.L = L;
            this.exportedValue = exportedValue;
        }
    }
}

