#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>

#ifdef __WIN32__
#include <windows.h>
#include <wincrypt.h>
#endif

/*

dhbitty, a small public key encryption tool.

2012-08-06: Release. This is slightly different from the preview version.

Written by yarrkov; cipherdev.org

Public domain.

*/

typedef unsigned int uint;

#define MIN(a, b) ((a) < (b) ? (a) : (b))

//------------------------------------------------------------------------------

void failNow(char *message) {
    fprintf(stderr, "%s\n", message);
    exit(1);
}

//------------------------------------------------------------------------------

#define B_WORDS 17
#define B_FULL (B_WORDS * 2)
#define B_BITS 15
#define B_MASK ((1 << B_BITS) - 1)
#define B_C 19
#define B_BYTES 32

typedef uint32_t qnum[B_FULL];

const qnum qnum_modulus = {
    0x7fed, 0x7fff, 0x7fff, 0x7fff, 0x7fff, 0x7fff,
    0x7fff, 0x7fff, 0x7fff, 0x7fff, 0x7fff, 0x7fff,
    0x7fff, 0x7fff, 0x7fff, 0x7fff, 0x7fff
};

void setZero(uint32_t *x, uint len) {
    uint i;
    for (i = 0; i < len; i += 1) {
        x[i] = 0;
    }
}

void qCopy(qnum dst, const qnum src) {
    uint i;
    for (i = 0; i < B_FULL; i += 1) {
        dst[i] = src[i];
    }
}

void qFlow(qnum x) {
    uint i, sum, carry = 0;
    for (i = 0; i < B_FULL; i += 1) {
        sum = x[i] + carry;
        carry = sum >> B_BITS;
        x[i] = sum & B_MASK;
    }
}

void qAdd2(qnum dst, const qnum a, const qnum b) {
    uint i;
    for (i = 0; i < B_FULL; i += 1) {
        dst[i] = a[i] + b[i];
    }
    qFlow(dst);
}

void qSub2(qnum dst, const qnum a, const qnum b) {
    uint i;
    for (i = 0; i < B_FULL; i += 1) {
        dst[i] = a[i] + (b[i] ^ B_MASK);
    }
    dst[0] += 1;
    qFlow(dst);
}

void qMulSmall(qnum dst, const qnum src, uint n) {
    uint i;
    for (i = 0; i < B_WORDS; i += 1) {
        dst[i] += src[i] * n;
    }
}

void qReduce(qnum val) {
    uint i;
    qnum x, ct = {B_C};
    qAdd2(val, val, ct);
    for (i = 0; i < 3; i += 1) {
        setZero(x, B_FULL);
        qMulSmall(x, &val[B_WORDS], B_C);
        setZero(&val[B_WORDS], B_WORDS);
        qAdd2(val, val, x);
    }
    qSub2(val, val, ct);
}

void qAdd(qnum dst, const qnum a, const qnum b) {
    qAdd2(dst, a, b);
    qReduce(dst);
}

void qSub(qnum dst, const qnum a, const qnum b) {
    qnum t;
    qAdd2(t, a, qnum_modulus);
    qSub2(dst, t, b);
    qReduce(dst);
}

void qMul(qnum dst, const qnum a, const qnum b) {
    uint i;
    qnum x;
    setZero(x, B_FULL);
    for (i = 0; i < B_WORDS; i += 1) {
        qMulSmall(&x[i], a, b[i]);
        qFlow(x);
    }
    qCopy(dst, x);
    qReduce(dst);
}

void qInvert(qnum val) {
    qnum product = {1}, g;
    qCopy(g, val);
    uint i, c = -(B_C + 2);
    for (i = 0; i < B_WORDS * B_BITS; i += 1) {
        if (i >= sizeof(uint) * 8 || (c & 1)) {
            qMul(product, product, g);
        }
        qMul(g, g, g);
        c >>= 1;
    }
    qCopy(val, product);
}

void loadInt(qnum out, const uint8_t in[B_BYTES]) {
    uint i, pos, shift;
    setZero(out, B_FULL);
    for (i = 0; i < B_WORDS * B_BITS; i += 8) {
        pos = i / B_BITS;
        shift = i % B_BITS;
        out[pos] += in[i / 8] << shift;
    }
    qFlow(out);
}

