shithub: purgatorio

ref: 38c0cf0737de03906729054ad7b7b011dd1ca475
dir: /os/ip/esp.c/

View raw version
#include	"u.h"
#include	"../port/lib.h"
#include	"mem.h"
#include	"dat.h"
#include	"fns.h"
#include	"../port/error.h"

#include	"ip.h"

#include	"libsec.h"

typedef struct Esphdr Esphdr;
typedef struct Esptail Esptail;
typedef struct Userhdr Userhdr;
typedef struct Esppriv Esppriv;
typedef struct Espcb Espcb;
typedef struct Algorithm Algorithm;
typedef struct Esprc4 Esprc4;

#define DPRINT if(0)print

enum
{
	IP_ESPPROTO	= 50,
	EsphdrSize	= 28,	// includes IP header
	IphdrSize	= 20,	// options have been striped
	EsptailSize	= 2,	// does not include pad or auth data
	UserhdrSize	= 4,	// user visable header size - if enabled
};

struct Esphdr
{
	/* ip header */
	uchar	vihl;		/* Version and header length */
	uchar	tos;		/* Type of service */
	uchar	length[2];	/* packet length */
	uchar	id[2];		/* Identification */
	uchar	frag[2];	/* Fragment information */
	uchar	Unused;	
	uchar	espproto;	/* Protocol */
	uchar	espplen[2];	/* Header plus data length */
	uchar	espsrc[4];	/* Ip source */
	uchar	espdst[4];	/* Ip destination */

	/* esp header */
	uchar	espspi[4];	/* Security parameter index */
	uchar	espseq[4];	/* Sequence number */
};

struct Esptail
{
	uchar	pad;
	uchar	nexthdr;
};

/* header as seen by the user */
struct Userhdr
{
	uchar	nexthdr;	// next protocol
	uchar	unused[3];
};

struct Esppriv
{
	ulong	in;
	ulong	inerrors;
};

/*
 *  protocol specific part of Conv
 */
struct Espcb
{
	int	incoming;
	int	header;		// user user level header
	ulong	spi;
	ulong	seq;		// last seq sent
	ulong	window;		// for replay attacks
	char	*espalg;
	void	*espstate;	// other state for esp
	int	espivlen;	// in bytes
	int	espblklen;
	int	(*cipher)(Espcb*, uchar *buf, int len);
	char	*ahalg;
	void	*ahstate;	// other state for esp
	int	ahlen;		// auth data length in bytes
	int	ahblklen;
	int	(*auth)(Espcb*, uchar *buf, int len, uchar *hash);
};

struct Algorithm
{
	char 	*name;
	int	keylen;		// in bits
	void	(*init)(Espcb*, char* name, uchar *key, int keylen);
};


enum {
	RC4forward	= 10*1024*1024,	// maximum skip forward
	RC4back = 100*1024,		// maximum look back
};

struct Esprc4
{
	ulong cseq;	// current byte sequence number
	RC4state current;

	int ovalid;	// old is valid
	ulong lgseq; // last good sequence
	ulong oseq;	// old byte sequence number
	RC4state old;
};

static	Conv* convlookup(Proto *esp, ulong spi);
static	char *setalg(Espcb *ecb, char **f, int n, Algorithm *alg);
static	void nullespinit(Espcb*, char*, uchar *key, int keylen);
static	void nullahinit(Espcb*, char*, uchar *key, int keylen);
static	void shaahinit(Espcb*, char*, uchar *key, int keylen);
static	void md5ahinit(Espcb*, char*, uchar *key, int keylen);
static	void desespinit(Espcb *ecb, char *name, uchar *k, int n);
static	void rc4espinit(Espcb *ecb, char *name, uchar *k, int n);
static	void espkick(void *x);

static Algorithm espalg[] =
{
	"null",			0,	nullespinit,
	"des_56_cbc",		64,	desespinit,
	"rc4_128",		128,	rc4espinit,
	nil,			0,	nil,
};

static Algorithm ahalg[] =
{
	"null",			0,	nullahinit,
	"hmac_sha1_96",		128,	shaahinit,
	"hmac_md5_96",		128,	md5ahinit,
	nil,			0,	nil,
};

