shithub: mc

ref: c20862cda53c711fe476f6c7d0f6631af47a4933
dir: /mi/match.c/

View raw version
#include <stdio.h>
#include <inttypes.h>
#include <stdarg.h>
#include <ctype.h>
#include <string.h>
#include <assert.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>

#include "parse.h"
#include "mi.h"

typedef struct Dtree Dtree;
struct Dtree {
    int id;
    Srcloc loc;

    /* values for matching */
    Node *lbl;
    Node *load;
    size_t nconstructors;
    int accept;
    int emitted;

    /* the decision tree step */
    Node **pat;
    size_t npat;
    Dtree **next;
    size_t nnext;
    Dtree *any;

    /* captured variables and action */
    Node **cap;
    size_t ncap;

};

Dtree *gendtree(Node *m, Node *val, Node **lbl, size_t nlbl);
static int addpat(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend);
void dtreedump(FILE *fd, Dtree *dt);


static Node *utag(Node *n)
{
    Node *tag;

    tag = mkexpr(n->loc, Outag, n, NULL);
    tag->expr.type = mktype(n->loc, Tyint32);
    return tag;
}

static Node *uvalue(Node *n, Type *ty)
{
    Node *elt;

    elt = mkexpr(n->loc, Oudata, n, NULL);
    elt->expr.type = ty;
    return elt;
}

static Node *tupelt(Node *n, size_t i)
{
    Node *idx, *elt;

    idx = mkintlit(n->loc, i);
    idx->expr.type = mktype(n->loc, Tyuint64);
    elt = mkexpr(n->loc, Otupget, n, idx, NULL);
    elt->expr.type = tybase(exprtype(n))->sub[i];
    return elt;
}

static Node *arrayelt(Node *n, size_t i)
{
    Node *idx, *elt;

    idx = mkintlit(n->loc, i);
    idx->expr.type = mktype(n->loc, Tyuint64);
    elt = mkexpr(n->loc, Oidx, n, idx, NULL);
    elt->expr.type = tybase(exprtype(n))->sub[0];
    return elt;
}

static Node *findmemb(Node *pat, Node *name)
{
    Node *n;
    size_t i;

    for (i = 0; i < pat->expr.nargs; i++) {
        n = pat->expr.args[i];
        if (nameeq(n->expr.idx, name))
            return n;
    }
    return NULL;
}

static Dtree *dtbytag(Dtree *t, Ucon *uc)
{
    uint32_t tagval;
    Node *taglit;
    size_t i;

    for (i = 0; i < t->npat; i++) {
        taglit = t->pat[i]->expr.args[0];
        tagval = taglit->lit.intval;
        if (tagval == uc->id) {
            return t->next[i];
        }
    }
    return NULL;
}

static Node *structmemb(Node *n, Node *name, Type *ty)
{
    Node *elt;

    elt = mkexpr(n->loc, Omemb, n, name, NULL);
    elt->expr.type = ty;
    return elt;
}

static Node *addcapture(Node *n, Node **cap, size_t ncap)
{
    Node **blk;
    size_t nblk, i;

    nblk = 0;
    blk = NULL;

    for (i = 0; i < ncap; i++)
        lappend(&blk, &nblk, cap[i]);
    for (i = 0; i < n->block.nstmts; i++)
        lappend(&blk, &nblk, n->block.stmts[i]);
    lfree(&n->block.stmts, &n->block.nstmts);
    n->block.stmts = blk;
    n->block.nstmts = nblk;
    return n;
}

static Dtree *mkdtree(Srcloc loc, Node *lbl)
{
    static int ndtree;
    Dtree *t;

    t = zalloc(sizeof(Dtree));
    t->lbl = lbl;
    t->loc = loc;
    t->id = ndtree++;
    return t;
}

static Dtree *nextnode(Srcloc loc, size_t idx, size_t count, Dtree *accept)
{
    if (idx == count - 1)
        return accept;
    else
        return mkdtree(loc, genlbl(loc));
}

