Skip to content

Notes on CFG vs CPS

Nico Lehmann edited this page Feb 25, 2021 · 4 revisions

Notes on CFG vs CPS

This is an example with a simple control flow to demonstrate some practical differences between CFG and CPS.

#[liquid::ty("fn(n: i32) -> {v: i32 | v >= 0}"]
fn abs(mut n: i32) -> i32 {
    if n < 0 {
        n = -n;
    }
    n
}

CFG

This is the textual representation of the corresponding mir program produced by rustc. I added annotations on top of each basic block. Their meaning is not relevant to understand the differences between CFG and CPS. The important bit is that they should be read as preconditions that need to be satisfied before jumping into the basic block (in the form of subtyping). Important things to mention are (1) the locals (program variables in mir lingo) at the top are in scope for the entire body of the functions and (2) switchInt is a terminator, i.e., it immediately jumps to a basic block after testing the discriminant.

fn abs(_1: i32) -> {i32 | v >= 0} {
    let mut _0: i32;
    let mut _2: bool;
    let mut _3: i32;
    let mut _4: i32;

    // [g2: uninit(1), g3: {int | _ }, g4: uninit(1), g5: uninit(1), g6: uninit(1)]
    // [_0: own(g2)  , _1: own(g3)   , _2: own(g4)  , _3: own(g5)  , _4: own(g6)]
    bb0: {
        _3 = _1;
        _2 = Lt(move _3, const 0_i32);
        switchInt(_2) -> [false: bb1, otherwise: bb2];
    }

    // [g7: uninit(1), g8: {int | _ }, g9: uninit(1), g10: uninit(1), g11: uninit(1)]
    // [_0: own(g7)  , _1: own(g8)   , _2: own(g9)  , _3: own(g10)  , _4: own(g11)]
    bb1: {
        goto -> bb3;
    }

    // [g12: uninit(1), g13: {int | _ }, g14: uninit(1), g15: uninit(1), g16: uninit(1)]
    // [_0: own(g12) , _1: own(g13)   , _2: own(g14)  , _3: own(g15)  , _4: own(g16)]
    bb2: {
        _4 = _1;
        _1 = Neg(move _4);
        goto -> bb3;
    }

    // [g17: uninit(1), g18: {int | _ }, g19: uninit(1), g20: uninit(1), g21: uninit(1)]
    // [_0: own(g17) , _1: own(g18)   , _2: own(g19)  , _3: own(g20)  , _4: own(g21)]
    bb3: {
        _0 = _1; 
        return;
    }
}

CPS - Take 1

This is the naive translation into CPS from the CFG. Not surprisingly it looks almost identical to the CFG but with fancier syntax. Differences are:

  • Let bindings define a lexical scope. This program doesn't take advantage and all locals are declared at the top level.
  • It uses if then else instead of switchInt which can contain extra statements in branches and does not need to jump right away to another continuation. This program doesn't take advantage of this.
  • letcont also defines a lexical scope but this program doesn't exploit it as all continuation are declared to be mutually recursive at the top level.
fn abs(g0: int; _1: own(g0)) ret k6(g1: {int | V >= 0}; ;own(g1)) =
  let _0 = alloc(1);
  let _2 = alloc(1);
  let _3 = alloc(1);
  let _4 = alloc(1);

  letcont k0( g2: uninit(1), g3: {int | _ }, g4: uninit(1), g5: uninit(1), g6: uninit(1)
            ; _0: own(g2)  , _1: own(g3)   , _2: own(g4)  , _3: own(g5)  , _4: own(g6)
            ) =
    _3 := _1;
    _2 := move _3 < 0;
    if _2 then
      jump k2()
    else
      jump k1()
  and k1( g7: uninit(1), g8: {int | _ }, g9: uninit(1), g10: uninit(1), g11: uninit(1)
        ; _0: own(g7)  , _1: own(g8)   , _2: own(g9)  , _3: own(g10)  , _4: own(g11)
        ) =
    jump k3()
  and k2( g12: uninit(1), g13: {int | _ }, g14: uninit(1), g15: uninit(1), g16: uninit(1)
        ;  _0: own(g12) , _1: own(g13)   , _2: own(g14)  , _3: own(g15)  , _4: own(g16)
        ) =
    _4 := _1;
    _1 := - move _4;
    jump k3()
  and k3( g17: uninit(1), g18: {int | _ }, g19: uninit(1), g20: uninit(1), g21: uninit(1)
        ;  _0: own(g17) , _1: own(g18)   , _2: own(g19)  , _3: own(g20)  , _4: own(g21)
        ) =
    _0 := _1;
    jump k6(_0)
  in
  jump k0()