static char*
espconnect(Conv *c, char **argv, int argc)
{
	char *p, *pp;
	char *e = nil;
	ulong spi;
	Espcb *ecb = (Espcb*)c->ptcl;

	switch(argc) {
	default:
		e = "bad args to connect";
		break;
	case 2:
		p = strchr(argv[1], '!');
		if(p == nil){
			e = "malformed address";
			break;
		}
		*p++ = 0;
		parseip(c->raddr, argv[1]);
		findlocalip(c->p->f, c->laddr, c->raddr);
		ecb->incoming = 0;
		ecb->seq = 0;
		if(strcmp(p, "*") == 0) {
			qlock(c->p);
			for(;;) {
				spi = nrand(1<<16) + 256;
				if(convlookup(c->p, spi) == nil)
					break;
			}
			qunlock(c->p);
			ecb->spi = spi;
			ecb->incoming = 1;
			qhangup(c->wq, nil);
		} else {
			spi = strtoul(p, &pp, 10);
			if(pp == p) {
				e = "malformed address";
				break;
			}
			ecb->spi = spi;
			qhangup(c->rq, nil);
		}
		nullespinit(ecb, "null", nil, 0);
		nullahinit(ecb, "null", nil, 0);
	}
	Fsconnected(c, e);

	return e;
}


static int
espstate(Conv *c, char *state, int n)
{
	return snprint(state, n, "%s", c->inuse?"Open\n":"Closed\n");
}

static void
espcreate(Conv *c)
{
	c->rq = qopen(64*1024, Qmsg, 0, 0);
	c->wq = qopen(64*1024, Qkick, espkick, c);
}

static void
espclose(Conv *c)
{
	Espcb *ecb;

	qclose(c->rq);
	qclose(c->wq);
	qclose(c->eq);
	ipmove(c->laddr, IPnoaddr);
	ipmove(c->raddr, IPnoaddr);

	ecb = (Espcb*)c->ptcl;
	free(ecb->espstate);
	free(ecb->ahstate);
	memset(ecb, 0, sizeof(Espcb));
}

static void
espkick(void *x)
{
	Conv *c = x;
	Esphdr *eh;
	Esptail *et;
	Userhdr *uh;
	Espcb *ecb;
	Block *bp;
	int nexthdr;
	int payload;
	int pad;
	int align;
	uchar *auth;

	bp = qget(c->wq);
	if(bp == nil)
		return;

	qlock(c);
	ecb = c->ptcl;

	if(ecb->header) {
		/* make sure the message has a User header */
		bp = pullupblock(bp, UserhdrSize);
		if(bp == nil) {
			qunlock(c);
			return;
		}
		uh = (Userhdr*)bp->rp;
		nexthdr = uh->nexthdr;
		bp->rp += UserhdrSize;
	} else {
		nexthdr = 0;  // what should this be?
	}

	payload = BLEN(bp) + ecb->espivlen;

	/* Make space to fit ip header */
	bp = padblock(bp, EsphdrSize + ecb->espivlen);

	align = 4;
	if(ecb->espblklen > align)
		align = ecb->espblklen;
	if(align % ecb->ahblklen != 0)
		panic("espkick: ahblklen is important after all");
	pad = (align-1) - (payload + EsptailSize-1)%align;

	/*
	 * Make space for tail
	 * this is done by calling padblock with a negative size
	 * Padblock does not change bp->wp!
	 */
	bp = padblock(bp, -(pad+EsptailSize+ecb->ahlen));
	bp->wp += pad+EsptailSize+ecb->ahlen;

	eh = (Esphdr *)(bp->rp);
	et = (Esptail*)(bp->rp + EsphdrSize + payload + pad);

	// fill in tail
	et->pad = pad;
	et->nexthdr = nexthdr;

	ecb->cipher(ecb, bp->rp+EsphdrSize, payload+pad+EsptailSize);
	auth = bp->rp + EsphdrSize + payload + pad + EsptailSize;

	// fill in head
	eh->vihl = IP_VER4;
	hnputl(eh->espspi, ecb->spi);
	hnputl(eh->espseq, ++ecb->seq);
	v6tov4(eh->espsrc, c->laddr);
	v6tov4(eh->espdst, c->raddr);
	eh->espproto = IP_ESPPROTO;
	eh->frag[0] = 0;
	eh->frag[1] = 0;

	ecb->auth(ecb, bp->rp+IphdrSize, (EsphdrSize-IphdrSize)+payload+pad+EsptailSize, auth);

	qunlock(c);
	//print("esp: pass down: %uld\n", BLEN(bp));
	ipoput4(c->p->f, bp, 0, c->ttl, c->tos, c);
}