static size_t nconstructors(Type *t)
{
    if (!t)
        return 0;

    t = tybase(t);
    switch (t->type) {
        case Tyvoid:    return 0;               break;
        case Tybool:    return 2;               break;
        case Tychar:    return 0x10ffff;        break;

        /* signed ints */
        case Tyint8:    return 0x100;           break;
        case Tyint16:   return 0x10000;         break;
        case Tyint32:   return 0x100000000;     break;
        case Tyint:     return 0x100000000;     break;
        case Tyint64:   return ~0ull;           break;

        /* unsigned ints */
        case Tybyte:    return 0x100;           break;
        case Tyuint8:   return 0x100;           break;
        case Tyuint16:  return 0x10000;         break;
        case Tyuint32:  return 0x100000000;     break;
        case Tyuint:    return 0x100000000;     break;
        case Tyuint64:  return ~0ull;           break;

        /* floats */
        case Tyflt32:   return ~0ull;           break;
        case Tyflt64:   return ~0ull;           break;

        /* complex types */
        case Typtr:     return 1;       break;
        case Tyarray:   return 1;       break;
        case Tytuple:   return 1;       break;
        case Tystruct:  return 1;
        case Tyunion:   return t->nmemb;        break;
        case Tyslice:   return ~0ULL;   break;

        case Tyvar: case Typaram: case Tyunres: case Tyname:
        case Tybad: case Tyvalist: case Tygeneric: case Ntypes:
        case Tyfunc: case Tycode:
            die("Invalid constructor type %s in match", tystr(t));
            break;
    }
    return 0;
}

static int verifymatch(Dtree *t)
{
    size_t i;
    int ret;

    if (t->accept)
        return 1;

    ret = 0;
    if (t->nnext == t->nconstructors || t->any)
        ret = 1;
    for (i = 0; i < t->nnext; i++)
        if (!verifymatch(t->next[i]))
            ret = 0;
    return ret;
}

static int acceptall(Dtree *t, Dtree *accept)
{
    size_t i;
    int ret;

    if (t->any == accept)
        return 0;

    ret = 0;
    if (t->any) {
        if (acceptall(t->any, accept))
            ret = 1;
    } else {
        t->any = accept;
        ret = 1;
    }

    for (i = 0; i < t->nnext; i++)
        if (acceptall(t->next[i], accept))
            ret = 1;
    return ret;
}

static int addwildrec(Srcloc loc, Type *ty, Dtree *start, Dtree *accept, Dtree ***end, size_t *nend)
{
    Dtree *next, **last, **tail;
    size_t i, j, nelt, nlast, ntail;
    Node *asize;
    Ucon *uc;
    int ret;

    tail = NULL;
    ntail = 0;
    ty = tybase(ty);
    if (istyprimitive(ty)) {
        for (i = 0; i < start->nnext; i++)
            lappend(end, nend, start->next[i]);
        if (start->any) {
            lappend(end, nend, start->any);
            return 0;
        } else {
            start->any = accept;
            lappend(end, nend, accept);
            return 1;
        }
    }
    
    ret = 0;
    last = NULL;
    nlast = 0;
    lappend(&last, &nlast, start);
    switch (ty->type) {
        case Tytuple:
            for (i = 0; i < ty->nsub; i++) {
                next = nextnode(loc, i, ty->nsub, accept);
                tail = NULL;
                ntail = 0;
                for (j = 0; j < nlast; j++)
                    if (addwildrec(loc, ty->sub[i], last[j], next, &tail, &ntail))
                        ret = 1;
                lfree(&last, &nlast);
                last = tail;
                nlast = ntail;
            }
            break;
        case Tyarray:
            asize = fold(ty->asize, 1);
            nelt = asize->expr.args[0]->lit.intval;
            for (i = 0; i < nelt; i++) {
                next = nextnode(loc, i, nelt, accept);
                tail = NULL;
                ntail = 0;
                for (j = 0; j < nlast; j++)
                    if (addwildrec(loc, ty->sub[0], last[j], next, &tail, &ntail))
                        ret = 1;
                lfree(&last, &nlast);
                last = tail;
                nlast = ntail;
            }
            break;
        case Tystruct:
            for (i = 0; i < ty->nmemb; i++) {
                next = nextnode(loc, i, ty->nmemb, accept);
                tail = NULL;
                ntail = 0;
                for (j = 0; j < nlast; j++)
                    if (addwildrec(loc, decltype(ty->sdecls[i]), last[j], next, &tail, &ntail))
                        ret = 1;
                lfree(&last, &nlast);
                last = tail;
                nlast = ntail;
            }
            break;
        case Tyunion:
            for (i = 0; i < ty->nmemb; i++) {
                uc = ty->udecls[i];
                next = dtbytag(start, uc);
                if (next) {
                    if (uc->etype) {
                        if (addwildrec(loc, uc->etype, next, accept, end, nend))
                            ret = 1;
                    } else {
                        lappend(end, nend, next);
                    }
                }
            }
            if (!start->any) {
                start->any = accept;
                ret = 1;
            }
            lappend(end, nend, start->any);
            break;
        case Tyslice:
            ret = acceptall(start, accept);
            lappend(&last, &nlast, accept);
            break;
        default:
            lappend(&last, &nlast, accept);
            break;
    }
    lcat(end, nend, last, nlast);
    lfree(&last, &nlast);
    return ret;
}