CPS - Take 2

One thing to note is that continuations k1 and k2 are redundant and can be "inlined" in the branches of the if then else resulting in the following program which needs to check for fewer jumps and hence it generates a smaller VC. This observation is equivalent to the fact that in the CFG one only needs to check jumping to a continuation at join points, i.e., nodes with more than one incoming edge. To generate this optimized CPS program from the CFG one would need to look at the dominator tree, which can, in turn, be used to implement the same optimization directly in the CFG. In a way, when translating to CPS one pays the price by making the translation smarter and keeping the checker "simpler" while doing the checking directly in the CFG requires making the checker smarter.

fn abs(g0: int; _1: own(g0)) ret k6(g1: {int | V >= 0}; ;own(g1)) =
  let _0 = alloc(1);
  let _2 = alloc(1);
  let _3 = alloc(1);
  let _4 = alloc(1);

  letcont k0( g2: uninit(1), g3: {int | _ }, g4: uninit(1), g5: uninit(1), g6: uninit(1)
            ; _0: own(g2)  , _1: own(g3)   , _2: own(g4)  , _3: own(g5)  , _4: own(g6)
            ) =
    _3 := _1;
    _2 := move _3 < 0;
    if _2 then
        _4 := _1;
        _1 := - move _4;
        jump k3()
    else
      jump k3()
  and k3( g17: uninit(1), g18: {int | _ }, g19: uninit(1), g20: uninit(1), g21: uninit(1)
        ;  _0: own(g17) , _1: own(g18)   , _2: own(g19)  , _3: own(g20)  , _4: own(g21)
        ) =
    _0 := _1;
    jump k6(_0)
  in
  jump k0()

CPS - Take 3

One further observation is that locals _2, _3, _4 are only used in the continuation k0 and one can push their definition inside the continuation which translates into fewer bindings in scope and fewer subtyping judgments to satisfy when jumping into a continuation. To translate the CFG into this optimized CPS program one would need to do some control flow analysis on the CFG. This particular transformation could also be easily implemented directly in the CFG because the scope of the optimized variables is contained whiting a single basic block, but it gets trickier if you want, for example, to define a variable that is in scope for a subset of the basic blocks. But anyway, I don't think this optimization is too relevant because we are already getting some of the data flow analysis information in the form of the uninit type annotations, and even though we are indeed creating more bindings during type checking, since it is not sound for refinements to depend on uninitialized variables, we could easily optimize those bindings away, perhaps in a later optimization phase directly in the VC.

fn abs(g0: int; _1: own(g0)) ret k6(g1: {int | V >= 0}; ;own(g1)) =
  let _0 = alloc(1);

  letcont k0( g2: uninit(1), g3: {int | _ }; _0: own(g2), _1: own(g3)) =
    let _3 = alloc(1);
    let _2 = alloc(1);
    let _4 = alloc(1);
    _3 := _1;
    _2 := move _3 < 0;
    if _2 then
      _4 := _1;
      _1 := -move _4;
      jump k3()
    else
      jump k3()
  and k3( g17: uninit(1), g18: {int | _ };  _0: own(g17), _1: own(g18)) =
    _0 := _1;
    jump k6(_0)
  in
  jump k0()

Generated VCs

These are the VCs generated for CPS Take 1, 2 and 3 respectively.

