shithub: mc

ref: 7ed35649bb5b7ce79c7a01f1b1d8cb850b8d388a
dir: /parse/infer.c/

View raw version
#include <stdlib.h>
#include <stdio.h>
#include <stdarg.h>
#include <inttypes.h>
#include <inttypes.h>
#include <ctype.h>
#include <string.h>
#include <assert.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <assert.h>

#include "util.h"
#include "parse.h"

typedef struct Traitmap Traitmap;
typedef struct Enttype Enttype;

struct Traitmap {
	Bitset *traits;
	Traitmap *sub[Ntypes];

	Htab *name;	/* name => bitset(traits) */
	Type **filter;
	size_t nfilter;
	Trait **filtertr;
	size_t nfiltertr;

static void infernode(Node **np, Type *ret, int *sawret);
static void inferexpr(Node **np, Type *ret, int *sawret);
static void inferdecl(Node *n);
static int tryconstrain(Type *ty, Trait *tr, int update);

static Type *tyfreshen(Tysubst *subst, Type *orig);
static Type *tf(Type *t);
static Type *basetype(Type *a);

static Type *unify(Node *ctx, Type *a, Type *b);
static Type *tyfix(Node *ctx, Type *orig, int noerr);
static void typesub(Node *n, int noerr);

/* tracking where we are in the inference */
static int ingeneric;
static int inaggr;
static int indentdepth;
static Srcloc *unifysrc;
static size_t nunifysrc;

/* post-inference checking/unification */
static Htab *delayed;
static Node **postcheck;
static size_t npostcheck;
static Stab **postcheckscope;
static size_t npostcheckscope;

/* generic declarations to be specialized */
static Node **genericdecls;
static size_t ngenericdecls;
static Node **impldecl;
static size_t nimpldecl;

/* specializations of generics */
static Node **specializations;
static size_t nspecializations;
static Stab **specializationscope;
static size_t nspecializationscope;
static Traitmap *traitmap;

static void
ctxstrcall(char *buf, size_t sz, Node *n)
	char *p, *end, *sep, *t;
	size_t nargs, i;
	Node **args;
	Type *et;

	args = n->expr.args;
	nargs = n->expr.nargs;
	p = buf;
	end = buf + sz;
	sep = "";