void storeInt(uint8_t out[B_BYTES], const qnum in) {
    uint i, pos, shift, sum;
    for (i = 0; i < B_WORDS * B_BITS; i += 8) {
        pos = i / B_BITS;
        shift = i % B_BITS;
        sum = in[pos] >> shift;
        if (pos < B_WORDS - 1) {
            sum += in[pos + 1] << (B_BITS - shift);
        }
        out[i / 8] = sum & 0xff;
    }
}

//------------------------------------------------------------------------------

typedef struct {
    qnum x, z;
} cpoint;

void decodePoint(cpoint *res, const uint8_t in[32]) {
    loadInt(res->x, in);
    setZero(res->z, B_FULL);
    res->z[0] = 1;
}

void encodePoint(uint8_t out[32], const cpoint *pt) {
    qnum z, zx;
    qCopy(z, pt->z);
    qInvert(z);
    qMul(zx, z, pt->x);
    storeInt(out, zx);
}

void curveOp(cpoint *pta, cpoint *ptb, const cpoint *q) {
    qnum a, aa, b, bb, c, d, da, cb, e, a24 = {23362, 3};
    qAdd(a, pta->x, pta->z);
    qMul(aa, a, a);
    qSub(b, pta->x, pta->z);
    qMul(bb, b, b);
    qSub(e, aa, bb);
    qAdd(c, ptb->x, ptb->z);
    qSub(d, ptb->x, ptb->z);
    qMul(da, d, a);
    qMul(cb, c, b);
    qAdd(a, da, cb);
    qMul(ptb->x, a, a);
    qSub(b, da, cb);
    qMul(b, b, b);
    qMul(ptb->z, q->x, b);
    qMul(pta->x, aa, bb);
    qMul(a, a24, e);
    qAdd(a, a, bb);
    qMul(pta->z, e, a);
}

void cndSwap(void *a, void *b, uint len, int d) {
    uint8_t *a8 = a, *b8 = b, m0 = -d, m1 = ~(-d), t;
    uint i;
    for (i = 0; i < len; i += 1) {
        t = a8[i];
        a8[i] = (m0 & b8[i]) ^ (m1 & t);
        b8[i] = (m1 & b8[i]) ^ (m0 & t);
    }
}

void curveMul(cpoint *q, const uint8_t n[32]) {
    int i, cnd;
    cpoint pta = {{1}, {1}}, ptb;
    memcpy(&ptb, q, sizeof(pta));
    for (i = 255; i >= 0; i -= 1) {
        cnd = (n[i / 8] >> (i % 8)) & 1;
        cndSwap(&pta, &ptb, sizeof(pta), cnd);
        curveOp(&pta, &ptb, q);
        cndSwap(&pta, &ptb, sizeof(pta), cnd);
    }
    memcpy(q, &pta, sizeof(*q));
}

void DH(uint8_t g[32], const uint8_t e[32]) {
    uint8_t e2[32];
    cpoint q;
    decodePoint(&q, g);
    memcpy(e2, e, 32);
    e2[0] &= 0xf8;
    e2[31] &= 0x7f;
    e2[31] |= 0x40;
    curveMul(&q, e2);
    encodePoint(g, &q);
}

void initG(uint8_t dst[32]) {
    memset(dst, 0, 32);
    dst[0] = 9;
}

//------------------------------------------------------------------------------

void xorBlock(void *dst, const void *src, uint len) {
    uint i;
    for (i = 0; i < len; i += 1) {
        ((uint8_t*)dst)[i] ^= ((uint8_t*)src)[i];
    }
}

#define ROL32(a, b) (((a)<<(b))|((a)>>(32-(b))))

typedef struct {
    uint32_t state[64];
    uint8_t buffer[128];
    uint unflushed, mode;
} sponge_ctx;

void spongeInit(sponge_ctx *ctx) {
    memset(ctx, 0, sizeof(*ctx));
}

void spongeScramble(uint32_t v[64]) {
    uint32_t r, z, y;
    for (r = 1; r <= 64 * 16; r += 1) {
        z = v[(r + 63) % 64];
        y = v[(r +  1) % 64];
        v[r % 64] ^= ROL32((z << 1) ^ y ^ r, 24) * 9;
    }
}