static int addwild(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend)
{
    Node *asn;

    asn = mkexpr(pat->loc, Oasn, pat, val, NULL);
    asn->expr.type = exprtype(pat);
    if (cap && ncap) {
        asn = mkexpr(pat->loc, Oasn, pat, val, NULL);
        asn->expr.type = exprtype(pat);
        lappend(cap, ncap, asn);
    }
    return addwildrec(pat->loc, exprtype(pat), start, accept, end, nend);
}

static int addunion(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend)
{
    Node *tagid;
    Dtree *next;
    Ucon *uc;

    if (start->any) {
        lappend(end, nend, start->any);
        return 0;
    }

    uc = finducon(tybase(exprtype(pat)), pat->expr.args[0]);
    next = dtbytag(start, uc);
    if (next) {
        if (!uc->etype) {
            lappend(end, nend, next);
            return 0;
        } else {
            return addpat(pat->expr.args[1], uvalue(val, uc->etype), next, accept, cap, ncap, end, nend);
        }
    }

    if (!start->load) {
        start->load = utag(val);
        start->nconstructors = nconstructors(tybase(exprtype(pat)));
    }

    tagid = mkintlit(pat->loc, uc->id);
    tagid->expr.type = mktype(pat->loc, Tyint32);
    lappend(&start->pat, &start->npat, tagid);
    if (uc->etype) {
        next = mkdtree(pat->loc, genlbl(pat->loc));
        lappend(&start->next, &start->nnext, next);
        addpat(pat->expr.args[1], uvalue(val, uc->etype), next, accept, cap, ncap, end, nend);
    } else {
        lappend(&start->next, &start->nnext, accept);
        lappend(end, nend, accept);
    }
    return 1;
}

static int addstr(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend)
{
    Dtree **tail, **last, *next;
    size_t i, j, n, ntail, nlast;
    Node *p, *v, *lit;
    Type *ty;
    char *s;
    int ret;

    lit = pat->expr.args[0];
    n = lit->lit.strval.len;
    s = lit->lit.strval.buf;

    ty = mktype(pat->loc, Tyuint64);
    p = mkintlit(lit->loc, n);
    v = structmemb(val, mkname(pat->loc, "len"), ty);
    p->expr.type = ty;

    if (n == 0)
        next = accept;
    else
        next = mkdtree(pat->loc, genlbl(pat->loc));

    last = NULL;
    nlast = 0;
    if (addpat(p, v, start, next, cap, ncap, &last, &nlast))
        ret = 1;

    ty = mktype(pat->loc, Tybyte);
    for (i = 0; i < n; i++) {
        p = mkintlit(lit->loc, s[i]);
        p->expr.type = ty;
        v = arrayelt(val, i);

        tail = NULL;
        ntail = 0;
        next = nextnode(pat->loc, i, n, accept);
        for (j = 0; j < nlast; j++) 
            if (addpat(p, v, last[j], next, NULL, NULL, &tail, &ntail))
                ret = 1;
        lfree(&last, &nlast);
        last = tail;
        nlast = ntail;
    }
    lcat(end, nend, last, nlast);
    lfree(&last, &nlast);
    return ret;
}