(var $k1 ((int) (int) (int)))
(var $k2 ((int) (int) (int)))
(var $k0 ((int) (int) (int)))
(var $k3 ((int) (int) (int)))
(constraint
  (forall ((l0 int) (true))
    (and
      (forall ((l31 int) ((true)))
        (and
          (forall ((l32 int) ((true)))
            (and
              (forall ((l33 int) ((true)))
                (and
                  (forall ((l34 int) ((true)))
                    (and
                      (forall ((l1 int) ((true)))
                        (forall ((l2 int) ($k0 l2 l0 l1))
                          (forall ((l3 int) ((true)))
                            (forall ((l4 int) ((true)))
                              (forall ((l5 int) ((true)))
                                (and
                                  (forall ((l35 int) ((l35 = l2)))
                                    (and
                                      (forall ((l36 int) ((true)))
                                        (forall ((l37 bool) ((l37 <=> (l35 < 0))))
                                          (and
                                            (forall ((_ int) (l37))
                                              (and
                                                ((true))
                                                (forall ((v int) ((v = l2)))
                                                  (tag ($k2 v l0 l1) "($k2 v l0 l1)"))
                                                ((true))
                                                ((true))
                                                ((true))))
                                            (forall ((_ int) ((not l37)))
                                              (and
                                                ((true))
                                                (forall ((v int) ((v = l2)))
                                                  (tag ($k1 v l0 l1) "($k1 v l0 l1)"))
                                                ((true))
                                                ((true))
                                                ((true)))))))
                                      ((true))))
                                  ((true))))))))
                      (forall ((l6 int) ((true)))
                        (forall ((l7 int) ($k1 l7 l0 l6))
                          (forall ((l8 int) ((true)))
                            (forall ((l9 int) ((true)))
                              (forall ((l10 int) ((true)))
                                (and
                                  ((true))
                                  (forall ((v int) ((v = l7)))
                                    (tag ($k3 v l0 l6) "($k3 v l0 l6)"))
                                  ((true))
                                  ((true))
                                  ((true))))))))
                      (forall ((l11 int) ((true)))
                        (forall ((l12 int) ($k2 l12 l0 l11))
                          (forall ((l13 int) ((true)))
                            (forall ((l14 int) ((true)))
                              (forall ((l15 int) ((true)))
                                (and
                                  (forall ((l38 int) ((l38 = l12)))
                                    (and
                                      (forall ((l39 int) ((true)))
                                        (forall ((l40 int) ((l40 = -l38)))
                                          (and
                                            ((true))
                                            (forall ((v int) ((v = l40)))
                                              (tag ($k3 v l0 l11) "($k3 v l0 l11)"))
                                            ((true))
                                            ((true))
                                            ((true)))))
                                      ((true))))
                                  ((true))))))))
                      (forall ((l16 int) ((true)))
                        (forall ((l17 int) ($k3 l17 l0 l16))
                          (forall ((l18 int) ((true)))
                            (forall ((l19 int) ((true)))
                              (forall ((l20 int) ((true)))
                                (and
                                  (forall ((l41 int) ((l41 = l17)))
                                    (forall ((v int) ((v = l41)))
                                      (tag ((v >= 0)) "((v >= 0))")))
                                  ((true))))))))
                      (and
                        ((true))
                        (forall ((v int) ((v = l0)))
                          (tag ($k0 v l0 l31) "($k0 v l0 l31)"))
                        ((true))
                        ((true))
                        ((true)))))
                  ((true))))
              ((true))))
          ((true))))
      ((true)))))