void spongeAbsorb(sponge_ctx *ctx, const uint8_t *in, uint len) {
    uint i;
    for (i = 0; i < len; i += 1) {
        ctx->buffer[ctx->unflushed] = in[i];
        ctx->unflushed += 1;
        if (ctx->unflushed == 128) {
            xorBlock(ctx->state, ctx->buffer, 128);
            spongeScramble(ctx->state);
            memset(ctx->buffer, 0, sizeof(ctx->buffer));
            ctx->unflushed = 0;
        }
    }
}

void spongeFinish(sponge_ctx *ctx) {
    if (ctx->mode == 0) {
        uint8_t one = 1;
        spongeAbsorb(ctx, &one, 1);
        ctx->state[63] ^= 1;
        xorBlock(ctx->state, ctx->buffer, 128);
        spongeScramble(ctx->state);
        ctx->mode = 1;
    }
}

void spongeSqueeze(sponge_ctx *ctx, uint8_t *out, uint len) {
    // This will skip stream if it's not requested in 128-byte blocks.
    uint i, full_blocks = len / 128, leftover = len % 128;
    spongeFinish(ctx);
    for (i = 0; i < full_blocks; i += 1) {
        spongeScramble(ctx->state);
        memcpy(&out[i * 128], ctx->state, 128);
    }
    if (leftover > 0) {
        spongeScramble(ctx->state);
        memcpy(&out[full_blocks * 128], ctx->state, leftover);
    }
}

//------------------------------------------------------------------------------

uint32_t zen32(uint32_t k[16]) {
    const uint len = 1 << 23;
    uint32_t *v = malloc(len * 4);
    uint32_t x = 0, p, r, i;

    for (r = 0; r < len; r += 1) {
        v[r] = r;
    }

    for (r = 0; r < len; r += 257) {
        v[r] ^= k[r % 16];
    }

    for (i = 0; i < 2; i += 1) {
        for (r = 0; r < len; r += 1) {
            v[r] = x = ROL32(x ^ v[r], 13) * 9;
        }
        for (r = p = 0; r < len; r += 1) {
            v[p] = x = ROL32(x ^ v[p], 13) * 9;
            p = (p + 2049) % len;
        }
    }

    free(v);
    return x;
}

void slothKdf(const char *passphrase, uint8_t out[32]) {
    uint32_t k[16], sval;
    uint8_t *salt = (uint8_t*)"dhbitty";
    sponge_ctx ctx;

    spongeInit(&ctx);
    spongeAbsorb(&ctx, salt, 8);
    spongeAbsorb(&ctx, (uint8_t*)passphrase, strlen(passphrase));
    spongeSqueeze(&ctx, (uint8_t*)k, 64);
    sval = zen32(k);

    spongeInit(&ctx);
    spongeAbsorb(&ctx, (uint8_t*)k, 64);
    spongeAbsorb(&ctx, (uint8_t*)&sval, 4);
    spongeSqueeze(&ctx, out, 32);
}

//------------------------------------------------------------------------------

void getRandomBytes(uint8_t *dst, uint len) {
    int test;
    memset(dst, 0, len);

    #ifdef __WIN32__
    HCRYPTPROV provider;
    CryptAcquireContext(&provider, NULL, NULL, 1, 0);
    test = CryptGenRandom(provider, len, dst);
    CryptReleaseContext(provider, 0);
    if (!test) {
        failNow("Failed to use CryptGenRandom.");
    }
    #else
    FILE *h;
    h = fopen("/dev/urandom", "rb");
    test = fread(dst, 1, len, h);
    fclose(h);
    if (test != len) {
        failNow("Failed to use /dev/urandom.");
    }
    #endif

    if (len > 8) { // Opportunistic test; always runs in this app.
        uint8_t b_or = 0;
        uint i;
        for (i = 0; i < len; i += 1) {
            b_or |= dst[i];
        }
        if (b_or == 0) {
            failNow("Random number generation failure.");
        }
    }
}

//------------------------------------------------------------------------------

void getKeysFromPassphrase(const char *passphrase, uint8_t private_key[32],
                           uint8_t public_key[32]) {
    slothKdf(passphrase, private_key);
    initG(public_key);
    DH(public_key, private_key);
}

void getSpongeContexts(const uint8_t key[32], const uint8_t nonce[16],
                       sponge_ctx *cipher, sponge_ctx *mac) {
    spongeInit(cipher);
    spongeAbsorb(cipher, key, 32);
    spongeAbsorb(cipher, nonce, 16);
    memcpy(mac, cipher, sizeof(*mac));
    spongeAbsorb(cipher, (uint8_t*)"encrypt", 8);
    spongeAbsorb(mac, (uint8_t*)"authenticate", 13);
}