static int addlit(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend)
{
    size_t i;

    if (pat->expr.args[0]->lit.littype == Lstr) {
        return addstr(pat, val, start, accept, cap, ncap, end, nend);
    } else {
        /* if we already have a match, we're not adding a new node */
        if (start->any) {
            lappend(end, nend, start->any);
            return 0;
        }

        for (i = 0; i < start->npat; i++) {
            if (liteq(start->pat[i]->expr.args[0], pat->expr.args[0])) {
                lappend(end, nend, start->next[i]);
                return 0;
            }
        }

        /* wire up an edge from start to 'accept' */
        if (!start->load) {
            start->load = val;
            start->nconstructors = nconstructors(exprtype(pat));
        }
        lappend(&start->pat, &start->npat, pat);
        lappend(&start->next, &start->nnext, accept);
        lappend(end, nend, accept);
        return 1;
    }
}

static int addtup(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend)
{
    size_t nargs, nlast, ntail, i, j;
    Dtree *next, **last, **tail;
    Node **args;
    int ret;

    args = pat->expr.args;
    nargs = pat->expr.nargs;
    last = NULL;
    nlast = 0;
    lappend(&last, &nlast, start);
    ret = 0;

    for (i = 0; i < nargs; i++) {
        next = nextnode(args[i]->loc, i, nargs, accept);
        tail = NULL;
        ntail = 0;
        for (j = 0; j < nlast; j++) 
            if (addpat(pat->expr.args[i], tupelt(val, i), last[j], next, cap, ncap, &tail, &ntail))
                ret = 1;
        lfree(&last, &nlast);
        last = tail;
        nlast = ntail;
    }
    lcat(end, nend, last, nlast);
    lfree(&last, &nlast);
    return ret;
}

static int addarr(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend)
{
    size_t nargs, nlast, ntail, i, j;
    Dtree *next, **last, **tail;
    Node **args;
    int ret;

    args = pat->expr.args;
    nargs = pat->expr.nargs;
    last = NULL;
    nlast = 0;
    lappend(&last, &nlast, start);
    ret = 0;

    for (i = 0; i < nargs; i++) {
        next = nextnode(args[i]->loc, i, nargs, accept);
        tail = NULL;
        ntail = 0;
        for (j = 0; j < nlast; j++) 
            if (addpat(pat->expr.args[i], arrayelt(val, i), last[j], next, cap, ncap, &tail, &ntail))
                ret = 1;
        lfree(&last, &nlast);
        last = tail;
        nlast = ntail;
    }
    lcat(end, nend, last, nlast);
    lfree(&last, &nlast);
    return ret;
}

static int addstruct(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend)
{
    Dtree *next, **last, **tail;
    Node *memb, *name;
    Type *ty, *mty;
    size_t i, j, ntail, nlast;
    int ret;

    ret = 0;
    last = NULL;
    nlast = 0;
    lappend(&last, &nlast, start);
    ty = tybase(exprtype(pat));

    for (i = 0; i < ty->nmemb; i++) {
        mty = decltype(ty->sdecls[i]);
        name = ty->sdecls[i]->decl.name;
        memb = findmemb(pat, name);
        next = nextnode(pat->loc, i, ty->nmemb, accept);
        tail = NULL;
        ntail = 0;

        for (j = 0; j < nlast; j++) {
            /* add a _ capture if we don't specify the value */
            if (!memb) {
                if (addwild(memb, NULL, last[j], next, NULL, NULL, &tail, &ntail))
                    ret = 1;
            } else {
                if (addpat(memb, structmemb(val, name, mty), last[j], next, cap, ncap, &tail, &ntail))
                    ret = 1;
            }
        }
        lfree(&last, &nlast);
        last = tail;
        nlast = ntail;
    }
    lcat(end, nend, last, nlast);
    lfree(&last, &nlast);
    return ret;
}