void
espiput(Proto *esp, Ipifc*, Block *bp)
{
	Esphdr *eh;
	Esptail *et;
	Userhdr *uh;
	Conv *c;
	Espcb *ecb;
	uchar raddr[IPaddrlen], laddr[IPaddrlen];
	Fs *f;
	uchar *auth;
	ulong spi;
	int payload, nexthdr;

	f = esp->f;

	bp = pullupblock(bp, EsphdrSize+EsptailSize);
	if(bp == nil) {
		netlog(f, Logesp, "esp: short packet\n");
		return;
	}

	eh = (Esphdr*)(bp->rp);
	spi = nhgetl(eh->espspi);
	v4tov6(raddr, eh->espsrc);
	v4tov6(laddr, eh->espdst);

	qlock(esp);
	/* Look for a conversation structure for this port */
	c = convlookup(esp, spi);
	if(c == nil) {
		qunlock(esp);
		netlog(f, Logesp, "esp: no conv %I -> %I!%d\n", raddr,
			laddr, spi);
		icmpnoconv(f, bp);
		freeblist(bp);
		return;
	}

	qlock(c);
	qunlock(esp);

	ecb = c->ptcl;
	// too hard to do decryption/authentication on block lists
	if(bp->next)
		bp = concatblock(bp);

	if(BLEN(bp) < EsphdrSize + ecb->espivlen + EsptailSize + ecb->ahlen) {
		qunlock(c);
		netlog(f, Logesp, "esp: short block %I -> %I!%d\n", raddr,
			laddr, spi);
		freeb(bp);
		return;
	}

	eh = (Esphdr*)(bp->rp);
	auth = bp->wp - ecb->ahlen;
	if(!ecb->auth(ecb, eh->espspi, auth-eh->espspi, auth)) {
		qunlock(c);
print("esp: bad auth %I -> %I!%ld\n", raddr, laddr, spi);
		netlog(f, Logesp, "esp: bad auth %I -> %I!%d\n", raddr,
			laddr, spi);
		freeb(bp);
		return;
	}

	payload = BLEN(bp)-EsphdrSize-ecb->ahlen;
	if(payload<=0 || payload%4 != 0 || payload%ecb->espblklen!=0) {
		qunlock(c);
		netlog(f, Logesp, "esp: bad length %I -> %I!%d payload=%d BLEN=%d\n", raddr,
			laddr, spi, payload, BLEN(bp));
		freeb(bp);
		return;
	}
	if(!ecb->cipher(ecb, bp->rp+EsphdrSize, payload)) {
		qunlock(c);
print("esp: cipher failed %I -> %I!%ld: %r\n", raddr, laddr, spi);
		netlog(f, Logesp, "esp: cipher failed %I -> %I!%d: %r\n", raddr,
			laddr, spi);
		freeb(bp);
		return;
	}

	payload -= EsptailSize;
	et = (Esptail*)(bp->rp + EsphdrSize + payload);
	payload -= et->pad + ecb->espivlen;
	nexthdr = et->nexthdr;
	if(payload <= 0) {
		qunlock(c);
		netlog(f, Logesp, "esp: short packet after decrypt %I -> %I!%d\n", raddr,
			laddr, spi);
		freeb(bp);
		return;
	}

	// trim packet
	bp->rp += EsphdrSize + ecb->espivlen;
	bp->wp = bp->rp + payload;
	if(ecb->header) {
		// assume UserhdrSize < EsphdrSize
		bp->rp -= UserhdrSize;
		uh = (Userhdr*)bp->rp;
		memset(uh, 0, UserhdrSize);
		uh->nexthdr = nexthdr;
	}

	if(qfull(c->rq)){
		netlog(f, Logesp, "esp: qfull %I -> %I.%uld\n", raddr,
			laddr, spi);
		freeblist(bp);
	}else {
//print("esp: pass up: %uld\n", BLEN(bp));
		qpass(c->rq, bp);
	}

	qunlock(c);
}