	if (exprop(args[0]) == Ovar)
		p += bprintf(p, end - p, "%s(", namestr(args[0]->expr.args[0]));
		p += bprintf(p, end - p, "<e>(");
	for (i = 1; i < nargs; i++) {
		et = tyfix(NULL, exprtype(args[i]), 1);
		if (et != NULL)
			t = tystr(et);
			t = strdup("?");

		if (exprop(args[i]) == Ovar)
			p += bprintf(p, end - p, "%s%s:%s", sep, namestr(args[i]->expr.args[0]), t);
			p += bprintf(p, end - p, "%s<e%zd>:%s", sep, i, t);
		sep = ", ";
	if (exprtype(args[0])->nsub)
		t = tystr(tyfix(NULL, exprtype(args[0])->sub[0], 1));
		t = strdup("unknown");
	p += bprintf(p, end - p, "): %s", t);

static char *
nodetystr(Node *n)
	Type *t;

	t = NULL;
	if (n->type == Nexpr && exprtype(n) != NULL)
		t = tyfix(n, exprtype(n), 1);
	else if (n->type == Ndecl && decltype(n) != NULL)
		t = tyfix(n, decltype(n), 1);

	if (t && tybase(t)->type != Tyvar)
		return tystr(t);
		return strdup("unknown");

static void
marksrc(Type *t, Srcloc l)
	t = tf(t);
	if (t->tid >= nunifysrc) {
		unifysrc = zrealloc(unifysrc, nunifysrc*sizeof(Srcloc), (t->tid + 1)*sizeof(Srcloc));
		nunifysrc = t->tid + 1;
	if (unifysrc[t->tid].line <= 0)
		unifysrc[t->tid] = l;

static char *
srcstr(Type *ty)
	char src[128];
	Srcloc l;
	char *s;

	src[0] = 0;
	if (nunifysrc > ty->tid && unifysrc[ty->tid].line > 0) {
		l = unifysrc[ty->tid];
		s = tystr(ty);
		snprintf(src, sizeof src, "\n\t%s from %s:%d", s, fname(l), lnum(l));
	return strdup(src);

/* Tries to give a good string describing the context
 * for the sake of error messages. */
static char *
ctxstr(Node *n)
	char *t, *t1, *t2, *t3;
	char *s, *d;
	size_t nargs;
	Node **args;
	char buf[512];

	if (!n)
		return strdup("???");
	switch (n->type) {
		s = strdup(nodestr[n->type]);
	case Ndecl:
		 d = declname(n);
		 t = nodetystr(n);
		 bprintf(buf, sizeof buf, "%s:%s", d, t);
		 s = strdup(buf);
	case Nname:	s = strdup(namestr(n));	break;
	case Nexpr:
		args = n->expr.args;
		nargs = n->expr.nargs;
		t1 = NULL;
		t2 = NULL;
		t3 = NULL;
		if (exprop(n) == Ovar)
			d = namestr(args[0]);
			d = opstr[exprop(n)];
		t = nodetystr(n);
		if (nargs >= 1)
			t1 = nodetystr(args[0]);
		if (nargs >= 2)
			t2 = nodetystr(args[1]);
		if (nargs >= 3)
			t3 = nodetystr(args[2]);

		switch (opclass[exprop(n)]) {
		case OTpre:	bprintf(buf, sizeof buf, "%s<e:%s>", oppretty[exprop(n)], t1);	break;
		case OTpost:	bprintf(buf, sizeof buf, "<e:%s>%s", t1, oppretty[exprop(n)]);	break;
		case OTzarg:	bprintf(buf, sizeof buf, "%s", oppretty[exprop(n)]);	break;
		case OTbin:
			bprintf(buf, sizeof buf, "<e1:%s> %s <e2:%s>", t1, oppretty[exprop(n)], t2);
		case OTmisc:
			switch (exprop(n)) {
			case Ovar:	bprintf(buf, sizeof buf, "%s:%s", namestr(args[0]), t);	break;
			case Ocall:	ctxstrcall(buf, sizeof buf,n);	break;
			case Oidx:
				if (exprop(args[0]) == Ovar)
					bprintf(buf, sizeof buf, "%s[<e1:%s>]", namestr(args[0]->expr.args[0]), t2);
					bprintf(buf, sizeof buf, "<sl:%s>[<e1%s>]", t1, t2);
			case Oslice:
				if (exprop(args[0]) == Ovar)
					bprintf(buf, sizeof buf, "%s[<e1:%s>:<e2:%s>]", namestr(args[0]->expr.args[0]), t2, t3);
					bprintf( buf, sizeof buf, "<sl:%s>[<e1%s>:<e2:%s>]", t1, t2, t3);
			case Omemb:
				bprintf(buf, sizeof buf, "<%s>.%s", t1, namestr(args[1]));
				bprintf(buf, sizeof buf, "%s:%s", d, t);
		default: bprintf(buf, sizeof buf, "%s", d); break;
		s = strdup(buf);
	return s;

static void
addspecialization(Node *n, Stab *stab)
	Node *dcl;

	dcl = decls[n->expr.did];
	lappend(&specializationscope, &nspecializationscope, stab);
	lappend(&specializations, &nspecializations, n);
	lappend(&genericdecls, &ngenericdecls, dcl);

static void
adddispspecialization(Node *n, Stab *stab)
	Trait *tr;
	Type *ty;

	tr = traittab[Tcdisp];
	ty = decltype(n);
	assert(tr->nproto == 1);
	if (hthas(tr->proto[0]->decl.impls, ty))
	lappend(&specializationscope, &nspecializationscope, stab);
	lappend(&specializations, &nspecializations, n);
	lappend(&genericdecls, &ngenericdecls, tr->proto[0]);

static void
additerspecialization(Node *n, Stab *stab)
	Trait *tr;
	Type *ty;
	size_t i;

	tr = traittab[Tciter];
	ty = exprtype(n->iterstmt.seq);
	if (ty->type == Tyslice || ty->type == Tyarray || ty->type == Typtr)
	ty = tyfreshen(NULL, ty);
	for (i = 0; i < tr->nproto; i++) {
		ty = exprtype(n->iterstmt.seq);
		if (hthas(tr->proto[i]->decl.impls, ty))
		lappend(&specializationscope, &nspecializationscope, stab);
		lappend(&specializations, &nspecializations, n);
		lappend(&genericdecls, &ngenericdecls, tr->proto[i]);

static void
delayedcheck(Node *n, Stab *s)
	lappend(&postcheck, &npostcheck, n);
	lappend(&postcheckscope, &npostcheckscope, s);

static void
typeerror(Type *a, Type *b, Node *ctx, char *msg)
	char *t1, *t2, *s1, *s2, *c;

	t1 = tystr(tyfix(NULL, a, 1));
	t2 = tystr(tyfix(NULL, b, 1));
	s1 = srcstr(a);
	s2 = srcstr(b);
	c = ctxstr(ctx);
	if (msg)
		fatal(ctx, "type \"%s\" incompatible with \"%s\" near %s: %s%s%s", t1, t2, c, msg, s1, s2);
		fatal(ctx, "type \"%s\" incompatible with \"%s\" near %s%s%s", t1, t2, c, s1, s2);

/* Set a scope's enclosing scope up correctly. */
static void
setsuperenv(Tyenv *e, Tyenv *super)
	Tyenv *te;

	/* verify that we don't accidentally create loops */
	if (!e)
	for (te = super; te; te = te->super)
		assert(te->super != e);
	e->super = super;

/* Set a scope's enclosing scope up correctly. */
static void
setsuper(Stab *st, Stab *super)
	Stab *s;

	/* verify that we don't accidentally create loops */
	for (s = super; s; s = s->super)
		assert(s->super != st);
	st->super = super;

/* Checks if a type that directly contains itself.
 * Recursive types that contain themselves through
 * pointers or slices are fine, but any other self-inclusion
 * would lead to a value of infinite size */
static int
occurs_rec(Type *sub, Bitset *bs)
	size_t i;

	sub = tf(sub);
	if (bshas(bs, sub->tid))
		return 1;
	bsput(bs, sub->tid);
	switch (sub->type) {
	case Typtr:
	case Tyslice:
	case Tystruct:
		for (i = 0; i < sub->nmemb; i++)
			if (occurs_rec(decltype(sub->sdecls[i]), bs))
				return 1;
	case Tyunion:
		for (i = 0; i < sub->nmemb; i++) {
			if (!sub->udecls[i]->etype)
			if (occurs_rec(sub->udecls[i]->etype, bs))
				return 1;
		for (i = 0; i < sub->nsub; i++)
			if (occurs_rec(sub->sub[i], bs))
				return 1;
	bsdel(bs, sub->tid);
	return 0;

static int
occursin(Type *a, Type *b)
	Bitset *bs;
	int r;

	bs = mkbs();
	a = tf(a);
	bsput(bs, a->tid);
	r = occurs_rec(b, bs);
	return r;

static int
occurs(Type *t)
	Bitset *bs;
	int r;

	bs = mkbs();
	r = occurs_rec(t, bs);
	return r;

static int
needfreshenrec(Type *t, Bitset *visited)
	size_t i;

	t = tysearch(t);
	if (bshas(visited, t->tid))
		return 0;
	bsput(visited, t->tid);
	switch (t->type) {
	case Typaram: return 1;
	case Tygeneric: return 1;
	case Tyname:
		for (i = 0; i < t->narg; i++)
			if (needfreshenrec(t->arg[i], visited))
				return 1;
		return needfreshenrec(t->sub[0], visited);
	case Tystruct:
		for (i = 0; i < t->nmemb; i++)
			if (needfreshenrec(decltype(t->sdecls[i]), visited))
				return 1;
	case Tyunion:
		for (i = 0; i < t->nmemb; i++)
			if (t->udecls[i]->etype && needfreshenrec(t->udecls[i]->etype, visited))
				return 1;
		for (i = 0; i < t->nsub; i++)
			if (needfreshenrec(t->sub[i], visited))
				return 1;
	return 0;

static int
needfreshen(Type *t)
	Bitset *visited;
	int ret;

	visited = mkbs();
	ret = needfreshenrec(t, visited);
	return ret;

/* Freshens the type of a declaration. */
static Type *
tyfreshen(Tysubst *subst, Type *orig)
	Type *ty, *base;

	if (!needfreshen(orig))
		return orig;
	if (!subst) {
		subst = mksubst();
		ty = tyspecialize(orig, subst, delayed);
	} else {
		ty = tyspecialize(orig, subst, delayed);
	ty->spec = orig->spec;
	ty->nspec = orig->nspec;
	base = basetype(ty);
	if (base)
		ty->seqaux = base;
	return ty;

/* Resolves a type and all its subtypes recursively. */
static void
tyresolve(Type *t)
	size_t i, j;
	Trait *tr;

	if (t->resolved)
	/* type resolution should never throw errors about non-generic
	 * showing up within a generic type, so we push and pop a generic
	 * around resolution */
	t->resolved = 1;
	/* Walk through aggregate type members */
	switch (t->type) {
	case Tystruct:
		for (i = 0; i < t->nmemb; i++)
			infernode(&t->sdecls[i], NULL, NULL);
	case Tyunion:
		for (i = 0; i < t->nmemb; i++) {
			t->udecls[i]->utype = t;
			t->udecls[i]->utype = tf(t->udecls[i]->utype);
			if (t->udecls[i]->etype) {
				t->udecls[i]->etype = tf(t->udecls[i]->etype);
	case Tyarray:
		if (!inaggr && !t->asize)
			lfatal(t->loc, "unsized array type outside of struct");
		infernode(&t->asize, NULL, NULL);
	case Typaram:
		if (!boundtype(t))
			lfatal(t->loc, "type parameter %s is undefined in generic context", tystr(t));

	for (i = 0; i < t->nspec; i++) {
		for (j = 0; j < t->spec[i]->ntrait; j++) {
			tr = gettrait(curstab(), t->spec[i]->trait[j]);
			if (!tr)
				lfatal(t->loc, "trait %s does not exist", ctxstr(t->spec[i]->trait[j]));
			if (!t->trneed)
				t->trneed = mkbs();
			bsput(t->trneed, tr->uid);
			if (nameeq(t->spec[i]->trait[j], traittab[Tciter]->name))
				t->seqaux = t->spec[i]->aux;

	for (i = 0; i < t->nsub; i++) {
		t->sub[i] = tf(t->sub[i]);
		if (t->sub[i] == t) {
				"%s occurs within %s, leading to infinite type\n",
				tystr(t->sub[i]), tystr(t));
	if (occurs(t))
		lfatal(t->loc, "type %s includes itself", tystr(t));

Type *
tysearch(Type *t)
	if (!t)
		return t;
	while (tytab[t->tid])
		t = tytab[t->tid];
	return t;

/* Look up the best type to date in the unification table, returning it */
static Type *
tylookup(Type *t)
	Type *lu;
	Stab *ns;

	assert(t != NULL);
	lu = NULL;
	while (1) {
		if (!tytab[t->tid] && t->type == Tyunres) {
			ns = curstab();
			if (t->name->name.ns)
				ns = getns(t->name->name.ns);
			if (!ns)
				fatal(t->name, "no namespace \"%s\"", t->name->name.ns);
			lu = gettype(ns, t->name);
			if (!lu)
				fatal(t->name, "no type %s", tystr(t));
			if (lu && lu != t)
				tytab[t->tid] = lu;

		if (!tytab[t->tid])
		/* compress paths: shift the link up one level */
		if (tytab[tytab[t->tid]->tid])
			tytab[t->tid] = tytab[tytab[t->tid]->tid];
		t = tytab[t->tid];
	return t;

static Type *
tysubstmap(Tysubst *subst, Type *t, Type *orig)
	size_t i;

	for (i = 0; i < t->ngparam; i++)
		substput(subst, t->gparam[i], tf(orig->arg[i]));
	t = tyfreshen(subst, t);
	return t;

static Type *
tysubst(Type *t, Type *orig)
	Tysubst *subst;

	subst = mksubst();
	t = tysubstmap(subst, t, orig);
	return t;

/* find the most accurate type mapping we have (ie,
 * the end of the unification chain */
static Type *
tf(Type *orig)
	int isgeneric;
	Type *t, *tt;

	assert(orig != NULL);
	t = tylookup(orig);
	isgeneric = t->type == Tygeneric;
	ingeneric += isgeneric;
	/* If this is an instantiation of a generic type, we want the params to
	 * match the instantiation */
	if (orig->type == Tyunres && t->type == Tygeneric) {
		if (t->ngparam != orig->narg) {
			lfatal(orig->loc, "%s incompatibly specialized with %s, declared on %s:%d",
					tystr(orig), tystr(t), file->file.files[t->loc.file], t->loc.line);
		t = tysubst(t, orig);
	} else if (orig->type == Typaram) {
		tt = boundtype(t);
		if (tt) {
			t = tt;
	ingeneric -= isgeneric;
	return t;

/* set the type of any typable node */
static void
settype(Node *n, Type *t)
	t = tf(t);
	switch (n->type) {
	case Nexpr:	n->expr.type = t;	break;
	case Ndecl:	n->decl.type = t;	break;
	case Nlit:	n->lit.type = t;	break;
	case Nfunc:	n->func.type = t;	break;
	default: die("untypable node %s", nodestr[n->type]); break;
	if (t->type != Tyvar)
		marksrc(t, n->loc);

static Type*
mktylike(Srcloc l, Ty other)
	Type *t;

	t = mktyvar(l);
	/* not perfect in general, but good enough for all places mktylike is used. */
	t->trneed = bsdup(traitmap->sub[other]->traits);
	return t;

/* Gets the type of a literal value */
static Type *
littype(Node *n)
	Type *t;

	t = NULL;
	if (!n->lit.type) {
		switch (n->lit.littype) {
		case Lvoid:	t = mktype(n->loc, Tyvoid);	break;
		case Lchr:	t = mktype(n->loc, Tychar);	break;
		case Lbool:	t = mktype(n->loc, Tybool);	break;
		case Lint:	t = mktylike(n->loc, Tyint);	break;
		case Lflt:	t = mktylike(n->loc, Tyflt64);	break;
		case Lstr:	t = mktyslice(n->loc, mktype(n->loc, Tybyte));	break;
		case Llbl:	t = mktyptr(n->loc, mktype(n->loc, Tyvoid));	break;
		case Lfunc:	t = n->lit.fnval->func.type;	break;
		n->lit.type = t;
	return n->lit.type;

static Type *
delayeducon(Type *fallback)
	Type *t;

	if (fallback->type != Tyunion)
		return fallback;
	t = mktylike(fallback->loc, fallback->type);
	htput(delayed, t, fallback);
	return t;

/* Finds the type of any typable node */
static Type *
type(Node *n)
	Type *t;

	switch (n->type) {
	case Nlit:	t = littype(n);	break;
	case Nexpr:	t = n->expr.type;	break;
	case Ndecl:	t = decltype(n);	break;
	case Nfunc:	t = n->func.type;	break;
		t = NULL;
		die("untypeable node %s", nodestr[n->type]);
	return tf(t);

static Ucon *
uconresolve(Node *n)
	Ucon *uc;
	Node **args;
	Stab *ns;

	args = n->expr.args;
	ns = curstab();
	if (args[0]->name.ns)
		ns = getns(args[0]->name.ns);
	if (!ns)
		fatal(n, "no namespace %s\n", args[0]->name.ns);
	uc = getucon(ns, args[0]);
	if (!uc)
		fatal(n, "no union constructor `%s", ctxstr(args[0]));
	if (!uc->etype && n->expr.nargs > 1)
		fatal(n, "nullary union constructor `%s passed arg ", ctxstr(args[0]));
	else if (uc->etype && n->expr.nargs != 2)
		fatal(n, "union constructor `%s needs arg ", ctxstr(args[0]));
	return uc;

/* this doesn't walk through named types, so it can't recurse infinitely. */
tymatchrank(Type *pat, Type *to)
	int match, q;
	size_t i;
	Ucon *puc, *tuc;

	pat = tysearch(pat);
	to = tysearch(to);
	if (pat->type == Typaram) {
		if (!pat->trneed)
			return 0;
		for (i = 0; bsiter(pat->trneed, &i); i++)
			if (!tryconstrain(to, traittab[i], 0))
				return -1;
		return 0;
	} else if (pat->type == Tyvar) {
		return 1;
	} else if (pat->type != to->type) {
		return -1;

	match = 0;
	switch (pat->type) {
	case Tystruct:
		if (pat->nmemb != to->nmemb)
			return -1;
		for (i = 0; i < pat->nmemb; i++) {
			if (!streq(declname(pat->sdecls[i]),declname( to->sdecls[i])))
				return -1;
			q = tymatchrank(decltype(pat->sdecls[i]), decltype(to->sdecls[i]));
			if (q < 0)
				return -1;
			match += q;
	case Tyunion:
		if (pat->nmemb != to->nmemb)
			return -1;
		for (i = 0; i < pat->nmemb; i++) {
			if (!nameeq(pat->udecls[i], to->udecls[i]))
				return -1;
			puc = pat->udecls[i];
			tuc = to->udecls[i];
			if (puc->etype && tuc->etype) {
				q = tymatchrank(puc->etype, tuc->etype);
				if (q < 0)
					return -1;
				match += q;
			} else if (puc->etype != tuc->etype) {
				return -1;
	case Tyname:
		if (!nameeq(pat->name, to->name))
			return -1;
		if (pat->narg != to->narg)
			return -1;
		for (i = 0; i < pat->narg; i++) {
			q = tymatchrank(pat->arg[i], to->arg[i]);
			if (q < 0)
				return -1;
			match += q;
	case Tyarray:
		/* unsized arrays are ok */
		if (pat->asize && to->asize) {
			if (!!litvaleq(pat->asize->expr.args[0], to->asize->expr.args[0]))
				return -1;
		} else if (pat->asize != to->asize) {
			return -1;
		else return tymatchrank(pat->sub[0], to->sub[0]);
		if (pat->nsub != to->nsub)
		if (pat->type == to->type)
			match = 1;
		for (i = 0; i < pat->nsub; i++) {
			q = tymatchrank(pat->sub[i], to->sub[i]);
			if (q < 0)
				return -1;
			match += q;
	return match;

static int
tryconstrain(Type *base, Trait *tr, int update)
	Traitmap *tm;
	Bitset *bs;
	Type *ty;
	size_t i;

	while(1) {
		ty = base;
		tm = traitmap->sub[ty->type];
		while (1) {
			if (ty->type == Typaram)
				if (ty->trneed && bshas(ty->trneed, tr->uid))
					return 1;
			if (ty->type == Tyvar) {
				if (!ty->trneed)
					ty->trneed = mkbs();
				if (update)
					bsput(ty->trneed, tr->uid);
				return 1;
			if (bshas(tm->traits, tr->uid))
				return 1;
			if (tm->name && ty->type == Tyname) {
				bs = htget(tm->name, ty->name);
				if (bs && bshas(bs, tr->uid))
					return 1;
			for (i = 0; i < tm->nfilter; i++) {
				if (tm->filtertr[i]->uid != tr->uid)
				if (tymatchrank(tm->filter[i], ty) >= 0)
					return 1;
			if (!ty->sub || ty->nsub != 1)
			ty = ty->sub[0];
			tm = tm->sub[ty->type];
			if (!tm)
		if (base->type != Tyname)
		base = base->sub[0];
	return 0;

/* Constrains a type to implement the required constraints. On
 * type variables, the constraint is added to the required
 * constraint list. Otherwise, the type is checked to see
 * if it has the required constraint */
static void
constrain(Node *ctx, Type *base, Trait *tr)
	if (!tryconstrain(base, tr, 1))
		fatal(ctx, "%s needs trait %s near %s", tystr(base), namestr(tr->name), ctxstr(ctx));

static void
traitsfor(Type *base, Bitset *dst)
	Traitmap *tm;
	Bitset *bs;
	Type *ty;
	size_t i;

	while(1) {
		ty = base;
		tm = traitmap->sub[ty->type];
		while (1) {
			if (ty->type == Tyvar)
			if (ty->type == Tyname && ty->ngparam == 0)
				bs = htget(tm->name, ty->name);
				bs = tm->traits;
			if (bs)
				bsunion(dst, bs);
			for (i = 0; i < tm->nfilter; i++) {
				if (tymatchrank(tm->filter[i], ty) >= 0)
					bsput(dst, tm->filtertr[i]->uid);
			if (!tm->sub[ty->type] || ty->nsub != 1)
			tm = tm->sub[ty->type];
			ty = ty->sub[0];
		if (base->type != Tyname)
		base = base->sub[0];

static void
verifytraits(Node *ctx, Type *a, Type *b)
	char traitbuf[64], abuf[64], bbuf[64];
	char asrc[128], bsrc[128];
	Bitset *abs, *bbs;
	size_t i, n;
	Srcloc l;
	char *sep;

	abs = a->trneed;
	if (!abs) {
		abs = mkbs();
		traitsfor(a, abs);
	bbs = b->trneed;
	if (!bbs) {
		bbs = mkbs();
		traitsfor(b, bbs);
	if (!bsissubset(abs, bbs)) {
		sep = "";
		n = 0;
		*traitbuf = 0;
		for (i = 0; bsiter(abs, &i); i++) {
			if (!bshas(bbs, i))
				n += bprintf(traitbuf + n, sizeof(traitbuf) - n, "%s%s", sep,
			sep = ",";
		tyfmt(abuf, sizeof abuf, a);
		tyfmt(bbuf, sizeof bbuf, b);
		bsrc[0] = 0;
		if (nunifysrc > b->tid && unifysrc[b->tid].line > 0) {
			l = unifysrc[b->tid];
			snprintf(bsrc, sizeof asrc, "\n\t%s from %s:%d", bbuf, fname(l), lnum(l));
		fatal(ctx, "%s needs trait %s near %s%s%s",
			bbuf, traitbuf, ctxstr(ctx), srcstr(a), srcstr(b));
	if (!a->trneed)
	if (!b->trneed)

/* Merges the constraints on types */
static void
mergetraits(Node *ctx, Type *a, Type *b)
	if (b->type == Tyvar) {
		/* make sure that if a = b, both have same traits */
		if (a->trneed && b->trneed)
			bsunion(b->trneed, a->trneed);
		else if (a->trneed)
			b->trneed = bsdup(a->trneed);
		else if (b->trneed)
			a->trneed = bsdup(b->trneed);
	} else {
		verifytraits(ctx, a, b);

/* Computes the 'rank' of the type; ie, in which
 * direction should we unify. A lower ranked type
 * should be mapped to the higher ranked (ie, more
 * specific) type. */
static int
tyrank(Type *t)
	if (t->type == Tyvar) {
		/* has associated iterator type */
		if (t->seqaux)
			return 1;
			return 0;
	return 2;

static void
unionunify(Node *ctx, Type *u, Type *v)
	size_t i, j;
	int found;

	if (u->nmemb != v->nmemb)
		fatal(ctx, "can't unify %s and %s near %s%s%s\n",
			tystr(u), tystr(v), ctxstr(ctx),
			srcstr(u), srcstr(v));

	for (i = 0; i < u->nmemb; i++) {
		found = 0;
		for (j = 0; j < v->nmemb; j++) {
			if (!nameeq(u->udecls[i]->name, v->udecls[j]->name))
			found = 1;
			if (u->udecls[i]->etype == NULL && v->udecls[j]->etype == NULL)
			else if (u->udecls[i]->etype && v->udecls[j]->etype)
				unify(ctx, u->udecls[i]->etype, v->udecls[j]->etype);
				fatal(ctx, "can't unify %s and %s near %s%s%s",
					tystr(u), tystr(v), ctxstr(ctx),
					srcstr(u), srcstr(v));
		if (!found)
			fatal(ctx, "can't unify %s and %s near %s%s%s",
				tystr(u), tystr(v), ctxstr(ctx),
				srcstr(u), srcstr(v));

static void
structunify(Node *ctx, Type *u, Type *v)
	size_t i, j;
	int found;
	char *ud, *vd;

	if (u->nmemb != v->nmemb)
		fatal(ctx, "can't unify %s and %s near %s%s%s",
			tystr(u), tystr(v), ctxstr(ctx),
			srcstr(u), srcstr(v));

	for (i = 0; i < u->nmemb; i++) {
		found = 0;
		for (j = 0; j < v->nmemb; j++) {
			ud = namestr(u->sdecls[i]->;
			vd = namestr(v->sdecls[j]->;
			if (strcmp(ud, vd) == 0) {
				found = 1;
				unify(ctx, type(u->sdecls[i]), type(v->sdecls[j]));
		/* we had at least one missing member */
		if (!found)
			fatal(ctx, "can't unify %s and %s near %s%s%s",
				tystr(u), tystr(v), ctxstr(ctx),
				srcstr(u), srcstr(v));

static void
membunify(Node *ctx, Type *u, Type *v)
	if (hthas(delayed, u))
		u = htget(delayed, u);
	u = tybase(u);
	if (hthas(delayed, v))
		v = htget(delayed, v);
	v = tybase(v);
	if (u->type == Tyunion && v->type == Tyunion && u != v)
		unionunify(ctx, u, v);
	else if (u->type == Tystruct && v->type == Tystruct && u != v)
		structunify(ctx, u, v);

static Type *
basetype(Type *a)
	Type *t;

	t = a->seqaux;
	while (!t && a->type == Tyname) {
		a = a->sub[0];
		t = a->seqaux;
	if (!t && (a->type == Tyslice || a->type == Tyarray || a->type == Typtr))
		t = a->sub[0];
	if (t)
		t = tf(t);
	return t;

static void
checksize(Node *ctx, Type *a, Type *b)
	if (a->asize)
		a->asize = fold(a->asize, 1);
	if (b->asize)
		b->asize = fold(b->asize, 1);
	if (a->asize && exprop(a->asize) != Olit)
		lfatal(ctx->loc, "%s: array size is not constant near %s",
				tystr(a), ctxstr(ctx));
	if (b->asize && exprop(b->asize) != Olit)
		lfatal(ctx->loc, "%s: array size is not constant near %s",
				tystr(b), ctxstr(ctx));
	if (!a->asize)
		a->asize = b->asize;
	else if (!b->asize)
		b->asize = a->asize;
	else if (a->asize && b->asize)
		if (!litvaleq(a->asize->expr.args[0], b->asize->expr.args[0]))
			lfatal(ctx->loc, "array size of %s does not match %s near %s",
				tystr(a), tystr(b), ctxstr(ctx));

static int
hasargs(Type *t)
	return t->type == Tyname && t->narg > 0;

/* Unifies two types, or errors if the types are not unifiable. */
static Type *
unify(Node *ctx, Type *u, Type *v)
	Type *t, *r;
	Type *a, *b;
	Type *ea, *eb;
	size_t i;

	/* a ==> b */
	a = tf(u);
	b = tf(v);
	if (a->tid == b->tid)
		return a;

	/* we unify from lower to higher ranked types */
	if (tyrank(b) < tyrank(a)) {
		t = a;
		a = b;
		b = t;

	/* Disallow recursive types */
	if (a->type == Tyvar && b->type != Tyvar) {
		if (occursin(a, b))
			fatal(ctx, "%s occurs within %s, leading to infinite type near %s\n",
				tystr(a), tystr(b), ctxstr(ctx));

	r = NULL;
	if (a->type == Tyvar || tyeq(a, b)) {
		tytab[a->tid] = b;
		if (ctx) {
			marksrc(a, ctx->loc);
			marksrc(b, ctx->loc);
	if (a->type == Tyvar) {
		ea = basetype(a);
		eb = basetype(b);
		if (ea && eb)
			unify(ctx, ea, eb);
		r = b;

	if (a->type == Tyarray && b->type == Tyarray) {
		checksize(ctx, a, b);

	/* if the tyrank of a is 0 (ie, a raw tyvar), just unify.
	 * Otherwise, match up subtypes. */
	if (a->type == b->type && a->type != Tyvar) {
		if (a->type == Tyname)
			if (!nameeq(a->name, b->name))
				typeerror(a, b, ctx, "incompatible types");
		if (hasargs(a) && hasargs(b)) {
			/* Only Tygeneric and Tyname should be able to unify. And they
			 * should have the same names for this to be true. */
			if (!nameeq(a->name, b->name))
				typeerror(a, b, ctx, NULL);
			if (a->narg != b->narg)
				typeerror(a, b, ctx, "incompatible parameter lists");
			for (i = 0; i < a->narg; i++)
				unify(ctx, a->arg[i], b->arg[i]);
			r = b;
		if (a->nsub != b->nsub) {
			verifytraits(ctx, a, b);
			if (tybase(a)->type == Tyfunc)
				typeerror(a, b, ctx, "function arity mismatch");
				typeerror(a, b, ctx, "subtype counts incompatible");
		for (i = 0; i < b->nsub; i++)
			unify(ctx, a->sub[i], b->sub[i]);
		r = b;
	} else if (a->type != Tyvar) {
		typeerror(a, b, ctx, NULL);
	mergetraits(ctx, a, b);
	if (a->isreflect || b->isreflect) {
	membunify(ctx, a, b);

	/* if we have delayed types for a tyvar, transfer it over. */
	if (a->type == Tyvar && b->type == Tyvar) {
		if (hthas(delayed, a) && !hthas(delayed, b))
			htput(delayed, b, htget(delayed, a));
		else if (hthas(delayed, b) && !hthas(delayed, a))
			htput(delayed, a, htget(delayed, b));
	} else if (hthas(delayed, a)) {
		unify(ctx, htget(delayed, a), tybase(b));

	return r;

/* Applies unifications to function calls.
 * Funciton application requires a slightly
 * different approach to unification. */
static void
unifycall(Node *n)
	size_t i;
	Type *ft;

	ft = type(n->expr.args[0]);
	if (ft->type == Tyvar) {
		/* the first arg is the function itself, so it shouldn't be counted */
		ft = mktyfunc(n->loc, &n->expr.args[1], n->expr.nargs - 1, mktyvar(n->loc));
		unify(n, ft, type(n->expr.args[0]));
	} else if (tybase(ft)->type != Tyfunc) {
		fatal(n, "calling uncallable type %s", tystr(ft));
	/* first arg: function itself */
	for (i = 1; i < n->expr.nargs; i++) {
		if (i == ft->nsub)
			fatal(n, "%s arity mismatch (expected %zd args, got %zd)",
					ctxstr(n->expr.args[0]), ft->nsub - 1, n->expr.nargs - 1);

		if (ft->sub[i]->type == Tyvalist) {
		unify(n->expr.args[0], ft->sub[i], type(n->expr.args[i]));
	if (i < ft->nsub && ft->sub[i]->type != Tyvalist)
		fatal(n, "%s arity mismatch (expected %zd args, got %zd)",
				ctxstr(n->expr.args[0]), ft->nsub - 1, i - 1);
	settype(n, ft->sub[0]);

static void
unifyparams(Node *ctx, Type *a, Type *b)
	size_t i;

	/* The only types with unifiable params are Tyunres and Tyname.
	 * Tygeneric should always be freshened, and no other types have
	 * parameters attached.
	 * FIXME: Is it possible to have parameterized typarams? */
	if (a->type != Tyunres && a->type != Tyname)
	if (b->type != Tyunres && b->type != Tyname)

	if (a->narg != b->narg)
		fatal(ctx, "mismatched arg list sizes: %s with %s near %s", tystr(a), tystr(b),
	for (i = 0; i < a->narg; i++)
		unify(ctx, a->arg[i], b->arg[i]);

static Type *
initvar(Node *n, Node *s)
	Type *t, *param;
	Tysubst *subst;

	if (s->decl.ishidden)
		fatal(n, "attempting to refer to hidden decl %s", ctxstr(n));

	param = n->expr.param;
	if (param)
	if (s->decl.isgeneric) {
		subst = mksubst();
		if (param)
			substput(subst, s->decl.trait->param, param);
		t = tysubstmap(subst, tf(s->decl.type), s->decl.type);
		if (s->decl.trait && !param) {
			param = substget(subst, s->decl.trait->param);
			if (!param)
				fatal(n, "ambiguous trait decl %s", ctxstr(s));
	} else {
		t = s->decl.type;
	n->expr.did = s->decl.did;
	n->expr.isconst = s->decl.isconst;
	if (param) {
		n->expr.param = param;
		delayedcheck(n, curstab());
	if (s->decl.isgeneric && !ingeneric) {
		t = tyfreshen(NULL, t);
		addspecialization(n, curstab());
		if (t->type == Tyvar) {
			settype(n, mktyvar(n->loc));
			delayedcheck(n, curstab());
		} else {
			settype(n, t);
	} else {
		settype(n, t);
	return t;

/* Finds out if the member reference is actually
 * referring to a namespaced name, instead of a struct
 * member. If it is, it transforms it into the variable
 * reference we should have, instead of the Omemb expr
 * that we do have */
static Node *
checkns(Node *n, Node **ret)
	Node *var, *name, *nsname;
	Node **args;
	Stab *stab;
	Node *s;

	/* check that this is a namespaced declaration */
	args = n->expr.args;
	if (n->type != Nexpr || exprop(n) != Omemb || exprop(args[0]) != Ovar)
		return n;
	name = args[0]->expr.args[0];
	stab = getns(namestr(name));
	if (!stab || getdcl(curstab(), name))
		return n;

	/* substitute the namespaced name */
	nsname = mknsname(n->loc, namestr(name), namestr(args[1]));
	s = getdcl(stab, args[1]);
	if (!s)
		fatal(n, "undeclared var %s.%s", nsname->name.ns, nsname->;
	var = mkexpr(n->loc, Ovar, nsname, NULL);
	var->expr.idx = n->expr.idx;
	initvar(var, s);
	*ret = var;
	return var;

static void
inferstruct(Node *n, int *isconst)
	size_t i;

	*isconst = 1;
	/* we want to check outer nodes before inner nodes when unifying nested structs */
	delayedcheck(n, curstab());
	for (i = 0; i < n->expr.nargs; i++) {
		infernode(&n->expr.args[i], NULL, NULL);
		if (!n->expr.args[i]->expr.isconst)
			*isconst = 0;
	settype(n, mktyvar(n->loc));

static int64_t
arraysize(Node *n)
	int64_t sz, off, i;
	Node **args, *idx;

	sz = 0;
	args = n->expr.args;
	for (i = 0; i < n->expr.nargs; i++) {
		if (args[i]->expr.idx) {
			args[i]->expr.idx = fold(args[i]->expr.idx, 1);
			idx = args[i]->expr.idx;
			if (exprop(idx) != Olit)
				fatal(idx, "nonconstant array initializer index near %s\n",
			if (idx->expr.args[0]->lit.littype == Lchr)
				off = idx->expr.args[0]->lit.chrval;
			else if (idx->expr.args[0]->lit.littype == Lint)
				off = idx->expr.args[0]->lit.intval;
				fatal(idx, "noninteger array initializer index near %s\n",
			if (off >= sz)
				sz = off + 1;
		} else {
	return sz;

static void
inferarray(Node *n, int *isconst)
	size_t i;
	Type *t;
	Node *len;

	*isconst = 1;
	len = mkintlit(n->loc, arraysize(n));
	t = mktyarray(n->loc, mktyvar(n->loc), len);
	for (i = 0; i < n->expr.nargs; i++) {
		infernode(&n->expr.args[i], NULL, NULL);
		unify(n, t->sub[0], type(n->expr.args[i]));
		if (!n->expr.args[i]->expr.isconst)
			*isconst = 0;
	settype(n, t);

static void
infertuple(Node *n, int *isconst)
	Type **types;
	size_t i;

	*isconst = 1;
	types = xalloc(sizeof(Type *) * n->expr.nargs);
	for (i = 0; i < n->expr.nargs; i++) {
		infernode(&n->expr.args[i], NULL, NULL);
		n->expr.isconst = n->expr.isconst && n->expr.args[i]->expr.isconst;
		types[i] = type(n->expr.args[i]);
	*isconst = n->expr.isconst;
	settype(n, mktytuple(n->loc, types, n->expr.nargs));

static void
inferucon(Node *n, int *isconst)
	Ucon *uc;
	Type *t;

	*isconst = 1;
	uc = uconresolve(n);
	/* Hackety hack hack.
	 * the types in a generic union may be bound from the tyname that
	 * defined it, which is not accessible here.
	 * To make it compile, for now, we just bind the types in here.
	t = tysubst(tf(uc->utype), uc->utype);
	uc = tybase(t)->udecls[uc->id];
	if (uc->etype) {
		inferexpr(&n->expr.args[1], NULL, NULL);
		unify(n, uc->etype, type(n->expr.args[1]));
		*isconst = n->expr.args[1]->expr.isconst;
	settype(n, delayeducon(t));

static void
inferpat(Node **np, Node *val, Node ***bind, size_t *nbind)
	size_t i;
	Node **args;
	Node *s, *n;
	Stab *ns;
	Type *t;

	n = *np;
	n = checkns(n, np);
	args = n->expr.args;
	for (i = 0; i < n->expr.nargs; i++)
		if (args[i]->type == Nexpr)
			inferpat(&args[i], val, bind, nbind);
	switch (exprop(n)) {
	case Otup:
	case Ostruct:
	case Oarr:
	case Olit:
	case Omemb:
		infernode(np, NULL, NULL);
		/* arithmetic expressions just need to be constant */
	case Oneg:
	case Oadd:
	case Osub:
	case Omul:
	case Odiv:
	case Obsl:
	case Obsr:
	case Oband:
	case Obor:
	case Obxor:
	case Obnot:
		infernode(np, NULL, NULL);
		if (!n->expr.isconst)
			fatal(n, "matching against non-constant expression near %s", ctxstr(n));
	case Oucon:
		inferucon(n, &n->expr.isconst);
	case Ovar:
		ns = curstab();
		if (args[0]->name.ns)
			ns = getns(args[0]->name.ns);
		s = getdcl(ns, args[0]);
		if (s && !s->decl.ishidden) {
			if (s->decl.isgeneric)
				t = tysubst(s->decl.type, s->decl.type);
			else if (s->decl.isconst)
				t = s->decl.type;
				fatal(n, "pattern shadows variable declared on %s:%d near %s",
						fname(s->loc), lnum(s->loc), ctxstr(s));
		} else {
			t = mktyvar(n->loc);
			s = mkdecl(n->loc, n->expr.args[0], t);
			s->decl.init = val;
			settype(n, t);
			lappend(bind, nbind, s);
		settype(n, t);
		n->expr.did = s->decl.did;
		n->expr.isconst = s->decl.isconst;
	case Oaddr:
		infernode(np, NULL, NULL);
	case Ogap:
		infernode(np, NULL, NULL);	break;
	default: fatal(n, "invalid pattern"); break;

addbindings(Node *n, Node **bind, size_t nbind)
	size_t i;

	/* order of binding shouldn't matter, so push them into the block
	 * in reverse order. */
	for (i = 0; i < nbind; i++) {
		putdcl(n->block.scope, bind[i]);
		linsert(&n->block.stmts, &n->block.nstmts, 0, bind[i]);

static void
infersub(Node *n, Type *ret, int *sawret, int *exprconst)
	Node **args;
	size_t i, nargs;
	int isconst;

	/* Ovar has no subexpressions */
	if (exprop(n) == Ovar)
	args = n->expr.args;
	nargs = n->expr.nargs;
	isconst = 1;
	for (i = 0; i < nargs; i++) {
		/* Nlit, Nvar, etc should not be inferred as exprs */
		if (args[i]->type == Nexpr) {
			/* Omemb can sometimes resolve to a namespace. We have to check
			 * this. Icky. */
			checkns(args[i], &args[i]);
			inferexpr(&args[i], ret, sawret);
			isconst = isconst && args[i]->expr.isconst;
	if (opispure[exprop(n)])
		n->expr.isconst = isconst;
	*exprconst = n->expr.isconst;

static void
inferexpr(Node **np, Type *ret, int *sawret)
	Node **args;
	size_t i, nargs;
	Node *s, *n;
	Type *t, *b;
	int isconst;
	Stab *ns;

	n = *np;
	assert(n->type == Nexpr);
	args = n->expr.args;
	nargs = n->expr.nargs;
	infernode(&n->expr.idx, NULL, NULL);
	n = checkns(n, np);
	switch (exprop(n)) {
		/* all operands are same type */
	case Oadd:	/* @a + @a -> @a */
	case Osub:	/* @a - @a -> @a */
	case Omul:	/* @a * @a -> @a */
	case Odiv:	/* @a / @a -> @a */
	case Oneg:	/* -@a -> @a */
	case Oaddeq:	/* @a += @a -> @a */
	case Osubeq:	/* @a -= @a -> @a */
	case Omuleq:	/* @a *= @a -> @a */
	case Odiveq:	/* @a /= @a -> @a */
		infersub(n, ret, sawret, &isconst);
		t = type(args[0]);
		constrain(n, t, traittab[Tcnum]);
		isconst = args[0]->expr.isconst;
		for (i = 1; i < nargs; i++) {
			isconst = isconst && args[i]->expr.isconst;
			t = unify(n, t, type(args[i]));
		n->expr.isconst = isconst;
		settype(n, t);
	case Omod:	/* @a % @a -> @a */
	case Obor:	/* @a | @a -> @a */
	case Oband:	/* @a & @a -> @a */
	case Obxor:	/* @a ^ @a -> @a */
	case Obsl:	/* @a << @a -> @a */
	case Obsr:	/* @a >> @a -> @a */
	case Obnot:	/* ~@a -> @a */
	case Opreinc:	/* ++@a -> @a */
	case Opredec:	/* --@a -> @a */
	case Opostinc:	/* @a++ -> @a */
	case Opostdec:	/* @a-- -> @a */
	case Omodeq:	/* @a %= @a -> @a */
	case Oboreq:	/* @a |= @a -> @a */
	case Obandeq:	/* @a &= @a -> @a */
	case Obxoreq:	/* @a ^= @a -> @a */
	case Obsleq:	/* @a <<= @a -> @a */
	case Obsreq:	/* @a >>= @a -> @a */
		infersub(n, ret, sawret, &isconst);
		t = type(args[0]);
		constrain(n, t, traittab[Tcnum]);
		constrain(n, t, traittab[Tcint]);
		isconst = args[0]->expr.isconst;
		for (i = 1; i < nargs; i++) {
			isconst = isconst && args[i]->expr.isconst;
			t = unify(n, t, type(args[i]));
		n->expr.isconst = isconst;
		settype(n, t);
	case Oasn:	/* @a = @a -> @a */
		infersub(n, ret, sawret, &isconst);
		t = type(args[0]);
		for (i = 1; i < nargs; i++)
			t = unify(n, t, type(args[i]));
		settype(n, t);
		if (args[0]->expr.isconst)
			fatal(n, "attempting to assign constant \"%s\"", ctxstr(args[0]));

		/* operands same type, returning bool */
	case Olor:	/* @a || @b -> bool */
	case Oland:	/* @a && @b -> bool */
	case Oeq:	/* @a == @a -> bool */
	case One:	/* @a != @a -> bool */
	case Ogt:	/* @a > @a -> bool */
	case Oge:	/* @a >= @a -> bool */
	case Olt:	/* @a < @a -> bool */
	case Ole:	/* @a <= @b -> bool */
		infersub(n, ret, sawret, &isconst);
		t = type(args[0]);
		for (i = 1; i < nargs; i++)
			unify(n, t, type(args[i]));
		settype(n, mktype(Zloc, Tybool));

	case Olnot:	/* !bool -> bool */
		infersub(n, ret, sawret, &isconst);
		t = unify(n, type(args[0]), mktype(Zloc, Tybool));
		settype(n, t);

		/* reach into a type and pull out subtypes */
	case Oaddr:	/* &@a -> @a* */
		infersub(n, ret, sawret, &isconst);
		settype(n, mktyptr(n->loc, type(args[0])));
	case Oderef:	/* @a# ->  @a */
		infersub(n, ret, sawret, &isconst);
		t = unify(n, type(args[0]), mktyptr(n->loc, mktyvar(n->loc)));
		settype(n, t->sub[0]);
	case Oidx:	/* @a[@b::tcint] -> @a */
		infersub(n, ret, sawret, &isconst);
		b = mktyvar(n->loc);
		t = mktyvar(n->loc);
		t->seqaux = b;
		unify(n, type(args[0]), t);
		constrain(n, type(args[0]), traittab[Tcidx]);
		constrain(n, type(args[1]), traittab[Tcint]);
		settype(n, b);
	case Oslice:	/* @a[@b::tcint,@b::tcint] -> @a[,] */
		infersub(n, ret, sawret, &isconst);
		b = mktyvar(n->loc);
		t = mktyvar(n->loc);
		t->seqaux = b;
		unify(n, type(args[0]), t);
		constrain(n, type(args[1]), traittab[Tcint]);
		constrain(n, type(args[1]), traittab[Tcnum]);
		constrain(n, type(args[2]), traittab[Tcint]);
		constrain(n, type(args[2]), traittab[Tcnum]);
		settype(n, mktyslice(n->loc, b));

		/* special cases */
	case Omemb:	/* @a.Ident -> @b, verify type(@a.Ident)==@b later */
		infersub(n, ret, sawret, &isconst);
		settype(n, mktyvar(n->loc));
		delayedcheck(n, curstab());
	case Osize:	/* sizeof(@a) -> size */
		infersub(n, ret, sawret, &isconst);
		settype(n, mktylike(n->loc, Tyuint));
	case Ocall:	/* (@a, @b, @c, ... -> @r)(@a, @b, @c, ...) -> @r */
		infersub(n, ret, sawret, &isconst);
	case Ocast:	/* (@a : @b) -> @b */
		infersub(n, ret, sawret, &isconst);
		delayedcheck(n, curstab());
	case Oret:	/* -> @a -> void */
		infersub(n, ret, sawret, &isconst);
		if (!ret)
			fatal(n, "returns are not valid near %s", ctxstr(n));
		if (sawret)
			*sawret = 1;
		t = unify(n, ret, type(args[0]));
		settype(n, t);
	case Obreak:
	case Ocontinue:
		/* nullary: nothing to infer. */
		settype(n, mktype(Zloc, Tyvoid));
	case Ojmp:	/* goto void* -> void */
		if (args[0]->type == Nlit && args[0]->lit.littype == Llbl)
			args[0] = getlbl(curstab(), args[0]->loc, args[0]->lit.lblname);
		infersub(n, ret, sawret, &isconst);
		settype(n, mktype(Zloc, Tyvoid));
	case Ovar:	/* a:@a -> @a */
		infersub(n, ret, sawret, &isconst);
		/* if we created this from a namespaced var, the type should be
		 * set, and the normal lookup is expected to fail. Since we're
		 * already done with this node, we can just return. */
		if (n->expr.type)
		ns = curstab();
		if (args[0]->name.ns)
			ns = getns(args[0]->name.ns);
		s = getdcl(ns, args[0]);
		if (!s)
			fatal(n, "undeclared var %s", ctxstr(args[0]));
		if (n->expr.param && !s->decl.trait)
			fatal(n, "var %s must refer to a trait decl", ctxstr(args[0]));
		initvar(n, s);
	case Ogap:	/* _ -> @a */
		if (n->expr.type)
		n->expr.type = mktyvar(n->loc);
	case Oucon:	inferucon(n, &n->expr.isconst);	break;
	case Otup:	infertuple(n, &n->expr.isconst);	break;
	case Ostruct:	inferstruct(n, &n->expr.isconst);	break;
	case Oarr:	inferarray(n, &n->expr.isconst);	break;
	case Olit:	/* <lit>:@a::tyclass -> @a */
		infersub(n, ret, sawret, &isconst);
		switch (args[0]->lit.littype) {
		case Lfunc:
			infernode(&args[0]->lit.fnval, NULL, NULL);
			/* FIXME: env capture means this is non-const */
			n->expr.isconst = 1;
		case Llbl:
			s = getlbl(curstab(), args[0]->loc, args[0]->lit.lblname);
			if (!s)
				fatal(n, "unable to find label %s in function scope\n", args[0]->lit.lblname);
			*np = s;
			n->expr.isconst = 1;
		settype(n, type(args[0]));
	case Otern:
		infersub(n, NULL, sawret, &isconst);
		unify(n, type(args[0]), mktype(n->loc, Tybool));
		unify(n, type(args[1]), type(args[2]));
		settype(n, type(args[1]));
	case Oundef:
		infersub(n, ret, sawret, &isconst);
		settype(n, mktype(n->loc, Tyvoid));
	case Odef:
	case Odead:
		   n->expr.type = mktype(n->loc, Tyvoid);
	/* unexpected in frontend */
	case Obad: case Ocjmp: case Ovjmp: case Oset: case Oslbase: case Osllen:
	case Outag: case Ocallind: case Oblit: case Oclear: case Oudata: case Otrunc:
	case Oswiden: case Ozwiden: case Oint2flt: case Oflt2int: case Oflt2flt:
	case Ofadd: case Ofsub: case Ofmul: case Ofdiv: case Ofneg: case Ofeq:
	case Ofne: case Ofgt: case Ofge: case Oflt: case Ofle: case Oueq: case Oune:
	case Ougt: case Ouge: case Oult: case Oule: case Otupget: case Numops:
		die("Should not see %s in fe", opstr[exprop(n)]);

static void
inferfunc(Node *n)
	size_t i;
	int sawret;

	sawret = 0;
	for (i = 0; i < n->func.nargs; i++)
		infernode(&n->func.args[i], NULL, NULL);
	infernode(&n->func.body, n->func.type->sub[0], &sawret);
	/* if there's no return stmt in the function, assume void ret */
	if (!sawret)
		unify(n, type(n)->sub[0], mktype(Zloc, Tyvoid));

static void
specializeimpl(Node *n)
	Node *dcl, *proto, *name, *sym;
	Tysubst *subst;
	Type *ty;
	Trait *tr;
	size_t i, j;
	int generic;
	char *traitns;

	tr = gettrait(curstab(), n->impl.traitname);
	if (!tr)
		fatal(n, "no trait %s\n", namestr(n->impl.traitname));
	n->impl.trait = tr;
	traitns = tr->name->name.ns;

	dcl = NULL;
	if (n->impl.naux != tr->naux)
		fatal(n, "%s incompatibly specialized with %zd types instead of %zd types",
			namestr(n->impl.traitname), n->impl.naux, tr->naux);
	n->impl.type = tf(n->impl.type);
	for (i = 0; i < n->impl.naux; i++)
		n->impl.aux[i] = tf(n->impl.aux[i]);
	for (i = 0; i < n->impl.ndecls; i++) {
		/* look up the prototype */
		proto = NULL;
		dcl = n->impl.decls[i];

		   since the decls in an impl are not installed in a namespace,
		   their names are not updated when we call updatens() on the
		   symbol table. Because we need to do namespace dependent
		   comparisons for specializing, we need to set the namespace
		if (file->file.globls->name)
			setns(dcl->, traitns);
		for (j = 0; j < tr->nproto; j++) {
			if (nsnameeq(dcl->, tr->proto[j]-> {
				proto = tr->proto[j];
		if (!proto)
			fatal(n, "declaration %s missing in %s, near %s", namestr(dcl->,
					namestr(tr->name), ctxstr(n));

		/* infer and unify types */
		verifytraits(n, tr->param, n->impl.type);
		subst = mksubst();
		substput(subst, tr->param, n->impl.type);
		for (j = 0; j < tr->naux; j++)
			substput(subst, tr->aux[j], n->impl.aux[j]);
		ty = tyspecialize(type(proto), subst, delayed);

		generic = hasparams(ty);
		if (generic)

		unify(n, type(dcl), ty);

		/* and put the specialization into the global stab */
		name = genericname(proto, n->impl.type, ty);
		sym = getdcl(file->file.globls, name);
		if (sym)
			fatal(n, "trait %s already specialized with %s on %s:%d",
				namestr(tr->name), tystr(n->impl.type),
				fname(sym->loc), lnum(sym->loc));
		dcl-> = name;
		putdcl(file->file.globls, dcl);
		htput(proto->decl.impls, n->impl.type, dcl);
		dcl->decl.isconst = 1;
		if (ty->type == Tygeneric || hasparams(ty)) {
			dcl->decl.isgeneric = 1;
			lappend(&proto->decl.gimpl, &proto->decl.ngimpl, dcl);
			lappend(&proto->decl.gtype, &proto->decl.ngtype, ty);
		dcl->decl.vis = tr->vis;
		lappend(&impldecl, &nimpldecl, dcl);

		if (generic)

static void
inferdecl(Node *n)
	Type *t;

	t = tf(decltype(n));
	if (t->type == Tygeneric && !n->decl.isgeneric) {
		t = tyfreshen(NULL, t);
		unifyparams(n, t, decltype(n));
	settype(n, t);
	if (n->decl.init) {
		inferexpr(&n->decl.init, NULL, NULL);
		unify(n, type(n), type(n->decl.init));
		if (n->decl.isconst && !n->decl.init->expr.isconst)
			fatal(n, "non-const initializer for \"%s\"", ctxstr(n));

static void
inferstab(Stab *s)
	void **k;
	size_t n, i;
	Node *dcl;
	Type *t;
	void *se;

	k = htkeys(s->dcl, &n);
	se = curenv();
	for (i = 0; i < n; i++) {
		dcl = htget(s->dcl, k[i]);
		assert(se == curenv());

	k = htkeys(s->ty, &n);
	for (i = 0; i < n; i++) {
		t = gettype(s, k[i]);
		if (!t)
			fatal(k[i], "undefined type %s", namestr(k[i]));
		t = tysearch(t);
		if (t->env)
			setsuperenv(t->env, curenv());
		updatetype(s, k[i], t);

static void
infernode(Node **np, Type *ret, int *sawret)
	size_t i, nbound;
	Node **bound, *n, *pat;
	Type *t, *b, *e;

	n = *np;
	if (!n)
	if (n->inferred)
	n->inferred = 1;
	switch (n->type) {
	case Nfile:
		for (i = 0; i < n->file.nstmts; i++)
			infernode(&n->file.stmts[i], NULL, sawret);
	case Ndecl:
		if (n->decl.isgeneric)
		setsuperenv(n->decl.env, curenv());
		if (hasparams(type(n)) && !ingeneric)
			fatal(n, "generic type in non-generic near %s", ctxstr(n));
		if (n->decl.isauto)
			constrain(n, type(n), traittab[Tcdisp]);
		if (n->decl.isgeneric)
	case Nblock:
		setsuper(n->block.scope, curstab());
		for (i = 0; i < n->block.nstmts; i++)
			infernode(&n->block.stmts[i], ret, sawret);
	case Nifstmt:
		infernode(&n->ifstmt.cond, NULL, sawret);
		infernode(&n->ifstmt.iftrue, ret, sawret);
		infernode(&n->ifstmt.iffalse, ret, sawret);
		unify(n, type(n->ifstmt.cond), mktype(n->loc, Tybool));
	case Nloopstmt:
		setsuper(n->loopstmt.scope, curstab());
		infernode(&n->loopstmt.init, ret, sawret);
		infernode(&n->loopstmt.cond, NULL, sawret);
		infernode(&n->loopstmt.step, ret, sawret);
		infernode(&n->loopstmt.body, ret, sawret);
		unify(n, type(n->loopstmt.cond), mktype(n->loc, Tybool));
	case Niterstmt:
		bound = NULL;
		nbound = 0;

		inferpat(&n->iterstmt.elt, NULL, &bound, &nbound);
		addbindings(n->iterstmt.body, bound, nbound);

		infernode(&n->iterstmt.seq, NULL, sawret);
		infernode(&n->iterstmt.body, ret, sawret);

		e = type(n->iterstmt.elt);
		t = type(n->iterstmt.seq);
		constrain(n, t, traittab[Tciter]);
		b = basetype(t);
		if (b)
			unify(n, e, b);
			t->seqaux = e;
		delayedcheck(n, curstab());
	case Nmatchstmt:
		infernode(&n->matchstmt.val, NULL, sawret);
		for (i = 0; i < n->matchstmt.nmatches; i++) {
			infernode(&n->matchstmt.matches[i], ret, sawret);
			pat = n->matchstmt.matches[i]->match.pat;
			unify(pat, type(n->matchstmt.val),
	case Nmatch:
		bound = NULL;
		nbound = 0;
		inferpat(&n->match.pat, NULL, &bound, &nbound);
		addbindings(n->match.block, bound, nbound);
		infernode(&n->match.block, ret, sawret);
	case Nexpr:
		inferexpr(np, ret, sawret);
	case Nfunc:
		setsuper(n->func.scope, curstab());
		setsuperenv(n->func.env, curenv());
	case Nimpl:
	case Nlit:
	case Nname:
	case Nuse:
	case Nnone:
		die("Nnone should not be seen as node type!");

/* returns the final type for t, after all unification
 * and default constraint selections */
static Type *
tyfix(Node *ctx, Type *orig, int noerr)
	static Type *tyint, *tyflt;
	static Bitset *intset, *fltset;
	Type *t, *d, *base;
	Tyenv *env;
	size_t i;
	char buf[1024];

	if (!tyint) {
		tyint = mktype(Zloc, Tyint);
		intset = mkbs();
		traitsfor(tyint, intset);
	if (!tyflt) {
		tyflt = mktype(Zloc, Tyflt64);
		fltset = mkbs();
		traitsfor(tyflt, fltset);

	t = tysearch(tf(orig));
	env = t->env;
	if (env)
	base = orig->seqaux;
	if (orig->type == Tyvar && hthas(delayed, orig)) {
		d = htget(delayed, orig);
		if (t->type == Tyvar) {
			/* tyvar is guaranteed to unify error-free */
			unify(ctx, t, d);
			t = tf(t);
		} else if (tybase(t)->type != d->type && !noerr) {
			fatal(ctx, "type %s not compatible with %s near %s\n",
			    tystr(t), tystr(d), ctxstr(ctx));
	if (t->type == Tyvar && t->trneed) {
		if (bsissubset(t->trneed, intset))
			t = tyint;
		else if (bsissubset(t->trneed, fltset))
			t = tyflt;
	} else if (!t->fixed) {
		t->fixed = 1;
		if (t->type == Tyarray) {
			typesub(t->asize, noerr);
		} else if (t->type == Tystruct) {
			for (i = 0; i < t->nmemb; i++)
				typesub(t->sdecls[i], noerr);
		} else if (t->type == Tyunion) {
			for (i = 0; i < t->nmemb; i++) {
				if (t->udecls[i]->etype) {
					t->udecls[i]->etype =
						tyfix(ctx, t->udecls[i]->etype, noerr);
		} else if (t->type == Tyname) {
			for (i = 0; i < t->narg; i++)
				t->arg[i] = tyfix(ctx, t->arg[i], noerr);
		for (i = 0; i < t->nsub; i++)
			t->sub[i] = tyfix(ctx, t->sub[i], noerr);

	if (t->type == Tyvar && !noerr)
		fatal(ctx, "underconstrained type %s near %s", tyfmt(buf, 1024, t), ctxstr(ctx));
	if (base) {
		if (base != orig)
			base = tyfix(ctx, base, noerr);
		t->seqaux = base;
	if (env)
	return t;

static void
checkcast(Node *n, Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
	/* FIXME: actually verify the casts. Right now, it's ok to leave thi
	 * unimplemented because bad casts get caught by the backend. */

static void
infercompn(Node *n, Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
	Node *aggr;
	Node *memb;
	Node **nl;
	Type *t;
	size_t i;
	int found;

	aggr = n->expr.args[0];
	memb = n->expr.args[1];

	found = 0;
	t = tybase(tf(type(aggr)));
	/* all array-like types have a fake "len" member that we emulate */
	if (t->type == Tyslice || t->type == Tyarray) {
		if (!strcmp(namestr(memb), "len")) {
			constrain(n, type(n), traittab[Tcnum]);
			constrain(n, type(n), traittab[Tcint]);
			found = 1;
 	 * otherwise, we search aggregate types for the member, and unify
	 * the expression with the member type; ie:
	 *	 x: aggrtype	y : memb in aggrtype
	 *	 ---------------------------------------
	 *			   x.y : membtype
	} else {
		if (tybase(t)->type == Typtr)
			t = tybase(tf(t->sub[0]));
		if (tybase(t)->type == Tyvar) {
			if (!rem)
				fatal(n, "underspecified type defined on %s:%d used near %s",
						fname(t->loc), lnum(t->loc), ctxstr(n));
			lappend(rem, nrem, n);
			lappend(remscope, nremscope, curstab());
		} else if (tybase(t)->type != Tystruct) {
			fatal(n, "type %s does not support member operators near %s",
					tystr(t), ctxstr(n));
		nl = t->sdecls;
		for (i = 0; i < t->nmemb; i++) {
			if (!strcmp(namestr(memb), declname(nl[i]))) {
				unify(n, type(n), decltype(nl[i]));
				found = 1;
	if (!found)
		fatal(aggr, "type %s has no member \"%s\" near %s", tystr(type(aggr)),
				ctxstr(memb), ctxstr(aggr));

static void
checkstruct(Node *n, Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
	Type *t, *et;
	Node *val, *name;
	size_t i, j;

	t = tybase(tf(n->lit.type));
	 * If we haven't inferred the type, and it's inside another struct,
	 * we'll eventually get to it.
	 * If, on the other hand, it is genuinely underspecified, we'll give
	 * a better error on it later.
	if (t->type == Tyvar)
	if (t->type != Tystruct)
		fatal(n, "struct literal is used as non struct type %s", tystr(n->lit.type));
	for (i = 0; i < n->expr.nargs; i++) {
		val = n->expr.args[i];
		name = val->expr.idx;

		et = NULL;
		for (j = 0; j < t->nmemb; j++) {
			if (!strcmp(namestr(t->sdecls[j]->, namestr(name))) {
				et = type(t->sdecls[j]);

		if (!et) {
			if (rem) {
				lappend(rem, nrem, n);
				lappend(remscope, nremscope, curstab());
			} else {
				fatal(n, "could not find member %s in struct %s, near %s",
						namestr(name), tystr(t), ctxstr(n));

		unify(val, et, type(val));

static void
checkvar(Node *n, Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
	Node *proto, *dcl;
	Type *ty;

	proto = decls[n->expr.did];
	ty = NULL;
	dcl = NULL;
	if (n->expr.param)
		dcl = htget(proto->decl.impls, tf(n->expr.param));
	if (dcl)
		ty = dcl->decl.type;
	if (!ty)
		ty = tyfreshen(NULL, type(proto));
	unify(n, type(n), ty);

static void
fixiter(Node *n, Type *ty, Type *base)
	size_t i, bestidx;
	int r, bestrank;
	Type *b, *t, *orig;
	Tysubst *ts;
	Node *impl;

	ty = tysearch(ty);
	b = ty->seqaux;
	if (!b)
	bestrank = -1;
	bestidx = 0;
	for (i = 0; i < nimpltab; i++) {
		if (impltab[i]->impl.trait != traittab[Tciter])
		r = tymatchrank(impltab[i]->impl.type, ty);
		if (r > bestrank) {
			bestrank = r;
			bestidx = i;
	if (bestrank >= 0) {
		impl = impltab[bestidx];
		orig = impl->impl.type;
		t = tf(impl->impl.aux[0]);
		ts = mksubst();
		for (i = 0; i < ty->narg; i++)
			substput(ts, tf(orig->arg[i]), ty->arg[i]);
		t = tyfreshen(ts, t);
		unify(n, t, base);

static void
postcheckpass(Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
	size_t i;
	Node *n;

	for (i = 0; i < npostcheck; i++) {
		n = postcheck[i];
		if (n->type == Nexpr) {
			switch (exprop(n)) {
			case Omemb:	infercompn(n, rem, nrem, remscope, nremscope);	break;
			case Ocast:	checkcast(n, rem, nrem, remscope, nremscope);	break;
			case Ostruct:	checkstruct(n, rem, nrem, remscope, nremscope);	break;
			case Ovar:	checkvar(n, rem, nrem, remscope, nremscope);	break;
			default:	die("should not see %s in postcheck\n", opstr[exprop(n)]);
		} else if (n->type == Niterstmt) {
			fixiter(n, type(n->iterstmt.seq), type(n->iterstmt.elt));

static void
	size_t nrem, nremscope;
	Stab **remscope;
	Node **rem;

	while (1) {
		remscope = NULL;
		nremscope = 0;
		rem = NULL;
		nrem = 0;
		postcheckpass(&rem, &nrem, &remscope, &nremscope);
		if (nrem == npostcheck) {
		postcheck = rem;
		npostcheck = nrem;
		postcheckscope = remscope;
		npostcheckscope = nremscope;

/* After inference, replace all
 * types in symbol tables with
 * the final computed types */
static void
stabsub(Stab *s)
	void **k;
	size_t n, i;
	Type *t;
	Node *d;
	char *dt;

	k = htkeys(s->ty, &n);
	for (i = 0; i < n; i++) {
		t = tysearch(gettype(s, k[i]));
		updatetype(s, k[i], t);
		tyfix(k[i], t, 0);

	k = htkeys(s->dcl, &n);
	for (i = 0; i < n; i++) {
		d = getdcl(s, k[i]);
		d->decl.type = tyfix(d, d->decl.type, 0);
		if (!d->decl.isconst && !d->decl.isgeneric)
		if (d->decl.trait)
		dt = "const";
		if (d->decl.isgeneric)
			dt = "generic";
		if (!d->decl.isimport && !d->decl.isextern && !d->decl.init)
			fatal(d, "non-extern %s \"%s\" has no initializer", dt, ctxstr(d));

static void
checkrange(Node *n)
	Type *t;
	int64_t sval;
	uint64_t uval;
	static const int64_t svranges[][2] = {
		/* signed ints */
		[Tyint8] = {-128LL, 127LL}, [Tyint16] = {-32768LL, 32767LL},
		/* FIXME: this has been doubled allow for uints... */
		[Tyint32] = {-2147483648LL, 2 * 2147483647LL},
		[Tyint] = {-2147483648LL, 2 * 2147483647LL},
		[Tyint64] = {-9223372036854775808ULL, 9223372036854775807LL},

	static const uint64_t uvranges[][2] = {
		[Tybyte] = {0, 255ULL}, [Tyuint8] = {0, 255ULL}, [Tyuint16] = {0, 65535ULL},
		[Tyuint32] = {0, 4294967295ULL}, [Tyuint64] = {0, 18446744073709551615ULL},
		[Tychar] = {0, 4294967295ULL},

	/* signed types */
	t = type(n);
	if (t->type >= Tyint8 && t->type <= Tyint64) {
		sval = n->lit.intval;
		if (sval < svranges[t->type][0] || sval > svranges[t->type][1])
			fatal(n, "literal value %lld out of range for type \"%s\"", sval, tystr(t));
	} else if ((t->type >= Tybyte && t->type <= Tyint64) || t->type == Tychar) {
		uval = n->lit.intval;
		if (uval < uvranges[t->type][0] || uval > uvranges[t->type][1])
			fatal(n, "literal value %llu out of range for type \"%s\"", uval, tystr(t));

static int
initcompatible(Type *t)
	if (t->type != Tyfunc)
		return 0;
	if (t->nsub != 1)
		return 0;
	if (tybase(t->sub[0])->type != Tyvoid)
		return 0;
	return 1;

static int
maincompatible(Type *t)
	if (t->nsub > 2)
		return 0;
	if (tybase(t->sub[0])->type != Tyvoid)
		return 0;
	if (t->nsub == 2) {
		t = tybase(t->sub[1]);
		if (t->type != Tyslice)
			return 0;
		t = tybase(t->sub[0]);
		if (t->type != Tyslice)
			return 0;
		t = tybase(t->sub[0]);
		if (t->type != Tybyte)
			return 0;
	return 1;

static void
verifyop(Node *n)
	Type *ty, *bty;
	Node *a;
	char *sa, *sb;
	size_t i, j;
	int found;

	ty = exprtype(n);
	bty = tybase(ty);
	switch (exprop(n)) {
	case Ostruct:
		if (bty->type != Tystruct)
			fatal(n, "wrong type for struct literal: %s\n", tystr(ty));
		for (i = 0; i < n->expr.nargs; i++) {
			found = 0;
			a = n->expr.args[i];
			for (j = 0; j < bty->nmemb; j++) {
				sa = namestr(a->expr.idx);
				sb = declname(bty->sdecls[j]);
				if (!strcmp(sa, sb))
					found = 1;
			if (!found)
				fatal(n, "type %s has no member %s",
				    tystr(ty), namestr(a->expr.idx));
	case Oarr:
		if (bty->type != Tyarray)
			fatal(n, "wrong type for struct literal: %s\n", tystr(ty));

/* After type inference, replace all type
 * with the final computed type */
static void
typesub(Node *n, int noerr)
	size_t i;

	if (!n)
	switch (n->type) {
	case Nfile:
		for (i = 0; i < n->file.nstmts; i++)
			typesub(n->file.stmts[i], noerr);
	case Ndecl:
		settype(n, tyfix(n, type(n), noerr));
		if (n->decl.init)
			typesub(n->decl.init, noerr);
		if (streq(declname(n), "main"))
			if (!maincompatible(tybase(decltype(n))))
				fatal(n, "main must be (->void) or (byte[:][:] -> void), got %s",
		if (streq(declname(n), "__init__"))
			if (!initcompatible(tybase(decltype(n))))
				fatal(n, "__init__ must be (->void), got %s", tystr(decltype(n)));
		if (n->decl.isauto)
			adddispspecialization(n, curstab());
	case Nblock:
		for (i = 0; i < n->block.nstmts; i++)
			typesub(n->block.stmts[i], noerr);
	case Nifstmt:
		typesub(n->ifstmt.cond, noerr);
		typesub(n->ifstmt.iftrue, noerr);
		typesub(n->ifstmt.iffalse, noerr);
	case Nloopstmt:
		typesub(n->loopstmt.cond, noerr);
		typesub(n->loopstmt.init, noerr);
		typesub(n->loopstmt.step, noerr);
		typesub(n->loopstmt.body, noerr);
	case Niterstmt:
		typesub(n->iterstmt.elt, noerr);
		typesub(n->iterstmt.seq, noerr);
		typesub(n->iterstmt.body, noerr);
		if (!ingeneric)
			additerspecialization(n, curstab());
	case Nmatchstmt:
		typesub(n->matchstmt.val, noerr);
		for (i = 0; i < n->matchstmt.nmatches; i++) {
			typesub(n->matchstmt.matches[i], noerr);
	case Nmatch:
		typesub(n->match.pat, noerr);
		typesub(n->match.block, noerr);
	case Nexpr:
		settype(n, tyfix(n, type(n), 0));
		if (n->expr.param)
			n->expr.param = tyfix(n, n->expr.param, 0);
		typesub(n->expr.idx, noerr);
		if (exprop(n) == Ocast && exprop(n->expr.args[0]) == Olit &&
				n->expr.args[0]->expr.args[0]->lit.littype == Lint) {
			settype(n->expr.args[0], exprtype(n));
			settype(n->expr.args[0]->expr.args[0], exprtype(n));
		for (i = 0; i < n->expr.nargs; i++)
			typesub(n->expr.args[i], noerr);
		if (!noerr)
	case Nfunc:
		settype(n, tyfix(n, n->func.type, 0));
		for (i = 0; i < n->func.nargs; i++)
			typesub(n->func.args[i], noerr);
		typesub(n->func.body, noerr);
	case Nlit:
		settype(n, tyfix(n, type(n), 0));
		switch (n->lit.littype) {
		case Lfunc:	typesub(n->lit.fnval, noerr);	break;
		case Lint:	checkrange(n);
		default: break;
	case Nimpl: putimpl(curstab(), n);
	case Nname:
	case Nuse: break;
	case Nnone:	die("Nnone should not be seen as node type!");	break;

static Type *
itertype(Node *n, Type *ret)
	Type *it, *val, *itp, *valp, *fn;

	it = exprtype(n);
	itp = mktyptr(n->loc, it);
	val = basetype(it);
	if (!val)
		die("FAIL! %s", tystr(it));
	valp = mktyptr(n->loc, val);
	fn = mktyfunc(n->loc, NULL, 0, ret);
	lappend(&fn->sub, &fn->nsub, itp);
	lappend(&fn->sub, &fn->nsub, valp);
	return fn;

/* Take generics and build new versions of them
 * with the type parameters replaced with the
 * specialized types */
static void
	Node *d, *n, *name;
	Type *ty, *it, *dt;
	size_t i;
	Trait *tr;

	for (i = 0; i < nimpldecl; i++) {
		d = impldecl[i];
		lappend(&file->file.stmts, &file->file.nstmts, d);
		typesub(d, 0);

	for (i = 0; i < nspecializations; i++) {
		n = specializations[i];
		if (n->type == Nexpr) {
			d = specializedcl(genericdecls[i], n->expr.param, n->expr.type, &name);
			n->expr.args[0] = name;
			n->expr.did = d->decl.did;

			/* we need to sub in default types in the specialization, so call
			 * typesub on the specialized function */
			typesub(d, 0);
		} else if (n->type == Niterstmt) {
			tr = traittab[Tciter];
			assert(tr->nproto == 2);
			ty = exprtype(n->iterstmt.seq);
			if (ty->type == Typaram)

			it = itertype(n->iterstmt.seq, mktype(n->loc, Tybool));
			d = specializedcl(tr->proto[0], ty, it, &name);
			htput(tr->proto[0]->decl.impls, ty, d);

			it = itertype(n->iterstmt.seq, mktype(n->loc, Tyvoid));
			d = specializedcl(tr->proto[1], ty, it, &name);
			htput(tr->proto[1]->decl.impls, ty, d);
		} else if (n->type == Ndecl && n->decl.isauto) {
			tr = traittab[Tcdisp];
			assert(tr->nproto == 1);
			ty = decltype(n);
			dt = mktyfunc(n->loc, NULL, 0, mktype(n->loc, Tyvoid));
			lappend(&dt->sub, &dt->nsub, ty);
			d = specializedcl(tr->proto[0], ty, dt, &name);
			htput(tr->proto[0]->decl.impls, ty, d);
		} else {
			die("unknown node for specialization\n");

static Traitmap *
	Traitmap *m;

	m = zalloc(sizeof(Traitmap));
	m->traits = mkbs();
	m->name = mkht(namehash, nameeq);
	return m;

static void
	size_t i;

	/* char::(numeric,integral) */
	for (i = 0; i < Ntypes; i++)
		traitmap->sub[i] = mktraitmap();

	bsput(traitmap->sub[Tychar]->traits, Tcnum);
	bsput(traitmap->sub[Tychar]->traits, Tcint);

	bsput(traitmap->sub[Tybyte]->traits, Tcnum);
	bsput(traitmap->sub[Tybyte]->traits, Tcint);

	/* <integer types>::(numeric,integral) */
	for (i = Tyint8; i < Tyflt32; i++) {
		bsput(traitmap->sub[i]->traits, Tcnum);
		bsput(traitmap->sub[i]->traits, Tcint);

	/* <floats>::(numeric,floating) */
	bsput(traitmap->sub[Tyflt32]->traits, Tcnum);
	bsput(traitmap->sub[Tyflt32]->traits, Tcflt);
	bsput(traitmap->sub[Tyflt64]->traits, Tcnum);
	bsput(traitmap->sub[Tyflt64]->traits, Tcflt);

	/* @a*::(sliceable) */
	bsput(traitmap->sub[Typtr]->traits, Tcslice);

	/* @a[:]::(indexable,sliceable) */
	bsput(traitmap->sub[Tyslice]->traits, Tcidx);
	bsput(traitmap->sub[Tyslice]->traits, Tcslice);
	bsput(traitmap->sub[Tyslice]->traits, Tciter);

	/* @a[SZ]::(indexable,sliceable) */
	bsput(traitmap->sub[Tyarray]->traits, Tcidx);
	bsput(traitmap->sub[Tyarray]->traits, Tcslice);
	bsput(traitmap->sub[Tyarray]->traits, Tciter);

	/* @a::function */
	bsput(traitmap->sub[Tyfunc]->traits, Tcfunc);

static Trait*
findtrait(Node *impl)
	Trait *tr;
	Node *n;
	Stab *ns;

	tr = impl->impl.trait;
	if (!tr) {
		n = impl->impl.traitname;
		ns = file->file.globls;
		if (n->name.ns)
			ns = getns(n->name.ns);
		if (ns)
			tr = gettrait(ns, n);
		if (!tr)
			fatal(impl, "trait %s does not exist near %s",
			    namestr(impl->impl.traitname), ctxstr(impl));
		if (tr->naux != impl->impl.naux)
			fatal(impl, "incompatible implementation of %s: mismatched aux types",
			    namestr(impl->impl.traitname), ctxstr(impl));
	return tr;

static void
addtraittab(Traitmap *m, Trait *tr, Type *ty)
	Bitset *bs;
	Traitmap *mm;
	size_t i;

	if (!m->sub[ty->type])
		m->sub[ty->type] = mktraitmap();
	mm = m->sub[ty->type];
	switch (ty->type) {
	case Tygeneric:
	case Typaram:
		for (i = 0; i < Ntypes; i++) {
			if (i == Typaram)
			if (!m->sub[i])
				m->sub[i] = mktraitmap();
			mm = m->sub[i];
			lappend(&mm->filter, &mm->nfilter, ty);
			lappend(&mm->filtertr, &mm->nfiltertr, tr);
	case Tyname:
		if (ty->ngparam == 0) {
			bs = htget(mm->name, ty->name);
			if (!bs) {
				bs = mkbs();
				htput(mm->name, ty->name, bs);
			bsput(bs, tr->uid);
		} else {
			lappend(&mm->filter, &m->nfilter, ty);
			lappend(&mm->filtertr, &m->nfiltertr, tr);
	case Typtr:
	case Tyarray:
		addtraittab(mm, tr, ty->sub[0]);
		if (istyprimitive(ty)) {
			bsput(mm->traits, tr->uid);
		} else {
			lappend(&mm->filter, &mm->nfilter, ty);
			lappend(&mm->filtertr, &mm->nfiltertr, tr);

static void
	size_t i;
	Node *impl;
	Trait *tr;
	Type *ty;

	traitmap = zalloc(sizeof(Traitmap));
	for (i = 0; i < nimpltab; i++) {
		impl = impltab[i];
		tr = findtrait(impl);

		ty = tf(impl->impl.type);
		addtraittab(traitmap, tr, ty);
		if (tr->uid == Tciter)
			ty->seqaux = tf(impl->impl.aux[0]);

static void
	Type *t;
	Node *n;
	size_t i;

	/* for now, traits can only be declared globally */
	for (i = 0; i < file->file.nstmts; i++) {
		n = file->file.stmts[i];
		if (n->type == Nimpl) {
			/* we merge, so we need to get it back again when error checking */
			if (n->impl.isproto)
				fatal(n, "missing implementation for prototype '%s %s'",
					namestr(n->impl.traitname), tystr(n->impl.type));
	for (i = 0; i < ntypes; i++) {
		t = types[i];
		if (t->type != Tyarray)
		t->asize = fold(t->asize, 1);
		if (t->asize && exprop(t->asize) != Olit)
			fatal(t->asize, "nonconstant array size near %s\n", ctxstr(t->asize));

	delayed = mkht(tyhash, tyeq);

	/* do the inference */
	infernode(&file, NULL, NULL);

	/* and replace type vars with actual types */
	typesub(file, 0);