shithub: mc

Download patch

ref: eab0f16acb161defcd8e6fa258c361df7e51968e
parent: aeb291f9491d8ef630d8b1e41116ebe4080fb8e3
author: Ori Bernstein <[email protected]>
date: Tue Oct 7 08:18:12 EDT 2014

Work on trimming BBs and inseting Orets.

    This is to enable work on checking BBs for missed returns. This
    seems to be working, but is overzealous, and will error out
    when we try to call terminating functions like die().

--- a/6/asm.h
+++ b/6/asm.h
@@ -198,7 +198,7 @@
 extern Loc **locmap; /* mapping from reg id => Loc * */
 
 char *genlblstr(char *buf, size_t sz);
-Node *genlbl(void);
+Node *genlbl(int line);
 Loc *loclbl(Node *lbl);
 Loc *locstrlbl(char *lbl);
 Loc *locreg(Mode m);
--- a/6/isel.c
+++ b/6/isel.c
@@ -675,13 +675,16 @@
         case Ocall:
             r = gencall(s, n);
             break;
+        case Oret: 
+            a = locstrlbl(s->cfg->end->lbls[0]);
+            g(s, Ijmp, a, NULL);
+            break;
         case Ojmp:
-            g(s, Ijmp, a = loclbl(args[0]), NULL);
+            g(s, Ijmp, loclbl(args[0]), NULL);
             break;
         case Ocjmp:
             selcjmp(s, n, args);
             break;
-
         case Olit: /* fall through */
             r = loc(s, n);
             break;
@@ -750,7 +753,7 @@
         /* These operators should never show up in the reduced trees,
          * since they should have been replaced with more primitive
          * expressions by now */
-        case Obad: case Oret: case Opreinc: case Opostinc: case Opredec:
+        case Obad: case Opreinc: case Opostinc: case Opredec:
         case Opostdec: case Olor: case Oland: case Oaddeq:
         case Osubeq: case Omuleq: case Odiveq: case Omodeq: case Oboreq:
         case Obandeq: case Obxoreq: case Obsleq: case Obsreq: case Omemb:
@@ -1001,6 +1004,8 @@
         fprintf(fd, ".globl %s\n", fn->name);
     fprintf(fd, "%s:\n", fn->name);
     for (j = 0; j < s->cfg->nbb; j++) {
+        if (!s->bb[j])
+            continue;
         for (i = 0; i < s->bb[j]->nlbls; i++)
             fprintf(fd, "%s:\n", s->bb[j]->lbls[i]);
         for (i = 0; i < s->bb[j]->ni; i++)
@@ -1012,6 +1017,8 @@
 {
     Asmbb *as;
 
+    if (!bb)
+        return NULL;
     as = zalloc(sizeof(Asmbb));
     as->id = bb->id;
     as->pred = bsdup(bb->pred);
@@ -1238,6 +1245,8 @@
     prologue(&is, fn->stksz);
     for (j = 0; j < fn->cfg->nbb - 1; j++) {
         is.curbb = is.bb[j];
+        if (!is.bb[j])
+            continue;
         for (i = 0; i < fn->cfg->bb[j]->nnl; i++) {
             /* put in a comment that says where this line comes from */
             snprintf(buf, sizeof buf, "\n\t# bb = %ld, bbidx = %ld, %s:%d",
--- a/6/locs.c
+++ b/6/locs.c
@@ -79,12 +79,12 @@
     return buf;
 }
 
-Node *genlbl(void)
+Node *genlbl(int line)
 {
     char buf[128];
 
     genlblstr(buf, 128);
-    return mklbl(-1, buf);
+    return mklbl(line, buf);
 }
 
 Loc *locstrlbl(char *lbl)
--- a/6/ra.c
+++ b/6/ra.c
@@ -261,6 +261,8 @@
     bb = s->bb;
     nbb = s->nbb;
     for (i = 0; i < nbb; i++) {
+        if (!bb[i])
+            continue;
         udcalc(s->bb[i]);
         bb[i]->livein = bsclear(bb[i]->livein);
         bb[i]->liveout = bsclear(bb[i]->liveout);
@@ -270,6 +272,8 @@
     while (changed) {
         changed = 0;
         for (i = nbb - 1; i >= 0; i--) {
+            if (!bb[i])
+                continue;
             old = bsdup(bb[i]->liveout);
             /* liveout[b] = U(s in succ) livein[s] */
             for (j = 0; bsiter(bb[i]->succ, &j); j++)
@@ -456,6 +460,8 @@
     nbb = s->nbb;
 
     for (i = 0; i < nbb; i++) {
+        if (!bb[i])
+            continue;
         live = bsdup(bb[i]->liveout);
         for (j = bb[i]->ni - 1; j >= 0; j--) {
             insn = bb[i]->il[j];
@@ -1086,6 +1092,8 @@
 
     new = NULL;
     nnew = 0;
+    if (!bb)
+        return;
     for (j = 0; j < bb->ni; j++) {
         /* if there is a remapping, insert the loads and stores as needed */
         if (remap(s, bb->il[j], use, &nuse, def, &ndef)) {
@@ -1186,6 +1194,8 @@
     size_t i, j;
 
     for (i = 0; i < s->nbb; i++) {
+        if (!s->bb[i])
+            continue;
         new = NULL;
         nnew = 0;
         bb = s->bb[i];
@@ -1325,6 +1335,8 @@
     fprintf(fd, "ASM -------- \n");
     for (j = 0; j < s->nbb; j++) {
         bb = s->bb[j];
+        if (!bb)
+            continue;
         fprintf(fd, "\n");
         fprintf(fd, "Bb: %d labels=(", bb->id);
         sep = "";
--- a/6/simp.c
+++ b/6/simp.c
@@ -490,12 +490,12 @@
     Node *l1, *l2, *l3;
     Node *iftrue, *iffalse;
 
-    l1 = genlbl();
-    l2 = genlbl();
+    l1 = genlbl(n->line);
+    l2 = genlbl(n->line);
     if (exit)
         l3 = exit;
     else
-        l3 = genlbl();
+        l3 = genlbl(n->line);
 
     iftrue = n->ifstmt.iftrue;
     iffalse = n->ifstmt.iffalse;
@@ -537,10 +537,10 @@
     Node *lcond;
     Node *lstep;
 
-    lbody = genlbl();
-    lcond = genlbl();
-    lstep = genlbl();
-    lend = genlbl();
+    lbody = genlbl(n->line);
+    lcond = genlbl(n->line);
+    lstep = genlbl(n->line);
+    lend = genlbl(n->line);
 
     lappend(&s->loopstep, &s->nloopstep, lstep);
     lappend(&s->loopexit, &s->nloopexit, lend);
@@ -583,11 +583,11 @@
     Node *idx, *len, *dcl, *seq, *val, *done;
     Node *zero;
 
-    lbody = genlbl();
-    lstep = genlbl();
-    lcond = genlbl();
-    lmatch = genlbl();
-    lend = genlbl();
+    lbody = genlbl(n->line);
+    lstep = genlbl(n->line);
+    lcond = genlbl(n->line);
+    lmatch = genlbl(n->line);
+    lend = genlbl(n->line);
 
     lappend(&s->loopstep, &s->nloopstep, lstep);
     lappend(&s->loopexit, &s->nloopexit, lend);
@@ -694,7 +694,7 @@
             str = lit->lit.strval;
 
             /* load slice length */
-            next = genlbl();
+            next = genlbl(pat->line);
             x = slicelen(s, val);
             len = strlen(str);
             y = mkintlit(lit->line, len);
@@ -704,9 +704,9 @@
             append(s, next);
 
             for (i = 0; i < len; i++) {
-                next = genlbl();
+                next = genlbl(pat->line);
                 x = mkintlit(pat->line, str[i]);
-                x->expr.type = mktype(-1, Tybyte);
+                x->expr.type = mktype(pat->line, Tybyte);
                 idx = mkintlit(pat->line, i);
                 idx->expr.type = tyintptr;
                 y = load(idxaddr(s, val, idx));
@@ -735,7 +735,7 @@
             off = 0;
             for (i = 0; i < pat->expr.nargs; i++) {
                 off = alignto(off, exprtype(patarg[i]));
-                next = genlbl();
+                next = genlbl(pat->line);
                 v = load(addk(addr(s, val, exprtype(patarg[i])), off));
                 matchpattern(s, patarg[i], v, exprtype(patarg[i]), next, iffalse);
                 append(s, next);
@@ -747,7 +747,7 @@
             patarg = pat->expr.args;
             for (i = 0; i < pat->expr.nargs; i++) {
                 off = offset(pat, patarg[i]->expr.idx);
-                next = genlbl();
+                next = genlbl(pat->line);
                 v = load(addk(addr(s, val, exprtype(patarg[i])), off));
                 matchpattern(s, patarg[i], v, exprtype(patarg[i]), next, iffalse);
                 append(s, next);
@@ -758,7 +758,7 @@
             if (!uc)
                 uc = finducon(val);
 
-            deeper = genlbl();
+            deeper = genlbl(pat->line);
 
             x = uconid(s, pat);
             y = uconid(s, val);
@@ -782,7 +782,7 @@
     Node *m;
     size_t i;
 
-    end = genlbl();
+    end = genlbl(n->line);
     val = temp(s, n->matchstmt.val);
     tmp = rval(s, n->matchstmt.val, val);
     if (val != tmp)
@@ -791,8 +791,8 @@
         m = n->matchstmt.matches[i];
 
         /* check pattern */
-        cur = genlbl();
-        next = genlbl();
+        cur = genlbl(n->line);
+        next = genlbl(n->line);
         matchpattern(s, m->match.pat, val, val->expr.type, cur, next);
 
         /* do the action if it matches */
@@ -876,11 +876,11 @@
 
     /* create expressions */
     cmp = mkexpr(idx->line, Olt, ptrsized(s, idx), ptrsized(s, len), NULL);
-    cmp->expr.type = mktype(-1, Tybool);
-    ok = genlbl();
-    fail = genlbl();
+    cmp->expr.type = mktype(len->line, Tybool);
+    ok = genlbl(len->line);
+    fail = genlbl(len->line);
     die = mkexpr(idx->line, Ocall, abortfunc, NULL);
-    die->expr.type = mktype(-1, Tyvoid);
+    die->expr.type = mktype(len->line, Tyvoid);
 
     /* insert them */
     cjmp(s, cmp, ok, fail);
@@ -969,13 +969,13 @@
     args = n->expr.args;
     switch (exprop(n)) {
         case Oland:
-            lnext = genlbl();
+            lnext = genlbl(n->line);
             simpcond(s, args[0], lnext, lfalse);
             append(s, lnext);
             simpcond(s, args[1], ltrue, lfalse);
             break;
         case Olor:
-            lnext = genlbl();
+            lnext = genlbl(n->line);
             simpcond(s, args[0], ltrue, lnext);
             append(s, lnext);
             simpcond(s, args[1], ltrue, lfalse);
@@ -1308,9 +1308,9 @@
 
     /* set up temps and labels */
     r = temp(s, n);
-    ltrue = genlbl();
-    lfalse = genlbl();
-    ldone = genlbl();
+    ltrue = genlbl(n->line);
+    lfalse = genlbl(n->line);
+    ldone = genlbl(n->line);
 
     /* simp the conditional */
     simpcond(s, n, ltrue, lfalse);
@@ -1532,7 +1532,7 @@
             for (i = 0; i < s->nqueue; i++)
                 append(s, s->incqueue[i]);
             lfree(&s->incqueue, &s->nqueue);
-            jmp(s, s->endlbl);
+            append(s, mkexpr(n->line, Oret, NULL));
             break;
         case Oasn:
             r = assign(s, args[0], args[1]);
@@ -1686,7 +1686,7 @@
     assert(f->type == Nfunc);
     s->nstmts = 0;
     s->stmts = NULL;
-    s->endlbl = genlbl();
+    s->endlbl = genlbl(f->line);
     s->ret = NULL;
 
     /* make a temp for the return type */
@@ -1708,12 +1708,16 @@
     append(s, s->endlbl);
 }
 
-static Func *simpfn(Simp *s, char *name, Node *n, Vis vis)
+static Func *simpfn(Simp *s, char *name, Node *dcl)
 {
+    Node *n;
+    Vis vis;
     size_t i;
     Func *fn;
     Cfg *cfg;
 
+    n = dcl->decl.init;
+    vis = dcl->decl.vis;
     if(debugopt['i'] || debugopt['F'] || debugopt['f'])
         printf("\n\nfunction %s\n", name);
 
@@ -1743,7 +1747,9 @@
         }
     }
 
-    cfg = mkcfg(s->stmts, s->nstmts);
+    cfg = mkcfg(dcl, s->stmts, s->nstmts);
+    if (debugopt['C'])
+	check(cfg);
     if (debugopt['t'] || debugopt['s'])
         dumpcfg(cfg, stdout);
 
@@ -1845,7 +1851,7 @@
     if (dcl->decl.isextern || dcl->decl.isgeneric)
         return;
     if (isconstfn(dcl)) {
-        f = simpfn(&s, name, dcl->decl.init, dcl->decl.vis);
+        f = simpfn(&s, name, dcl);
         lappend(fn, nfn, f);
     } else {
         simpconstinit(&s, dcl);
--- a/opt/cfg.c
+++ b/opt/cfg.c
@@ -34,10 +34,15 @@
     return n->expr.args[0]->lit.lblval;
 }
 
+static void strlabel(Cfg *cfg, char *lbl, Bb *bb)
+{
+    htput(cfg->lblmap, lbl, bb);
+    lappend(&bb->lbls, &bb->nlbls, lbl);
+}
+
 static void label(Cfg *cfg, Node *lbl, Bb *bb)
 {
-    htput(cfg->lblmap, lblstr(lbl), bb);
-    lappend(&bb->lbls, &bb->nlbls, lblstr(lbl));
+    strlabel(cfg, lblstr(lbl), bb);
 }
 
 static int addnode(Cfg *cfg, Bb *bb, Node *n)
@@ -45,6 +50,7 @@
     switch (exprop(n)) {
         case Ojmp:
         case Ocjmp:
+        case Oret:
             lappend(&bb->nl, &bb->nnl, n);
             lappend(&cfg->fixjmp, &cfg->nfixjmp, n);
             lappend(&cfg->fixblk, &cfg->nfixblk, bb);
@@ -85,8 +91,84 @@
     return bb;
 }
 
-Cfg *mkcfg(Node **nl, size_t nn)
+void delete(Cfg *cfg, Bb *bb)
 {
+    size_t i, j;
+
+    if (bb == cfg->start || bb == cfg->end)
+        return;
+    for (i = 0; bsiter(bb->pred, &i); i++) {
+        bsunion(cfg->bb[i]->succ, bb->succ);
+        bsdel(cfg->bb[i]->succ, bb->id);
+    }
+    for (i = 0; bsiter(bb->succ, &i); i++) {
+        bsunion(cfg->bb[i]->pred, bb->pred);
+        bsdel(cfg->bb[i]->pred, bb->id);
+        for (j = 0; j < bb->nlbls; j++)
+            strlabel(cfg, bb->lbls[j], cfg->bb[i]);
+    }
+    cfg->bb[bb->id] = NULL;
+}
+
+void trimdead(Bb *bb)
+{
+    size_t i;
+
+    for (i = 0; i < bb->nnl; i++) {
+        switch (exprop(bb->nl[i])) {
+            /* if we're jumping, we can't keep going
+             * within this BB */
+            case Ojmp:
+            case Ocjmp:
+            case Oret:
+                bb->nnl = i + 1;
+                return;
+            default:
+                /* nothing */
+                break;
+        }
+    }
+}
+
+void trim(Cfg *cfg)
+{
+    Bb *bb;
+    size_t i;
+
+    /* delete empty blocks and trivially unreachable code */
+    for (i = 0; i < cfg->nbb; i++) {
+
+        bb = cfg->bb[i];
+        if (bb->nnl == 0)
+            delete(cfg, bb);
+        else
+            trimdead(bb);
+    }
+}
+
+void delunreachable(Cfg *cfg)
+{
+    Bb *bb;
+    size_t i;
+    int deleted;
+
+    deleted = 1;
+    while (deleted) {
+        deleted = 0;
+        for (i = 0; i < cfg->nbb; i++) {
+            bb = cfg->bb[i];
+            if (bb == cfg->start || bb == cfg->end)
+                continue;
+            if (bb && bsisempty(bb->pred)) {
+                delete(cfg, bb);
+                deleted = 1;
+            }
+        }
+    }
+}
+
+Cfg *mkcfg(Node *fn, Node **nl, size_t nn)
+{
     Cfg *cfg;
     Bb *pre, *post;
     Bb *bb, *targ;
@@ -94,6 +176,7 @@
     size_t i;
 
     cfg = zalloc(sizeof(Cfg));
+    cfg->fn = fn;
     cfg->lblmap = mkht(strhash, streq);
     pre = mkbb(cfg);
     bb = mkbb(cfg);
@@ -113,10 +196,14 @@
         }
     }
     post = mkbb(cfg);
+    cfg->start = pre;
+    cfg->end = post;
     bsput(pre->succ, cfg->bb[1]->id);
     bsput(cfg->bb[1]->pred, pre->id);
     bsput(cfg->bb[cfg->nbb - 2]->succ, post->id);
     bsput(post->pred, cfg->bb[cfg->nbb - 2]->id);
+    trim(cfg);
+
     for (i = 0; i < cfg->nfixjmp; i++) {
         bb = cfg->fixblk[i];
         switch (exprop(cfg->fixjmp[i])) {
@@ -128,6 +215,10 @@
                 a = cfg->fixjmp[i]->expr.args[1];
                 b = cfg->fixjmp[i]->expr.args[2];
                 break;
+            case Oret:
+                a = mklbl(cfg->fixjmp[i]->line, cfg->end->lbls[0]);
+                b = NULL;
+                break;
             default:
                 die("Bad jump fix thingy");
                 break;
@@ -147,6 +238,7 @@
             bsput(targ->pred, bb->id);
         }
     }
+    delunreachable(cfg);
     return cfg;
 }
 
@@ -158,6 +250,8 @@
 
     for (j = 0; j < cfg->nbb; j++) {
         bb = cfg->bb[j];
+        if (!bb)
+            continue;
         fprintf(fd, "\n");
         fprintf(fd, "Bb: %d labels=(", bb->id);
         sep = "";
--- a/opt/df.c
+++ b/opt/df.c
@@ -31,10 +31,40 @@
 }
 */
 
-void flow(Cfg *cfg)
+static void checkreach(Cfg *cfg)
 {
 }
 
-void checkret(Cfg *cfg)
+static void checkpredret(Cfg *cfg, Bb *bb)
 {
+    Bb *pred;
+    size_t i;
+
+    for (i = 0; bsiter(bb->pred, &i); i++) {
+        pred = cfg->bb[i];
+        if (pred->nnl == 0) {
+            checkpredret(cfg, pred);
+        } else if (exprop(pred->nl[pred->nnl - 1]) != Oret) {
+            dumpcfg(cfg, stdout);
+            fatal(pred->nl[pred->nnl-1]->line, "Reaches end of function without return\n");
+        }
+    }
+}
+
+static void checkret(Cfg *cfg)
+{
+    Type *ft;
+
+    ft = tybase(decltype(cfg->fn));
+    assert(ft->type == Tyfunc);
+    if (ft->sub[0]->type == Tyvoid)
+        return;
+
+    checkpredret(cfg, cfg->end);
+}
+
+void check(Cfg *cfg)
+{
+    checkret(cfg);
+    checkreach(cfg);
 }
--- a/opt/opt.h
+++ b/opt/opt.h
@@ -2,6 +2,7 @@
 typedef struct Bb Bb;
 
 struct  Cfg {
+    Node *fn;
     Bb **bb;
     Bb *start;
     Bb *end;
@@ -29,6 +30,6 @@
 /* expression folding */
 Node *fold(Node *n, int foldvar);
 /* Takes a reduced block, and returns a flow graph. */
-Cfg *mkcfg(Node **nl, size_t nn);
+Cfg *mkcfg(Node *fn, Node **nl, size_t nn);
 void dumpcfg(Cfg *c, FILE *fd);
-void flow(Cfg *cfg);
+void check(Cfg *cfg);
--- a/parse/bitset.c
+++ b/parse/bitset.c
@@ -208,3 +208,13 @@
             return 0;
     return 1;
 }
+
+int bsisempty(Bitset *set)
+{
+    size_t i;
+
+    for (i = 0; i < set->nchunks; i++)
+        if (set->chunks[i])
+            return 0;
+    return 1;
+}
--- a/parse/dump.c
+++ b/parse/dump.c
@@ -110,7 +110,7 @@
         findentf(fd, depth, "Nil\n");
         return;
     }
-    findentf(fd, depth, "%s", nodestr(n->type));
+    findentf(fd, depth, "%s.%zd@%i", nodestr(n->type), n->nid, n->line);
     switch(n->type) {
         case Nfile:
             fprintf(fd, "(name = %s)\n", n->file.name);
--- a/parse/parse.h
+++ b/parse/parse.h
@@ -363,6 +363,7 @@
 void bsdiff(Bitset *a, Bitset *b);
 int  bseq(Bitset *a, Bitset *b);
 int  bsissubset(Bitset *set, Bitset *sub);
+int  bsisempty(Bitset *set);
 int  bsiter(Bitset *bs, size_t *elt);
 size_t bsmax(Bitset *bs);
 size_t bscount(Bitset *bs);