char*
espctl(Conv *c, char **f, int n)
{
	Espcb *ecb = c->ptcl;
	char *e = nil;

	if(strcmp(f[0], "esp") == 0)
		e = setalg(ecb, f, n, espalg);
	else if(strcmp(f[0], "ah") == 0)
		e = setalg(ecb, f, n, ahalg);
	else if(strcmp(f[0], "header") == 0)
		ecb->header = 1;
	else if(strcmp(f[0], "noheader") == 0)
		ecb->header = 0;
	else
		e = "unknown control request";
	return e;
}

void
espadvise(Proto *esp, Block *bp, char *msg)
{
	Esphdr *h;
	Conv *c;
	ulong spi;

	h = (Esphdr*)(bp->rp);

	spi = nhgets(h->espspi);
	qlock(esp);
	c = convlookup(esp, spi);
	if(c != nil) {
		qhangup(c->rq, msg);
		qhangup(c->wq, msg);
	}
	qunlock(esp);
	freeblist(bp);
}

int
espstats(Proto *esp, char *buf, int len)
{
	Esppriv *upriv;

	upriv = esp->priv;
	return snprint(buf, len, "%lud %lud\n",
		upriv->in,
		upriv->inerrors);
}

static int
esplocal(Conv *c, char *buf, int len)
{
	Espcb *ecb = c->ptcl;
	int n;

	qlock(c);
	if(ecb->incoming)
		n = snprint(buf, len, "%I!%uld\n", c->laddr, ecb->spi);
	else
		n = snprint(buf, len, "%I\n", c->laddr);
	qunlock(c);
	return n;
}

static int
espremote(Conv *c, char *buf, int len)
{
	Espcb *ecb = c->ptcl;
	int n;

	qlock(c);
	if(ecb->incoming)
		n = snprint(buf, len, "%I\n", c->raddr);
	else
		n = snprint(buf, len, "%I!%uld\n", c->raddr, ecb->spi);
	qunlock(c);
	return n;
}

static	Conv*
convlookup(Proto *esp, ulong spi)
{
	Conv *c, **p;
	Espcb *ecb;

	for(p=esp->conv; *p; p++){
		c = *p;
		ecb = c->ptcl;
		if(ecb->incoming && ecb->spi == spi)
			return c;
	}
	return nil;
}

static char *
setalg(Espcb *ecb, char **f, int n, Algorithm *alg)
{
	uchar *key;
	int i, nbyte, nchar;
	int c;

	if(n < 2)
		return "bad format";
	for(; alg->name; alg++)
		if(strcmp(f[1], alg->name) == 0)
			break;
	if(alg->name == nil)
		return "unknown algorithm";

	if(n != 3)
		return "bad format";
	nbyte = (alg->keylen + 7) >> 3;
	nchar = strlen(f[2]);
	for(i=0; i<nchar; i++) {
		c = f[2][i];
		if(c >= '0' && c <= '9')
			f[2][i] -= '0';
		else if(c >= 'a' && c <= 'f')
			f[2][i] -= 'a'-10;
		else if(c >= 'A' && c <= 'F')
			f[2][i] -= 'A'-10;
		else
			return "bad character in key";
	}
	key = smalloc(nbyte);
	for(i=0; i<nchar && i*2<nbyte; i++) {
		c = f[2][nchar-i-1];
		if(i&1)
			c <<= 4;
		key[i>>1] |= c;
	}

	alg->init(ecb, alg->name, key, alg->keylen);
	free(key);
	return nil;
}

static int
nullcipher(Espcb*, uchar*, int)
{
	return 1;
}

static void
nullespinit(Espcb *ecb, char *name, uchar*, int)
{
	ecb->espalg = name;
	ecb->espblklen = 1;
	ecb->espivlen = 0;
	ecb->cipher = nullcipher;
}

static int
nullauth(Espcb*, uchar*, int, uchar*)
{
	return 1;
}