static int addpat(Node *pat, Node *val, Dtree *start, Dtree *accept, Node ***cap, size_t *ncap, Dtree ***end, size_t *nend)
{
    int ret;
    Node *dcl;

    pat = fold(pat, 1);
    ret = 0;
    switch (exprop(pat)) {
        case Ovar:
            dcl = decls[pat->expr.did];
            if (dcl->decl.isconst)
                ret = addpat(dcl->decl.init, val, start, accept, cap, ncap, end, nend);
            else
                ret = addwild(pat, val, start, accept, cap, ncap, end, nend);
            break;
        case Oucon:
            ret = addunion(pat, val, start, accept, cap, ncap, end, nend);
            break;
        case Olit:
            ret = addlit(pat, val, start, accept, cap, ncap, end, nend);
            break;
        case Otup:
            ret = addtup(pat, val, start, accept, cap, ncap, end, nend);
            break;
        case Oarr:
            ret = addarr(pat, val, start, accept, cap, ncap, end, nend);
            break;
        case Ostruct:
            ret = addstruct(pat, val, start, accept, cap, ncap, end, nend);
            break;
        case Ogap:
            ret = addwild(pat, NULL, start, accept, NULL, NULL, end, nend);
            break;
        default:
            fatal(pat, "unsupported pattern %s of type %s", opstr[exprop(pat)], tystr(exprtype(pat)));
            break;
    }
    return ret;
}


/* val must be a pure, fully evaluated value */
Dtree *gendtree(Node *m, Node *val, Node **lbl, size_t nlbl)
{
    Dtree *start, *accept, **end;
    Node **pat, **cap;
    size_t npat, ncap, nend;
    size_t i;

    pat = m->matchstmt.matches;
    npat = m->matchstmt.nmatches;

    end = NULL;
    nend = 0;
    start = mkdtree(m->loc, genlbl(m->loc));
    for (i = 0; i < npat; i++) {
        cap = NULL;
        ncap = 0;
        accept = mkdtree(lbl[i]->loc, lbl[i]);
        accept->accept = 1;

        if (!addpat(pat[i]->match.pat, val, start, accept, &cap, &ncap, &end, &nend))
            fatal(pat[i], "pattern matched by earlier case");
        addcapture(pat[i]->match.block, cap, ncap);
    }
    if (!verifymatch(start))
        fatal(m, "nonexhaustive pattern set in match statement");
    return start;
}

void genmatchcode(Dtree *dt, Node ***out, size_t *nout)
{
    Node *jmp, *eq, *fail;
    int emit;
    size_t i;

    if (dt->emitted)
        return;
    dt->emitted = 1;

    /* the we jump to the accept label when generating the parent */
    if (dt->accept) {
        jmp = mkexpr(dt->loc, Ojmp, dt->lbl, NULL);
        jmp->expr.type = mktype(dt->loc, Tyvoid);
        lappend(out, nout, jmp);
        return;
    }

    lappend(out, nout, dt->lbl);
    for (i = 0; i < dt->nnext; i++) {
        if (i == dt->nnext - 1 && dt->any) {
            fail = dt->any->lbl;
            emit = 0;
        } else {
            fail = genlbl(dt->loc);
            emit = 1;
        }

        eq = mkexpr(dt->loc, Oeq, dt->load, dt->pat[i], NULL);
        eq->expr.type = mktype(dt->loc, Tybool);
        jmp = mkexpr(dt->loc, Ocjmp, eq, dt->next[i]->lbl, fail, NULL);
        jmp->expr.type = mktype(dt->loc, Tyvoid);
        lappend(out, nout, jmp);

        genmatchcode(dt->next[i], out, nout);
        if (emit)
            lappend(out, nout, fail);
    }
    if (dt->any) {
        jmp = mkexpr(dt->loc, Ojmp, dt->any->lbl, NULL);
        jmp->expr.type = mktype(dt->loc, Tyvoid);
        lappend(out, nout, jmp);
        genmatchcode(dt->any, out, nout);
    }
}

