shithub: mc

ref: ba9519781df047104b22bd2285e2d608253812f1
dir: /lib/crypto/chacha20.myr/

View raw version
use std

pkg crypto =
	type chacha20ctx = struct
		input	: uint32[16]
	;;

	const chacha20keysetup	: (x : chacha20ctx#, k : byte[:] -> void) 
	const chacha20ivsetup	: (x : chacha20ctx#, iv : byte[:] -> void)
	const chacha20encrypt	: (x : chacha20ctx#, m : byte[:], c : byte[:] -> void)
;;

const sigma = "expand 32-byte k"
const tau = "expand 16-byte k"

const chacha20keysetup = {x, k
	var constants

	x.input[4] = std.getle32(k[0:4])
	x.input[5] = std.getle32(k[4:8])
	x.input[6] = std.getle32(k[8:12])
	x.input[7] = std.getle32(k[12:16])
	if k.len * 8 == 256
		k = k[16:]
		constants = sigma
	elif k.len * 8 == 128
		constants = tau
	else
		std.die("invalid key length")
	;;
	x.input[8] = std.getle32(k[0:4])
	x.input[9] = std.getle32(k[4:8])
	x.input[10] = std.getle32(k[8:12])
	x.input[11] = std.getle32(k[12:16])
	x.input[0] = std.getle32(constants[0:4])
	x.input[1] = std.getle32(constants[4:8])
	x.input[2] = std.getle32(constants[8:12])
	x.input[3] = std.getle32(constants[12:16])
}

const chacha20ivsetup = {x, iv
  x.input[12] = 0
  x.input[13] = 0
  x.input[14] = std.getle32(iv[0:4])
  x.input[15] = std.getle32(iv[4:8])
}


const chacha20encrypt = {x, m, c
	var x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 : uint32;
	var j0, j1, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11, j12, j13, j14, j15 : uint32;
	var ctarget = [][:]
	var tmp : byte[64];

	std.assert(m.len == c.len, "mismatch between message and ciphertext lengths\n")
	j0 = x.input[0]
	j1 = x.input[1]
	j2 = x.input[2]
	j3 = x.input[3]
	j4 = x.input[4]
	j5 = x.input[5]
	j6 = x.input[6]
	j7 = x.input[7]
	j8 = x.input[8]
	j9 = x.input[9]
	j10 = x.input[10]
	j11 = x.input[11]
	j12 = x.input[12]
	j13 = x.input[13]
	j14 = x.input[14]
	j15 = x.input[15]

	while true
		if m.len < 64
			std.slcp(tmp[:m.len], m)
			m = tmp[:]
			ctarget = c
			c = tmp[:]
		;;
		x0 = j0
		x1 = j1
		x2 = j2
		x3 = j3
		x4 = j4
		x5 = j5
		x6 = j6
		x7 = j7
		x8 = j8
		x9 = j9
		x10 = j10
		x11 = j11
		x12 = j12
		x13 = j13
		x14 = j14
		x15 = j15
		/* do the rounds */
		for var i = 20; i > 0; i -= 2
			x0 = (x0 + x4)
			x12 = (((x12 ^ x0 : uint32) << 16) | ((x12 ^ x0) >> (32 - 16)))
			x8 = (x8 + x12)
			x4 = (((x4 ^ x8 : uint32) << 12) | ((x4 ^ x8) >> (32 - 12)))
			x0 = (x0 + x4)
			x12 = (((x12 ^ x0 : uint32) << 8) | ((x12 ^ x0) >> (32 - 8)))
			x8 = (x8 + x12)
			x4 = (((x4 ^ x8 : uint32) << 7) | ((x4 ^ x8) >> (32 - 7)))

			x1 = (x1 + x5)
			x13 = ((((x13 ^ x1) : uint32) << 16) | ((x13 ^ x1) >> (32 - 16)))
			x9 = (x9 + x13)
			x5 = ((((x5 ^ x9) : uint32) << 12) | ((x5 ^ x9) >> (32 - 12)))
			x1 = (x1 + x5)
			x13 = ((((x13 ^ x1) : uint32) << 8) | ((x13 ^ x1) >> (32 - 8)))
			x9 = (x9 + x13)
			x5 = ((((x5 ^ x9) : uint32) << 7) | ((x5 ^ x9) >> (32 - 7)))

			x2 = (x2 + x6)
			x14 = ((((x14 ^ x2) : uint32) << 16) | ((x14 ^ x2) >> (32 - 16)))
			x10 = (x10 + x14)
			x6 = ((((x6 ^ x10) : uint32) << 12) | ((x6 ^ x10) >> (32 - 12)))
			x2 = (x2 + x6)
			x14 = ((((x14 ^ x2) : uint32) << 8) | ((x14 ^ x2) >> (32 - 8)))
			x10 = (x10 + x14)
			x6 = ((((x6 ^ x10) : uint32) << 7) | ((x6 ^ x10) >> (32 - 7)))

			x3 = (x3 + x7)
			x15 = ((((x15 ^ x3) : uint32) << 16) | ((x15 ^ x3) >> (32 - 16)))
			x11 = (x11 + x15)
			x7 = ((((x7 ^ x11) : uint32) << 12) | ((x7 ^ x11) >> (32 - 12)))
			x3 = (x3 + x7)
			x15 = ((((x15 ^ x3) : uint32) << 8) | ((x15 ^ x3) >> (32 - 8)))
			x11 = (x11 + x15)
			x7 = ((((x7 ^ x11) : uint32) << 7) | ((x7 ^ x11) >> (32 - 7)))

			x0 = (x0 + x5)
			x15 = ((((x15 ^ x0) : uint32) << 16) | ((x15 ^ x0) >> (32 - 16)))
			x10 = (x10 + x15)
			x5 = ((((x5 ^ x10) : uint32) << 12) | ((x5 ^ x10) >> (32 - 12)))
			x0 = (x0 + x5)
			x15 = ((((x15 ^ x0) : uint32) << 8) | ((x15 ^ x0) >> (32 - 8)))
			x10 = (x10 + x15)
			x5 = ((((x5 ^ x10) : uint32) << 7) | ((x5 ^ x10) >> (32 - 7)))

			x1 = (x1 + x6)
			x12 = ((((x12 ^ x1) : uint32) << 16) | ((x12 ^ x1) >> (32 - 16)))
			x11 = (x11 + x12)
			x6 = ((((x6 ^ x11) : uint32) << 12) | ((x6 ^ x11) >> (32 - 12)))
			x1 = (x1 + x6)
			x12 = ((((x12 ^ x1) : uint32) << 8) | ((x12 ^ x1) >> (32 - 8)))
			x11 = (x11 + x12)
			x6 = ((((x6 ^ x11) : uint32) << 7) | ((x6 ^ x11) >> (32 - 7)))

			x2 = (x2 + x7)
			x13 = ((((x13 ^ x2) : uint32) << 16) | ((x13 ^ x2) >> (32 - 16)))
			x8 = (x8 + x13)
			x7 = ((((x7 ^ x8) : uint32) << 12) | ((x7 ^ x8) >> (32 - 12)))
			x2 = (x2 + x7)
			x13 = ((((x13 ^ x2) : uint32) << 8) | ((x13 ^ x2) >> (32 - 8)))
			x8 = (x8 + x13)
			x7 = ((((x7 ^ x8) : uint32) << 7) | ((x7 ^ x8) >> (32 - 7)))

			x3 = (x3 + x4)
			x14 = ((((x14 ^ x3) : uint32) << 16) | ((x14 ^ x3) >> (32 - 16)))
			x9 = (x9 + x14)
			x4 = ((((x4 ^ x9) : uint32) << 12) | ((x4 ^ x9) >> (32 - 12)))
			x3 = (x3 + x4)
			x14 = ((((x14 ^ x3) : uint32) << 8) | ((x14 ^ x3) >> (32 - 8)))
			x9 = (x9 + x14)
			x4 = ((((x4 ^ x9) : uint32) << 7) | ((x4 ^ x9) >> (32 - 7)))
		;;
		x0 = x0 + j0
		x1 = x1 + j1
		x2 = x2 + j2
		x3 = x3 + j3
		x4 = x4 + j4
		x5 = x5 + j5
		x6 = x6 + j6
		x7 = x7 + j7
		x8 = x8 + j8
		x9 = x9 + j9
		x10 = x10 + j10
		x11 = x11 + j11
		x12 = x12 + j12
		x13 = x13 + j13
		x14 = x14 + j14
		x15 = x15 + j15

		x0 = x0 ^ std.getle32(m[0:4]);
		x1 = x1 ^ std.getle32(m[4:8]);
		x2 = x2 ^ std.getle32(m[8:12]);
		x3 = x3 ^ std.getle32(m[12:16]);
		x4 = x4 ^ std.getle32(m[16:20]);
		x5 = x5 ^ std.getle32(m[20:24]);
		x6 = x6 ^ std.getle32(m[24:28]);
		x7 = x7 ^ std.getle32(m[28:32]);
		x8 = x8 ^ std.getle32(m[32:36]);
		x9 = x9 ^ std.getle32(m[36:40]);
		x10 = x10 ^ std.getle32(m[40:44]);
		x11 = x11 ^ std.getle32(m[44:48]);
		x12 = x12 ^ std.getle32(m[48:52]);
		x13 = x13 ^ std.getle32(m[52:56]);
		x14 = x14 ^ std.getle32(m[56:60]);
		x15 = x15 ^ std.getle32(m[60:64]);

		j12++
		if j12 == 0
			j13++
		;;

		std.putle32(c[0:4], x0);
		std.putle32(c[4:8], x1);
		std.putle32(c[8:12], x2);
		std.putle32(c[12:16], x3);
		std.putle32(c[16:20], x4);
		std.putle32(c[20:24], x5);
		std.putle32(c[24:28], x6);
		std.putle32(c[28:32], x7);
		std.putle32(c[32:36], x8);
		std.putle32(c[36:40], x9);
		std.putle32(c[40:44], x10);
		std.putle32(c[44:48], x11);
		std.putle32(c[48:52], x12);
		std.putle32(c[52:56], x13);
		std.putle32(c[56:60], x14);
		std.putle32(c[60:64], x15);

		if m.len <= 64
			if m.len < 64
				std.slcp(ctarget[:m.len], c)
			;;
			x.input[12] = j12;
			x.input[13] = j13;
			-> void
		;;
		c = c[64:]
		m = m[64:]
	;;

}