static void
nullahinit(Espcb *ecb, char *name, uchar*, int)
{
	ecb->ahalg = name;
	ecb->ahblklen = 1;
	ecb->ahlen = 0;
	ecb->auth = nullauth;
}

void
seanq_hmac_sha1(uchar hash[SHA1dlen], uchar *t, long tlen, uchar *key, long klen)
{
	uchar ipad[65], opad[65];
	int i;
	DigestState *digest;
	uchar innerhash[SHA1dlen];

	for(i=0; i<64; i++){
		ipad[i] = 0x36;
		opad[i] = 0x5c;
	}
	ipad[64] = opad[64] = 0;
	for(i=0; i<klen; i++){
		ipad[i] ^= key[i];
		opad[i] ^= key[i];
	}
	digest = sha1(ipad, 64, nil, nil);
	sha1(t, tlen, innerhash, digest);
	digest = sha1(opad, 64, nil, nil);
	sha1(innerhash, SHA1dlen, hash, digest);
}

static int
shaauth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
{
	uchar hash[SHA1dlen];
	int r;

	memset(hash, 0, SHA1dlen);
	seanq_hmac_sha1(hash, t, tlen, (uchar*)ecb->ahstate, 16);
	r = memcmp(auth, hash, ecb->ahlen) == 0;
	memmove(auth, hash, ecb->ahlen);
	return r;
}

static void
shaahinit(Espcb *ecb, char *name, uchar *key, int klen)
{
	if(klen != 128)
		panic("shaahinit: bad keylen");
	klen >>= 8;	// convert to bytes

	ecb->ahalg = name;
	ecb->ahblklen = 1;
	ecb->ahlen = 12;
	ecb->auth = shaauth;
	ecb->ahstate = smalloc(klen);
	memmove(ecb->ahstate, key, klen);
}

void
seanq_hmac_md5(uchar hash[MD5dlen], uchar *t, long tlen, uchar *key, long klen)
{
	uchar ipad[65], opad[65];
	int i;
	DigestState *digest;
	uchar innerhash[MD5dlen];

	for(i=0; i<64; i++){
		ipad[i] = 0x36;
		opad[i] = 0x5c;
	}
	ipad[64] = opad[64] = 0;
	for(i=0; i<klen; i++){
		ipad[i] ^= key[i];
		opad[i] ^= key[i];
	}
	digest = md5(ipad, 64, nil, nil);
	md5(t, tlen, innerhash, digest);
	digest = md5(opad, 64, nil, nil);
	md5(innerhash, MD5dlen, hash, digest);
}

static int
md5auth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
{
	uchar hash[MD5dlen];
	int r;

	memset(hash, 0, MD5dlen);
	seanq_hmac_md5(hash, t, tlen, (uchar*)ecb->ahstate, 16);
	r = memcmp(auth, hash, ecb->ahlen) == 0;
	memmove(auth, hash, ecb->ahlen);
	return r;
}

static void
md5ahinit(Espcb *ecb, char *name, uchar *key, int klen)
{
	if(klen != 128)
		panic("md5ahinit: bad keylen");
	klen >>= 3;	// convert to bytes


	ecb->ahalg = name;
	ecb->ahblklen = 1;
	ecb->ahlen = 12;
	ecb->auth = md5auth;
	ecb->ahstate = smalloc(klen);
	memmove(ecb->ahstate, key, klen);
}

static int
descipher(Espcb *ecb, uchar *p, int n)
{
	uchar tmp[8];
	uchar *pp, *tp, *ip, *eip, *ep;
	DESstate *ds = ecb->espstate;

	ep = p + n;
	if(ecb->incoming) {
		memmove(ds->ivec, p, 8);
		p += 8;
		while(p < ep){
			memmove(tmp, p, 8);
			block_cipher(ds->expanded, p, 1);
			tp = tmp;
			ip = ds->ivec;
			for(eip = ip+8; ip < eip; ){
				*p++ ^= *ip;
				*ip++ = *tp++;
			}
		}
	} else {
		memmove(p, ds->ivec, 8);
		for(p += 8; p < ep; p += 8){
			pp = p;
			ip = ds->ivec;
			for(eip = ip+8; ip < eip; )
				*pp++ ^= *ip++;
			block_cipher(ds->expanded, p, 0);
			memmove(ds->ivec, p, 8);
		}
	}
	return 1;
}
	