//------------------------------------------------------------------------------

uint keyToHex(char *out, const uint8_t *key) {
    uint i, pos = 0;
    for (i = 0; i < 32; i += 1) {
        pos += sprintf(&out[pos], "%02x", key[i]);
    }
    return pos;
}

void hexToKey(uint8_t *public_key, const char *hex) {
    uint i, test, num;
    char test2[10];
    for (i = 0; i < 32; i += 1) {
        test = sscanf(&hex[i*2], "%02x", &num);
        sprintf(test2, "%02x", num);
        if (test != 1 || memcmp(&hex[i*2], test2, 2) != 0) {
            failNow("Tried to read a broken public key.");
        }
        public_key[i] = num;
    }
}

void generatePassphrase(char passphrase[24]) {
    char *alpha = "0123456789abcdefghjkmnpqrstvwxyz";
    uint i;
    getRandomBytes((uint8_t*)passphrase, 23);
    for (i = 0; i < 23; i += 1) {
        passphrase[i] = alpha[(uint)(passphrase[i]) % 32];
    }
    passphrase[23] = 0;
}

void getKeysFromUser(uint8_t private_key[32], uint8_t public_key[32], int gen) {
    char passphrase[256];
    uint len;
    while (1) {
        printf("username:passphrase (this is visible!): ");
        fgets(passphrase, 256, stdin);
        len = strlen(passphrase) - 1;
        passphrase[len] = 0;
        if (!gen) {
            break;
        }
        if (len == 0) {
            generatePassphrase(passphrase);
            printf("generated: %s\n", passphrase);
            len = strlen(passphrase);
        }
        if (len >= 10) {
            break;
        }
        printf("Please pick a stronger passphrase, such as seven words by diceware.\n");
    }
    getKeysFromPassphrase(passphrase, private_key, public_key);
}

void uiGenerate(char *out_fn) {
    int test, hex_key_len;
    char hex_key[256];
    uint8_t private_key[32], public_key[32];

    getKeysFromUser(private_key, public_key, 1);
    hex_key_len = keyToHex(hex_key, public_key);

    FILE *h;
    h = fopen(out_fn, "w");
    test = fwrite(hex_key, 1, hex_key_len, h);
    fclose(h);
    if (test != hex_key_len) {
        failNow("Couldn't write to public key file.");
    }
}

#define FILE_CHUNK_LEN (1024*1024*4)

void uiEncrypt(char *pub_fn, char *plaintext_fn, char *ciphertext_fn) {
    FILE *in_h, *out_h;
    int read_len;
    char hex_key[128];
    uint8_t private_key[32], public_keys[2][32], shared_key[32];
    uint8_t nonce[16], tag[16], *keystream, *data;

    // Read and parse recipient public key.
    in_h = fopen(pub_fn, "r");
    if (in_h == NULL) {
        failNow("Couldn't open public key file.");
    }
    read_len = fread(hex_key, 1, 127, in_h);
    hex_key[read_len] = 0;
    fclose(in_h);

    hexToKey(public_keys[0], hex_key);

    // Request passphrase and compute shared secret.
    getKeysFromUser(private_key, public_keys[1], 1);
    memcpy(shared_key, public_keys[0], 32);
    DH(shared_key, private_key);

    // Open input and output files.
    in_h = fopen(plaintext_fn, "rb");
    if (in_h == NULL) {
        failNow("Couldn't open input plaintext file.");
    }
    out_h = fopen(ciphertext_fn, "wb");
    if (out_h == NULL) {
        failNow("Couldn't open output ciphertext file.");
    }

    // Initialize stream cipher and MAC contexts.
    sponge_ctx cipher, mac;
    getRandomBytes(nonce, 16);
    getSpongeContexts(shared_key, nonce, &cipher, &mac);

    // Sort the public keys. This is not branchless.
    cndSwap(public_keys[0], public_keys[1], 32,
            (memcmp(public_keys[0], public_keys[1], 32) > 0) ? 1 : 0);

    // "Header".
    fwrite(public_keys, 2, 32, out_h);
    fwrite(nonce, 1, 16, out_h);

    // Let's stream!
    data = malloc(FILE_CHUNK_LEN);
    keystream = malloc(FILE_CHUNK_LEN);
    do {
        read_len = fread(data, 1, FILE_CHUNK_LEN, in_h);
        spongeSqueeze(&cipher, keystream, read_len);
        xorBlock(data, keystream, read_len);
        spongeAbsorb(&mac, data, read_len);
        fwrite(data, 1, read_len, out_h);
    } while (read_len == FILE_CHUNK_LEN);
    free(data);
    free(keystream);
    fclose(in_h);

    spongeSqueeze(&mac, tag, 16);
    fwrite(tag, 1, 16, out_h);
    fclose(out_h);
}