(var $k0 ((int) (int) (int)))
(var $k1 ((int) (int) (int)))
(constraint
  (forall ((l0 int) (true))
    (and
      (forall ((l21 int) ((true)))
        (and
          (forall ((l22 int) ((true)))
            (and
              (forall ((l23 int) ((true)))
                (and
                  (forall ((l24 int) ((true)))
                    (and
                      (forall ((l1 int) ((true)))
                        (forall ((l2 int) ($k0 l2 l0 l1))
                          (forall ((l3 int) ((true)))
                            (forall ((l4 int) ((true)))
                              (forall ((l5 int) ((true)))
                                (and
                                  (forall ((l25 int) ((l25 = l2)))
                                    (and
                                      (forall ((l26 int) ((true)))
                                        (forall ((l27 bool) ((l27 <=> (l25 < 0))))
                                          (and
                                            (forall ((_ int) (l27))
                                              (and
                                                (forall ((l28 int) ((l28 = l2)))
                                                  (and
                                                    (forall ((l29 int) ((true)))
                                                      (forall ((l30 int) ((l30 = -l28)))
                                                        (and
                                                          ((true))
                                                          (forall ((v int) ((v = l30)))
                                                            (tag ($k1 v l0 l1) "($k1 v l0 l1)"))
                                                          ((true))
                                                          ((true))
                                                          ((true)))))
                                                    ((true))))
                                                ((true))))
                                            (forall ((_ int) ((not l27)))
                                              (and
                                                ((true))
                                                (forall ((v int) ((v = l2)))
                                                  (tag ($k1 v l0 l1) "($k1 v l0 l1)"))
                                                ((true))
                                                ((true))
                                                ((true)))))))
                                      ((true))))
                                  ((true))))))))
                      (forall ((l6 int) ((true)))
                        (forall ((l7 int) ($k1 l7 l0 l6))
                          (forall ((l8 int) ((true)))
                            (forall ((l9 int) ((true)))
                              (forall ((l10 int) ((true)))
                                (and
                                  (forall ((l31 int) ((l31 = l7)))
                                    (forall ((v int) ((v = l31)))
                                      (tag ((v >= 0)) "((v >= 0))")))
                                  ((true))))))))
                      (and
                        ((true))
                        (forall ((v int) ((v = l0)))
                          (tag ($k0 v l0 l21) "($k0 v l0 l21)"))
                        ((true))
                        ((true))
                        ((true)))))
                  ((true))))
              ((true))))
          ((true))))
      ((true)))))
(var $k1 ((int) (int) (int)))
(var $k0 ((int) (int) (int)))
(constraint
  (forall ((l0 int) (true))
    (and
      (forall ((l15 int) ((true)))
        (and
          (forall ((l1 int) ((true)))
            (forall ((l2 int) ($k0 l2 l0 l1))
              (and
                (forall ((l16 int) ((true)))
                  (and
                    (forall ((l17 int) ((true)))
                      (and
                        (forall ((l18 int) ((true)))
                          (and
                            (forall ((l19 int) ((l19 = l2)))
                              (and
                                (forall ((l20 int) ((true)))
                                  (forall ((l21 bool) ((l21 <=> (l19 < 0))))
                                    (and
                                      (forall ((_ int) (l21))
                                        (and
                                          (forall ((l22 int) ((l22 = l2)))
                                            (and
                                              (forall ((l23 int) ((true)))
                                                (forall ((l24 int) ((l24 = -l22)))
                                                  (and
                                                    ((true))
                                                    (forall ((v int) ((v = l24)))
                                                      (tag ($k1 v l0 l1) "($k1 v l0 l1)")))))
                                              ((true))))
                                          ((true))))
                                      (forall ((_ int) ((not l21)))
                                        (and
                                          ((true))
                                          (forall ((v int) ((v = l2)))
                                            (tag ($k1 v l0 l1) "($k1 v l0 l1)")))))))
                                ((true))))
                            ((true))))
                        ((true))))
                    ((true))))
                ((true)))))
          (forall ((l3 int) ((true)))
            (forall ((l4 int) ($k1 l4 l0 l3))
              (and
                (forall ((l25 int) ((l25 = l4)))
                  (forall ((v int) ((v = l25)))
                    (tag ((v >= 0)) "((v >= 0))")))
                ((true)))))
          (and
            ((true))
            (forall ((v int) ((v = l0)))
              (tag ($k0 v l0 l15) "($k0 v l0 l15)")))))
      ((true)))))
Clone this wiki locally