static void
desespinit(Espcb *ecb, char *name, uchar *k, int n)
{
	uchar key[8];
	uchar ivec[8];
	int i;
	
	// bits to bytes
	n = (n+7)>>3;
	if(n > 8)
		n = 8;
	memset(key, 0, sizeof(key));
	memmove(key, k, n);
	for(i=0; i<8; i++)
		ivec[i] = nrand(256);
	ecb->espalg = name;
	ecb->espblklen = 8;
	ecb->espivlen = 8;
	ecb->cipher = descipher;
	ecb->espstate = smalloc(sizeof(DESstate));
	setupDESstate(ecb->espstate, key, ivec);
}

static int
rc4cipher(Espcb *ecb, uchar *p, int n)
{
	Esprc4 *esprc4;
	RC4state tmpstate;
	ulong seq;
	long d, dd;

	if(n < 4)
		return 0;

	esprc4 = ecb->espstate;
	if(ecb->incoming) {
		seq = nhgetl(p);
		p += 4;
		n -= 4;
		d = seq-esprc4->cseq;
		if(d == 0) {
			rc4(&esprc4->current, p, n);
			esprc4->cseq += n;
			if(esprc4->ovalid) {
				dd = esprc4->cseq - esprc4->lgseq;
				if(dd > RC4back)
					esprc4->ovalid = 0;
			}
		} else if(d > 0) {
print("missing packet: %uld %ld\n", seq, d);
			// this link is hosed
			if(d > RC4forward) {
				strcpy(up->errstr, "rc4cipher: skipped too much");
				return 0;
			}
			esprc4->lgseq = seq;
			if(!esprc4->ovalid) {
				esprc4->ovalid = 1;
				esprc4->oseq = esprc4->cseq;
				memmove(&esprc4->old, &esprc4->current, sizeof(RC4state));
			}
			rc4skip(&esprc4->current, d);
			rc4(&esprc4->current, p, n);
			esprc4->cseq = seq+n;
		} else {
print("reordered packet: %uld %ld\n", seq, d);
			dd = seq - esprc4->oseq;
			if(!esprc4->ovalid || -d > RC4back || dd < 0) {
				strcpy(up->errstr, "rc4cipher: too far back");
				return 0;
			}
			memmove(&tmpstate, &esprc4->old, sizeof(RC4state));
			rc4skip(&tmpstate, dd);
			rc4(&tmpstate, p, n);
			return 1;
		}

		// move old state up
		if(esprc4->ovalid) {
			dd = esprc4->cseq - RC4back - esprc4->oseq;
			if(dd > 0) {
				rc4skip(&esprc4->old, dd);
				esprc4->oseq += dd;
			}
		}
	} else {
		hnputl(p, esprc4->cseq);
		p += 4;
		n -= 4;
		rc4(&esprc4->current, p, n);
		esprc4->cseq += n;
	}
	return 1;
}

static void
rc4espinit(Espcb *ecb, char *name, uchar *k, int n)
{	
	Esprc4 *esprc4;

	// bits to bytes
	n = (n+7)>>3;
	esprc4 = smalloc(sizeof(Esprc4));
	memset(esprc4, 0, sizeof(Esprc4));
	setupRC4state(&esprc4->current, k, n);
	ecb->espalg = name;
	ecb->espblklen = 4;
	ecb->espivlen = 4;
	ecb->cipher = rc4cipher;
	ecb->espstate = esprc4;
}
	
void
espinit(Fs *fs)
{
	Proto *esp;

	esp = smalloc(sizeof(Proto));
	esp->priv = smalloc(sizeof(Esppriv));
	esp->name = "esp";
	esp->connect = espconnect;
	esp->announce = nil;
	esp->ctl = espctl;
	esp->state = espstate;
	esp->create = espcreate;
	esp->close = espclose;
	esp->rcv = espiput;
	esp->advise = espadvise;
	esp->stats = espstats;
	esp->local = esplocal;
	esp->remote = espremote;
	esp->ipproto = IP_ESPPROTO;
	esp->nc = Nchans;
	esp->ptclsize = sizeof(Espcb);

	Fsproto(fs, esp);
}