shithub: mc

Download patch

ref: af7de567e568e79be41ec01425422579e8d2f5ce
parent: 5007bc86a71fb725fe586acfc0d320e8f50c22b9
author: Ori Bernstein <[email protected]>
date: Tue Oct 7 08:18:12 EDT 2014

Work on trimming BBs and inseting Orets.

    This is to enable work on checking BBs for missed returns. This
    seems to be working, but is overzealous, and will error out
    when we try to call terminating functions like die().

--- a/6/asm.h
+++ b/6/asm.h
@@ -198,7 +198,7 @@
 extern Loc **locmap; /* mapping from reg id => Loc * */
 
 char *genlblstr(char *buf, size_t sz);
-Node *genlbl(void);
+Node *genlbl(int line);
 Loc *loclbl(Node *lbl);
 Loc *locstrlbl(char *lbl);
 Loc *locreg(Mode m);
--- a/6/isel.c
+++ b/6/isel.c
@@ -675,13 +675,16 @@
         case Ocall:
             r = gencall(s, n);
             break;
+        case Oret: 
+            a = locstrlbl(s->cfg->end->lbls[0]);
+            g(s, Ijmp, a, NULL);
+            break;
         case Ojmp:
-            g(s, Ijmp, a = loclbl(args[0]), NULL);
+            g(s, Ijmp, loclbl(args[0]), NULL);
             break;
         case Ocjmp:
             selcjmp(s, n, args);
             break;
-
         case Olit: /* fall through */
             r = loc(s, n);
             break;
@@ -750,7 +753,7 @@
         /* These operators should never show up in the reduced trees,
          * since they should have been replaced with more primitive
          * expressions by now */
-        case Obad: case Oret: case Opreinc: case Opostinc: case Opredec:
+        case Obad: case Opreinc: case Opostinc: case Opredec:
         case Opostdec: case Olor: case Oland: case Oaddeq:
         case Osubeq: case Omuleq: case Odiveq: case Omodeq: case Oboreq:
         case Obandeq: case Obxoreq: case Obsleq: case Obsreq: case Omemb:
@@ -1001,6 +1004,8 @@
         fprintf(fd, ".globl %s\n", fn->name);
     fprintf(fd, "%s:\n", fn->name);
     for (j = 0; j < s->cfg->nbb; j++) {
+        if (!s->bb[j])
+            continue;
         for (i = 0; i < s->bb[j]->nlbls; i++)
             fprintf(fd, "%s:\n", s->bb[j]->lbls[i]);
         for (i = 0; i < s->bb[j]->ni; i++)
@@ -1012,6 +1017,8 @@
 {
     Asmbb *as;
 
+    if (!bb)
+        return NULL;
     as = zalloc(sizeof(Asmbb));
     as->id = bb->id;
     as->pred = bsdup(bb->pred);
@@ -1238,6 +1245,8 @@
     prologue(&is, fn->stksz);
     for (j = 0; j < fn->cfg->nbb - 1; j++) {
         is.curbb = is.bb[j];
+        if (!is.bb[j])
+            continue;
         for (i = 0; i < fn->cfg->bb[j]->nnl; i++) {
             /* put in a comment that says where this line comes from */
             snprintf(buf, sizeof buf, "\n\t# bb = %ld, bbidx = %ld, %s:%d",
--- a/6/locs.c
+++ b/6/locs.c
@@ -79,12 +79,12 @@
     return buf;
 }
 
-Node *genlbl(void)
+Node *genlbl(int line)
 {
     char buf[128];
 
     genlblstr(buf, 128);
-    return mklbl(-1, buf);
+    return mklbl(line, buf);
 }
 
 Loc *locstrlbl(char *lbl)
--- a/6/ra.c
+++ b/6/ra.c
@@ -261,6 +261,8 @@
     bb = s->bb;
     nbb = s->nbb;
     for (i = 0; i < nbb; i++) {
+        if (!bb[i])
+            continue;
         udcalc(s->bb[i]);
         bb[i]->livein = bsclear(bb[i]->livein);
         bb[i]->liveout = bsclear(bb[i]->liveout);
@@ -270,6 +272,8 @@
     while (changed) {
         changed = 0;
         for (i = nbb - 1; i >= 0; i--) {
+            if (!bb[i])
+                continue;
             old = bsdup(bb[i]->liveout);
             /* liveout[b] = U(s in succ) livein[s] */
             for (j = 0; bsiter(bb[i]->succ, &j); j++)
@@ -456,6 +460,8 @@
     nbb = s->nbb;
 
     for (i = 0; i < nbb; i++) {
+        if (!bb[i])
+            continue;
         live = bsdup(bb[i]->liveout);
         for (j = bb[i]->ni - 1; j >= 0; j--) {
             insn = bb[i]->il[j];
@@ -1086,6 +1092,8 @@
 
     new = NULL;
     nnew = 0;
+    if (!bb)
+        return;
     for (j = 0; j < bb->ni; j++) {
         /* if there is a remapping, insert the loads and stores as needed */
         if (remap(s, bb->il[j], use, &nuse, def, &ndef)) {
@@ -1186,6 +1194,8 @@
     size_t i, j;
 
     for (i = 0; i < s->nbb; i++) {
+        if (!s->bb[i])
+            continue;
         new = NULL;
         nnew = 0;
         bb = s->bb[i];
@@ -1325,6 +1335,8 @@
     fprintf(fd, "ASM -------- \n");
     for (j = 0; j < s->nbb; j++) {
         bb = s->bb[j];
+        if (!bb)
+            continue;
         fprintf(fd, "\n");
         fprintf(fd, "Bb: %d labels=(", bb->id);
         sep = "";
--- a/6/simp.c
+++ b/6/simp.c
@@ -490,12 +490,12 @@
     Node *l1, *l2, *l3;
     Node *iftrue, *iffalse;
 
-    l1 = genlbl();
-    l2 = genlbl();
+    l1 = genlbl(n->line);
+    l2 = genlbl(n->line);
     if (exit)
         l3 = exit;
     else
-        l3 = genlbl();
+        l3 = genlbl(n->line);
 
     iftrue = n->ifstmt.iftrue;
     iffalse = n->ifstmt.iffalse;
@@ -537,10 +537,10 @@
     Node *lcond;
     Node *lstep;
 
-    lbody = genlbl();
-    lcond = genlbl();
-    lstep = genlbl();
-    lend = genlbl();
+    lbody = genlbl(n->line);
+    lcond = genlbl(n->line);
+    lstep = genlbl(n->line);
+    lend = genlbl(n->line);
 
     lappend(&s->loopstep, &s->nloopstep, lstep);
     lappend(&s->loopexit, &s->nloopexit, lend);
@@ -583,11 +583,11 @@
     Node *idx, *len, *dcl, *seq, *val, *done;
     Node *zero;
 
-    lbody = genlbl();
-    lstep = genlbl();
-    lcond = genlbl();
-    lmatch = genlbl();
-    lend = genlbl();
+    lbody = genlbl(n->line);
+    lstep = genlbl(n->line);
+    lcond = genlbl(n->line);
+    lmatch = genlbl(n->line);
+    lend = genlbl(n->line);
 
     lappend(&s->loopstep, &s->nloopstep, lstep);
     lappend(&s->loopexit, &s->nloopexit, lend);
@@ -694,7 +694,7 @@
             str = lit->lit.strval;
 
             /* load slice length */
-            next = genlbl();
+            next = genlbl(pat->line);
             x = slicelen(s, val);
             len = strlen(str);
             y = mkintlit(lit->line, len);
@@ -704,9 +704,9 @@
             append(s, next);
 
             for (i = 0; i < len; i++) {
-                next = genlbl();
+                next = genlbl(pat->line);
                 x = mkintlit(pat->line, str[i]);
-                x->expr.type = mktype(-1, Tybyte);
+                x->expr.type = mktype(pat->line, Tybyte);
                 idx = mkintlit(pat->line, i);
                 idx->expr.type = tyintptr;
                 y = load(idxaddr(s, val, idx));
@@ -735,7 +735,7 @@
             off = 0;
             for (i = 0; i < pat->expr.nargs; i++) {
                 off = alignto(off, exprtype(patarg[i]));
-                next = genlbl();
+                next = genlbl(pat->line);
                 v = load(addk(addr(s, val, exprtype(patarg[i])), off));
                 matchpattern(s, patarg[i], v, exprtype(patarg[i]), next, iffalse);
                 append(s, next);
@@ -747,7 +747,7 @@
             patarg = pat->expr.args;
             for (i = 0; i < pat->expr.nargs; i++) {
                 off = offset(pat, patarg[i]->expr.idx);
-                next = genlbl();
+                next = genlbl(pat->line);
                 v = load(addk(addr(s, val, exprtype(patarg[i])), off));
                 matchpattern(s, patarg[i], v, exprtype(patarg[i]), next, iffalse);
                 append(s, next);
@@ -758,7 +758,7 @@
             if (!uc)
                 uc = finducon(val);
 
-            deeper = genlbl();
+            deeper = genlbl(pat->line);
 
             x = uconid(s, pat);
             y = uconid(s, val);
@@ -782,7 +782,7 @@
     Node *m;
     size_t i;
 
-    end = genlbl();
+    end = genlbl(n->line);
     val = temp(s, n->matchstmt.val);
     tmp = rval(s, n->matchstmt.val, val);
     if (val != tmp)
@@ -791,8 +791,8 @@
         m = n->matchstmt.matches[i];
 
         /* check pattern */
-        cur = genlbl();
-        next = genlbl();
+        cur = genlbl(n->line);
+        next = genlbl(n->line);
         matchpattern(s, m->match.pat, val, val->expr.type, cur, next);
 
         /* do the action if it matches */
@@ -876,11 +876,11 @@
 
     /* create expressions */
     cmp = mkexpr(idx->line, Olt, ptrsized(s, idx), ptrsized(s, len), NULL);
-    cmp->expr.type = mktype(-1, Tybool);
-    ok = genlbl();
-    fail = genlbl();
+    cmp->expr.type = mktype(len->line, Tybool);
+    ok = genlbl(len->line);
+    fail = genlbl(len->line);
     die = mkexpr(idx->line, Ocall, abortfunc, NULL);
-    die->expr.type = mktype(-1, Tyvoid);
+    die->expr.type = mktype(len->line, Tyvoid);
 
     /* insert them */
     cjmp(s, cmp, ok, fail);
@@ -969,13 +969,13 @@
     args = n->expr.args;
     switch (exprop(n)) {
         case Oland:
-            lnext = genlbl();
+            lnext = genlbl(n->line);
             simpcond(s, args[0], lnext, lfalse);
             append(s, lnext);
             simpcond(s, args[1], ltrue, lfalse);
             break;
         case Olor:
-            lnext = genlbl();
+            lnext = genlbl(n->line);
             simpcond(s, args[0], ltrue, lnext);
             append(s, lnext);
             simpcond(s, args[1], ltrue, lfalse);
@@ -1308,9 +1308,9 @@
 
     /* set up temps and labels */
     r = temp(s, n);
-    ltrue = genlbl();
-    lfalse = genlbl();
-    ldone = genlbl();
+    ltrue = genlbl(n->line);
+    lfalse = genlbl(n->line);
+    ldone = genlbl(n->line);
 
     /* simp the conditional */
     simpcond(s, n, ltrue, lfalse);
@@ -1532,7 +1532,7 @@
             for (i = 0; i < s->nqueue; i++)
                 append(s, s->incqueue[i]);
             lfree(&s->incqueue, &s->nqueue);
-            jmp(s, s->endlbl);
+            append(s, mkexpr(n->line, Oret, NULL));
             break;
         case Oasn:
             r = assign(s, args[0], args[1]);
@@ -1686,7 +1686,7 @@
     assert(f->type == Nfunc);
     s->nstmts = 0;
     s->stmts = NULL;
-    s->endlbl = genlbl();
+    s->endlbl = genlbl(f->line);
     s->ret = NULL;
 
     /* make a temp for the return type */
@@ -1708,12 +1708,16 @@
     append(s, s->endlbl);
 }
 
-static Func *simpfn(Simp *s, char *name, Node *n, Vis vis)
+static Func *simpfn(Simp *s, char *name, Node *dcl)
 {
+    Node *n;
+    Vis vis;
     size_t i;
     Func *fn;
     Cfg *cfg;
 
+    n = dcl->decl.init;
+    vis = dcl->decl.vis;
     if(debugopt['i'] || debugopt['F'] || debugopt['f'])
         printf("\n\nfunction %s\n", name);
 
@@ -1743,7 +1747,9 @@
         }
     }
 
-    cfg = mkcfg(s->stmts, s->nstmts);
+    cfg = mkcfg(dcl, s->stmts, s->nstmts);
+    if (debugopt['C'])
+	check(cfg);
     if (debugopt['t'] || debugopt['s'])
         dumpcfg(cfg, stdout);
 
@@ -1845,7 +1851,7 @@
     if (dcl->decl.isextern || dcl->decl.isgeneric)
         return;
     if (isconstfn(dcl)) {
-        f = simpfn(&s, name, dcl->decl.init, dcl->decl.vis);
+        f = simpfn(&s, name, dcl);
         lappend(fn, nfn, f);
     } else {
         simpconstinit(&s, dcl);
--- a/opt/cfg.c
+++ b/opt/cfg.c
@@ -34,10 +34,15 @@
     return n->expr.args[0]->lit.lblval;
 }
 
+static void strlabel(Cfg *cfg, char *lbl, Bb *bb)
+{
+    htput(cfg->lblmap, lbl, bb);
+    lappend(&bb->lbls, &bb->nlbls, lbl);
+}
+
 static void label(Cfg *cfg, Node *lbl, Bb *bb)
 {
-    htput(cfg->lblmap, lblstr(lbl), bb);
-    lappend(&bb->lbls, &bb->nlbls, lblstr(lbl));
+    strlabel(cfg, lblstr(lbl), bb);
 }
 
 static int addnode(Cfg *cfg, Bb *bb, Node *n)
@@ -45,6 +50,7 @@
     switch (exprop(n)) {
         case Ojmp:
         case Ocjmp:
+        case Oret:
             lappend(&bb->nl, &bb->nnl, n);
             lappend(&cfg->fixjmp, &cfg->nfixjmp, n);
             lappend(&cfg->fixblk, &cfg->nfixblk, bb);
@@ -85,8 +91,84 @@
     return bb;
 }
 
-Cfg *mkcfg(Node **nl, size_t nn)
+void delete(Cfg *cfg, Bb *bb)
 {
+    size_t i, j;
+
+    if (bb == cfg->start || bb == cfg->end)
+        return;
+    for (i = 0; bsiter(bb->pred, &i); i++) {
+        bsunion(cfg->bb[i]->succ, bb->succ);
+        bsdel(cfg->bb[i]->succ, bb->id);
+    }
+    for (i = 0; bsiter(bb->succ, &i); i++) {
+        bsunion(cfg->bb[i]->pred, bb->pred);
+        bsdel(cfg->bb[i]->pred, bb->id);
+        for (j = 0; j < bb->nlbls; j++)
+            strlabel(cfg, bb->lbls[j], cfg->bb[i]);
+    }
+    cfg->bb[bb->id] = NULL;
+}
+
+void trimdead(Bb *bb)
+{
+    size_t i;
+
+    for (i = 0; i < bb->nnl; i++) {
+        switch (exprop(bb->nl[i])) {
+            /* if we're jumping, we can't keep going
+             * within this BB */
+            case Ojmp:
+            case Ocjmp:
+            case Oret:
+                bb->nnl = i + 1;
+                return;
+            default:
+                /* nothing */
+                break;
+        }
+    }
+}
+
+void trim(Cfg *cfg)
+{
+    Bb *bb;
+    size_t i;
+
+    /* delete empty blocks and trivially unreachable code */
+    for (i = 0; i < cfg->nbb; i++) {
+
+        bb = cfg->bb[i];
+        if (bb->nnl == 0)
+            delete(cfg, bb);
+        else
+            trimdead(bb);
+    }
+}
+
+void delunreachable(Cfg *cfg)
+{
+    Bb *bb;
+    size_t i;
+    int deleted;
+
+    deleted = 1;
+    while (deleted) {
+        deleted = 0;
+        for (i = 0; i < cfg->nbb; i++) {
+            bb = cfg->bb[i];
+            if (bb == cfg->start || bb == cfg->end)
+                continue;
+            if (bb && bsisempty(bb->pred)) {
+                delete(cfg, bb);
+                deleted = 1;
+            }
+        }
+    }
+}
+
+Cfg *mkcfg(Node *fn, Node **nl, size_t nn)
+{
     Cfg *cfg;
     Bb *pre, *post;
     Bb *bb, *targ;
@@ -94,6 +176,7 @@
     size_t i;
 
     cfg = zalloc(sizeof(Cfg));
+    cfg->fn = fn;
     cfg->lblmap = mkht(strhash, streq);
     pre = mkbb(cfg);
     bb = mkbb(cfg);
@@ -113,10 +196,14 @@
         }
     }
     post = mkbb(cfg);
+    cfg->start = pre;
+    cfg->end = post;
     bsput(pre->succ, cfg->bb[1]->id);
     bsput(cfg->bb[1]->pred, pre->id);
     bsput(cfg->bb[cfg->nbb - 2]->succ, post->id);
     bsput(post->pred, cfg->bb[cfg->nbb - 2]->id);
+    trim(cfg);
+
     for (i = 0; i < cfg->nfixjmp; i++) {
         bb = cfg->fixblk[i];
         switch (exprop(cfg->fixjmp[i])) {
@@ -128,6 +215,10 @@
                 a = cfg->fixjmp[i]->expr.args[1];
                 b = cfg->fixjmp[i]->expr.args[2];
                 break;
+            case Oret:
+                a = mklbl(cfg->fixjmp[i]->line, cfg->end->lbls[0]);
+                b = NULL;
+                break;
             default:
                 die("Bad jump fix thingy");
                 break;
@@ -147,6 +238,7 @@
             bsput(targ->pred, bb->id);
         }
     }
+    delunreachable(cfg);
     return cfg;
 }
 
@@ -158,6 +250,8 @@
 
     for (j = 0; j < cfg->nbb; j++) {
         bb = cfg->bb[j];
+        if (!bb)
+            continue;
         fprintf(fd, "\n");
         fprintf(fd, "Bb: %d labels=(", bb->id);
         sep = "";
--- a/opt/df.c
+++ b/opt/df.c
@@ -31,10 +31,40 @@
 }
 */
 
-void flow(Cfg *cfg)
+static void checkreach(Cfg *cfg)
 {
 }
 
-void checkret(Cfg *cfg)
+static void checkpredret(Cfg *cfg, Bb *bb)
 {
+    Bb *pred;
+    size_t i;
+
+    for (i = 0; bsiter(bb->pred, &i); i++) {
+        pred = cfg->bb[i];
+        if (pred->nnl == 0) {
+            checkpredret(cfg, pred);
+        } else if (exprop(pred->nl[pred->nnl - 1]) != Oret) {
+            dumpcfg(cfg, stdout);
+            fatal(pred->nl[pred->nnl-1]->line, "Reaches end of function without return\n");
+        }
+    }
+}
+
+static void checkret(Cfg *cfg)
+{
+    Type *ft;
+
+    ft = tybase(decltype(cfg->fn));
+    assert(ft->type == Tyfunc);
+    if (ft->sub[0]->type == Tyvoid)
+        return;
+
+    checkpredret(cfg, cfg->end);
+}
+
+void check(Cfg *cfg)
+{
+    checkret(cfg);
+    checkreach(cfg);
 }
--- a/opt/opt.h
+++ b/opt/opt.h
@@ -2,6 +2,7 @@
 typedef struct Bb Bb;
 
 struct  Cfg {
+    Node *fn;
     Bb **bb;
     Bb *start;
     Bb *end;
@@ -29,6 +30,6 @@
 /* expression folding */
 Node *fold(Node *n, int foldvar);
 /* Takes a reduced block, and returns a flow graph. */
-Cfg *mkcfg(Node **nl, size_t nn);
+Cfg *mkcfg(Node *fn, Node **nl, size_t nn);
 void dumpcfg(Cfg *c, FILE *fd);
-void flow(Cfg *cfg);
+void check(Cfg *cfg);
--- a/parse/bitset.c
+++ b/parse/bitset.c
@@ -208,3 +208,13 @@
             return 0;
     return 1;
 }
+
+int bsisempty(Bitset *set)
+{
+    size_t i;
+
+    for (i = 0; i < set->nchunks; i++)
+        if (set->chunks[i])
+            return 0;
+    return 1;
+}
--- a/parse/dump.c
+++ b/parse/dump.c
@@ -110,7 +110,7 @@
         findentf(fd, depth, "Nil\n");
         return;
     }
-    findentf(fd, depth, "%s", nodestr(n->type));
+    findentf(fd, depth, "%s.%zd@%i", nodestr(n->type), n->nid, n->line);
     switch(n->type) {
         case Nfile:
             fprintf(fd, "(name = %s)\n", n->file.name);
--- a/parse/parse.h
+++ b/parse/parse.h
@@ -363,6 +363,7 @@
 void bsdiff(Bitset *a, Bitset *b);
 int  bseq(Bitset *a, Bitset *b);
 int  bsissubset(Bitset *set, Bitset *sub);
+int  bsisempty(Bitset *set);
 int  bsiter(Bitset *bs, size_t *elt);
 size_t bsmax(Bitset *bs);
 size_t bscount(Bitset *bs);