void uiDecrypt(char *ciphertext_fn, char *plaintext_fn) {
    FILE *in_h, *out_h;
    int read_len;
    char hex_key[128];
    uint8_t private_key[32], public_keys[3][32], shared_key[32];
    uint8_t nonce[16], right_tag[16], tag[16], *keystream, *data, or_b;
    uint ciphertext_length, pos, i;

    getKeysFromUser(private_key, public_keys[2], 0);

    in_h = fopen(ciphertext_fn, "rb");
    if (in_h == NULL) {
        failNow("Couldn't open input ciphertext file.");
    }
    fseek(in_h, 0, SEEK_END);
    ciphertext_length = ftell(in_h);
    fseek(in_h, 0, SEEK_SET);
    if (ciphertext_length < 32 * 2 + 16 + 16) {
        failNow("Input ciphertext file is too short.");
    }

    fread(public_keys, 2, 32, in_h);
    fread(nonce, 1, 16, in_h);

    // Another branching swap. Hardly matters.
    cndSwap(public_keys[0], public_keys[1], 32,
            memcmp(public_keys[1], public_keys[2], 32) == 0 ? 1 : 0);
    if (memcmp(public_keys[0], public_keys[2], 32) != 0) {
        failNow("Incorrect passphrase!");
    }

    out_h = fopen(plaintext_fn, "wb");
    if (out_h == NULL) {
        failNow("Couldn't open output plaintext file.");
    }

    memcpy(shared_key, public_keys[1], 32);
    DH(shared_key, private_key);

    sponge_ctx cipher, mac;
    getSpongeContexts(shared_key, nonce, &cipher, &mac);

    ciphertext_length -= 32 * 2 + 16 + 16;

    data = malloc(FILE_CHUNK_LEN);
    keystream = malloc(FILE_CHUNK_LEN);

    read_len = FILE_CHUNK_LEN;
    for (pos = 0; read_len == FILE_CHUNK_LEN; pos += FILE_CHUNK_LEN) {
        read_len = MIN(FILE_CHUNK_LEN, ciphertext_length - pos);
        fread(data, 1, read_len, in_h);
        spongeAbsorb(&mac, data, read_len);
        spongeSqueeze(&cipher, keystream, read_len);
        xorBlock(data, keystream, read_len);
        fwrite(data, 1, read_len, out_h);
    }

    free(data);
    free(keystream);
    fclose(out_h);

    spongeSqueeze(&mac, right_tag, 16);
    fread(&tag, 1, 16, in_h);
    fclose(in_h);

    or_b = 0;
    for (i = 0; i < 16; i += 1) {
        or_b |= tag[i] ^ right_tag[i];
    }

    if (or_b != 0) {
        remove(plaintext_fn);
        failNow("Tag was wrong; data seems to be corrupted!\n");
    }

    printf("This is the public key of file's secondary owner:\n");
    keyToHex(hex_key, public_keys[1]);
    printf("%s\n", hex_key);
}

int main(int argc, char **argv) {
    if (argc == 3 && strcmp(argv[1], "generate") == 0) {
        uiGenerate(argv[2]);
    } else
    if (argc == 5 && strcmp(argv[1], "encrypt") == 0) {
        uiEncrypt(argv[2], argv[3], argv[4]);
    } else
    if (argc == 4 && strcmp(argv[1], "decrypt") == 0) {
        uiDecrypt(argv[2], argv[3]);
    } else {
        printf("Usage:\n");
        printf("dhbitty generate publickey.txt\n");
        printf("dhbitty encrypt publickey.txt plaintext.tar ciphertext.tar.dhbt\n");
        printf("dhbitty decrypt ciphertext.tar.dhbt plaintext.tar\n");
    }
    printf("Done.\n");

    return 0;
}
