/* p2p message demo
 * pesco, 2009
 */

#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <assert.h>
#include <stdlib.h>
#include "monty.h"
#include "byteorder.h"

#define LOCALPORT 1337
#define HOSTIDLEN 8


/* IO functions */

void printbignum(FILE *f, int n, uint32_t *X)
{
        int i;

        for(i=0; i<n; i++) {
                fprintf(f, "%.8X", X[n-1-i]);
        }
	fprintf(f, "\n");
}

int readhexdigit(FILE *f)
{
	int c;

	for(;;) {
		c = getc(f);
		if(c>='0' && c<='9')
			return (c-'0');
		if(c>='a' && c<='f')
			return (c-'a'+10);
		if(c>='A' && c<='F')
			return (c-'A'+10);
	}
}
	
void readbignum(FILE *f, int n, uint32_t *X)
{
	int i,j;

	for(i=n-1; i>=0; i--) {
		uint32_t x=0;

		for(j=0; j<8; j++) {
			x <<= 4;
			x |= readhexdigit(f);
		}

		X[i] = x;
	}
}

struct host {
	char name[32];
	uint32_t *gaR;
	struct host *next;
};

int main(int argc, char **argv)
{
        /* 1536-bit modulus from RFC 3526, little-endian */
        uint32_t N[] = {0xFFFFFFFF, 0xFFFFFFFF, 0xCA237327, 0xF1746C08, 0x4ABC9804, 0x670C354E,
                        0x7096966D, 0x9ED52907, 0x208552BB, 0x1C62F356, 0xDCA3AD96, 0x83655D23,
                        0xFD24CF5F, 0x69163FA8, 0x1C55D39A, 0x98DA4836, 0xA163BF05, 0xC2007CB8,
                        0xECE45B3D, 0x49286651, 0x7C4B1FE6, 0xAE9F2411, 0x5A899FA5, 0xEE386BFB,
                        0xF406B7ED, 0x0BFF5CB6, 0xA637ED6B, 0xF44C42E9, 0x625E7EC6, 0xE485B576,
                        0x6D51C245, 0x4FE1356D, 0xF25F1437, 0x302B0A6D, 0xCD3A431B, 0xEF9519B3,
                        0x8E3404DD, 0x514A0879, 0x3B139B22, 0x020BBEA6, 0x8A67CC74, 0x29024E08,
                        0x80DC1CD1, 0xC4C6628B, 0x2168C234, 0xC90FDAA2, 0xFFFFFFFF, 0xFFFFFFFF};
        /* 2^3072 mod N */
        uint32_t RR[] = {0x32c695e0, 0xf115d27d, 0x67478c73, 0x8e0e3e21, 0x8397f245, 0xd0ab92e1,
                         0xbcd49d68, 0xf466ee5f, 0x3b01e018, 0x8f2331b1, 0x98b5fb62, 0x7e8cd2ac,
                         0x7a58f170, 0xb9052bb4, 0xdb102d39, 0xb004a750, 0x93ae1ceb, 0x04a541ff,
                         0x8e434130, 0x07cd0a62, 0x04b9f796, 0x1c729c7e, 0x196b7e88, 0xb8fe6121,
                         0x0223b76b, 0x8e1abd78, 0xd46fec23, 0x22c296e9, 0xb270521b, 0xd62a0eea,
                         0xd4053f54, 0xdc541a4e, 0x969b7f02, 0xf8056564, 0xa87c7b37, 0x0be49647,
                         0x67984460, 0x57b59348, 0x9a36a51f, 0x102630fa, 0xcc2456ef, 0xe9c3fa02,
                         0x7929a1c7, 0xae594104, 0x6cc1ebd2, 0xee9c9a21, 0x59541c01, 0xe3b33c72};
        int n = sizeof(N) / sizeof(uint32_t);
	uint32_t R[n];
        uint32_t one[n];

        uint32_t g[n];          /* generator = 2 */
        uint32_t gR[n];         /* generator, montgomery representation */
	uint32_t b[n];          /* random number, bob */
	uint32_t gbR[n];        /* g^b, montgomery rep */
	uint32_t gabR[n];       /* g^ab, montgomery rep */
	uint32_t mR[n];         /* message, montgomery rep */
	uint32_t m[n];          /* message, converted back */

	FILE *fhosts;
	struct host *h, *hosts = NULL;
	int nhosts, i, j;
	struct sockaddr_in localnode;  /* 127.0.0.1:LOCALPORT */
	int sockfd;

	char sender[16];
	int senderlen;
	unsigned char pkt[2*n*4];


	/* initialize */
        memset(one,  0, sizeof(N)); one[0] = 1;
        memset(g,    0, sizeof(N)); g[0] = 2;

	/* convert stuff to montgomery representation */
	monty_mul(n, N, R, one, RR);
        monty_mul(n, N, gR, g, RR);


	/* read hosts file */
	fhosts = fopen("hosts", "r");
	if(fhosts == NULL) {
		perror("hosts");
		return -1;
	}
	hosts = NULL;
	nhosts = 0;
	for(i=1; !feof(fhosts); i++) {
		char name[32];

		h = malloc(sizeof(struct host));
		h->gaR = malloc(sizeof(N));
		if(!h || !h->gaR) {
			perror("malloc");
			return -1;
		}

		if(fscanf(fhosts, "%32s", h->name) < 1)
			continue;
		readbignum(fhosts, n, h->gaR);

		h->next = hosts;
		hosts = h;

		nhosts++;
	}
	fclose(fhosts);
	printf("%d hosts loaded\n", nhosts);

	/* open socket */
	sockfd = socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP);
	if(sockfd < 0) {
		perror("socket");
		return -1;
	}
	localnode.sin_family = AF_INET;
	inet_aton("127.0.0.1", &localnode.sin_addr);
	localnode.sin_port = htons(LOCALPORT);

	/* ask for sending station id */
	sender[0]='<';
	sender[1]='\0';
	do {
		printf("our station id? ");
		fgets(sender+1, 12, stdin);
		senderlen = strlen(sender);
		if(sender[--senderlen] != '\n')
			senderlen++;
		sender[senderlen++] = '>';
		sender[senderlen++] = ' ';
	} while(senderlen<4);

	assert(HOSTIDLEN < n);
	printf("enter messages, format \"host: message\"\n");
	while(!feof(stdin)) {
		unsigned char buf[1024];
		unsigned char *to, *msg;
		int msglen;
		int r;

		/* read input */
		if(!fgets(buf, 1024, stdin))
			break;
		to = strtok(buf, ":");
		msg = strtok(NULL, "");
		if(!to || !msg) {
			printf("no parse\n");
			continue;
		}
		while(*msg && isspace(*msg))
			msg++;
		msglen = strlen(msg);
		if(msg[msglen-1] == '\n')
			msg[msglen-1] = '\0';

		/* look up recipient in host list */
		for(h=hosts; h; h=h->next) {
			if(!strcmp(h->name, to))
				break;
		}
		if(!h) {
			printf("%s: unknown host\n", to);
			continue;
		}

		/* input marshaling */
		int buflen = n*4;
		memset(m, 0, buflen);
		char *p_hostid = (char *)m;
		char *p_sender = p_hostid + HOSTIDLEN;
		char *p_msg    = p_sender + senderlen;
		memcpy(p_hostid, h->gaR, HOSTIDLEN); buflen -= HOSTIDLEN;
		strncpy(p_sender, sender, buflen);   buflen -= senderlen;
		strncpy(p_msg, msg, buflen);
		ltoh(p_sender, n - HOSTIDLEN/4);

		/* encrypt */
		mrand(n, N, b);
		monty_exp(n, N, R, gbR, gR, b);
		monty_exp(n, N, R, gabR, h->gaR, b);
		monty_mul(n, N, mR, m, RR);
		muladd(n, mR, 0, 1, gabR);

		/* output marshaling */
		memcpy(pkt, gbR, n*4);
		memcpy(pkt+n*4, mR, n*4);
		htol((uint32_t *)pkt, n*2);

		/* transmit packet to local node */
		r = sendto(sockfd, pkt, n*8, 0,
			(struct sockaddr *)&localnode,
			sizeof(struct sockaddr_in));
		if(r < 0)
			perror("sendto");
	}

	close(sockfd);
	return 0;
}