void genonematch(Node *pat, Node *val, Node *iftrue, Node *iffalse, Node ***out, size_t *nout, Node ***cap, size_t *ncap)
{
    Dtree *start, *accept, *reject, **end;
    size_t nend;

    end = NULL;
    nend = 0;

    start = mkdtree(pat->loc, genlbl(pat->loc));
    accept = mkdtree(iftrue->loc, iftrue);
    accept->accept = 1;
    reject = mkdtree(iffalse->loc, iffalse);
    reject->accept = 1;

    assert(addpat(pat, val, start, accept, cap, ncap, &end, &nend));
    acceptall(start, reject);
    genmatchcode(start, out, nout);
}

void genmatch(Node *m, Node *val, Node ***out, size_t *nout)
{
    Node **pat, **lbl, *end, *endlbl;
    size_t npat, nlbl, i;
    Dtree *dt;

    lbl = NULL;
    nlbl = 0;

    pat = m->matchstmt.matches;
    npat = m->matchstmt.nmatches;
    for (i = 0; i < npat; i++)
        lappend(&lbl, &nlbl, genlbl(pat[i]->match.block->loc));


    endlbl = genlbl(m->loc);
    dt = gendtree(m, val, lbl, nlbl);
    genmatchcode(dt, out, nout);

    for (i = 0; i < npat; i++) {
        end = mkexpr(pat[i]->loc, Ojmp, endlbl, NULL);
        end->expr.type = mktype(end->loc, Tyvoid);
        lappend(out, nout, lbl[i]);
        lappend(out, nout, pat[i]->match.block);
        lappend(out, nout, end);
    }
    lappend(out, nout, endlbl);
    
    if (debugopt['m']) {
        dtreedump(stdout, dt);
        for (i = 0; i < *nout; i++)
            dump((*out)[i], stdout);
    }
}


void dtreedumplit(FILE *fd, Dtree *dt, Node *n, size_t depth)
{
    char *s;

    s = lblstr(dt->lbl);
    switch (n->lit.littype) {
        case Lchr:      findentf(fd, depth, "%s: Lchr %c\n", s, n->lit.chrval); break;
        case Lbool:     findentf(fd, depth, "%s: Lbool %s\n", s, n->lit.boolval ? "true" : "false"); break;
        case Lint:      findentf(fd, depth, "%s: Lint %llu\n", s, n->lit.intval); break;
        case Lflt:      findentf(fd, depth, "%s: Lflt %lf\n", s, n->lit.fltval); break;
        case Lstr:      findentf(fd, depth, "%s: Lstr %.*s\n", s, (int)n->lit.strval.len, n->lit.strval.buf); break;
        case Llbl:      findentf(fd, depth, "%s: Llbl %s\n", s, n->lit.lblval); break;
        case Lfunc:     findentf(fd, depth, "%s: Lfunc\n"); break;
    }
}

void dtreedumpnode(FILE *fd, Dtree *dt, size_t depth)
{
    size_t i;

    if (dt->accept)
        findentf(fd, depth, "%s: accept\n", lblstr(dt->lbl));

    for (i = 0; i < dt->nnext; i++) {
        dtreedumplit(fd, dt, dt->pat[i]->expr.args[0], depth);
        dtreedumpnode(fd, dt->next[i], depth + 1);
    }
    if (dt->any) {
        findentf(fd, depth, "%s: wildcard\n", lblstr(dt->lbl));
        dtreedumpnode(fd, dt->any, depth + 1);
    }
}

void dtreedump(FILE *fd, Dtree *dt)
{
    dtreedumpnode(fd, dt, 0);
}