How to replace a Block with a new one at Spatial IR level inside a compiler pass?

28 views
Skip to first unread message

Ruizhe Zhao

unread,
Apr 29, 2019, 12:43:16 PM4/29/19
to Spatial Users
Hi there,

Sorry for disturbing you.

I am currently working on replacing a Block inside a Pass, specifically, a MutateTransformer, to a new Block constructed from some information that can be collected from the design at the Spatial IR level. And I'm stuck with finding the right way to do so.

A concrete toy example (which may not be very useful) is that, suppose I have a new primitive called PRIM that takes in 2 Bits values, e.g., x and y. The pass I intend to devise will replace that PRIM to an actual computation between them, let's say, x ** 2 + y ** 2 if x and y are added in this design, otherwise, do nothing.

// Primitive definition
@op
case class PRIM[A:Bits](a: Bits[A], b: Bits[A]) extends Primitive[A] {}

// The API to use this primitive
@api def prim[A](a: Bits[A], b: Bits[A]): A = {
 
implicit val tA: Bits[A] = a.selfType
  stage
(PRIM[A](a, b))
}

// The design
@spatial object ToyDesign extends SpatialApp {

 
def main(args: Array[String]): Unit = {
   
type T = FixPt[TRUE, _24, _8]

    val x = ArgIn[T]
    val y
= ArgIn[T]
    setArg
(x, 2.0)
    setArg(y, 4.0)

    // The gradient output
    val z = ArgOut[T]

   
Accel {
      val d = x + y

      z := prim(x, y) // should be replaced
   
}

    println
(s"x = $x, dy/dx = $dydx")
 
}
}

As far as I understand, this compiler pass needs to:

1) Locate the statement of PRIM op
2) Find out whether there is an addition (FixAdd) between the two inputs of PRIM
3) If so, construct a new Block, which has a list of statements (e.g., FixAdd(FixMul(x, x), FixMul(y, y))); otherwise, create an empty block
4) Replace the PRIM statement with this new Block



The major problems that I met while I was trying to implement these steps are as follows:

1) In an override transform method, how could I create a new Block and use it to replace an old one?

I checked RegReadCSE and PipeInserter. I kind of get the idea that a transform should contain these steps: pattern matching, building a new op through an API (inside spatial.lang), wrapping it with stage(), and returning the staged value. Alternatively, I notice we can also use a register function to register a mapping from a block, which is extracted from a ControlBody, to a new Block constructed by some custom methods.

I'm not sure about which direction should I go with for the case given above, and what are the exact steps to take. 

2) What are the purposes of calling super.transform and f (for MutateTransform)?

Specifically, is there any concrete examples of the effect from mirror and update?

3) Once we've implemented a Transform, how could we test it?






In general, I have a rough idea about the algorithm behind my pass but I'm not clear about how to actually implement it respecting the type system inside Spatial. It would be great if anyone can draft the compiler pass mentioned above (no need to be very accurate, just an outline would be perfect) and give me a hint about those questions.

Thanks in advance for any help.

Best regards
Ruizhe



David Koeplinger

unread,
Apr 29, 2019, 1:04:28 PM4/29/19
to Spatial Users
Great, this is exactly what MutateTransformer was created to do!

MutateTransformer works through substitution rules. Unless a node is explicitly replaced, the transformer updates all of it's inputs based on a symbol mapping from s -> s'.

1) In an override transform method, how could I create a new Block and use it to replace an old one?
You probably don't need to create a new block. Note that PRIM is already in a Block, and by default if you have an arbitrary sequence of expressions in its substitution rule, those will be inlined in that block in place of PRIM.
Check out RewriteTransformer in Spatial for an example. 

2) What are the purposes of calling super.transform and f (for MutateTransform)?
Calling `f(s)` looks up the substitution (just in case there is one). This is generally needed when you look at inputs of an Op. 

Calling super.transform, by default, notes that you are *not* updating this node, and instead should update its inputs based on other updated nodes. 

3) Once we've implemented a Transform, how could we test it?
The simplest way is to create a Spatial application like the one you have above, enable the transformation in the compiler, and run the compiler in verbose mode.
The compiler will output logs, both before and after the transformer. You can go and check that the IR has been changed the way you want from there. 
Unit testing of a single transformer in isolation is a little harder, we don't have a good test framework in place for that yet. 

The basic idea for your algorithm could look like:
case class MyTransformer(IR: State) extends MutateTransformer {
  val added
: Set[(Sym[_],Sym[_])] = Set.empty

  override def transform[A:Type](lhs: Sym[A], rhs: Op[A])(implicit ctx: SrcCtx): Sym[A] = rhs match {
     
case Add(a, b) =>  
       added
+= ((a,b))         // 1) Mark a and b as having been added
       
super.transform(lhs,rhs) // 2) Transform addition as normal
     
case PRIM(a, b) if added.contains((a,b)) =>
       
// Expand to a*a + b*b
       
// This is slightly more complicated because a and b are either Flt or Fix
       
// Check out RewriteTransformer on how to do this
     
case _ => super.transform(lhs,rhs) // Transform other operations as normal
 
}
}      





Ruizhe Zhao

unread,
May 1, 2019, 2:58:16 AM5/1/19
to Spatial Users
Hey David,

Thank you so much for your clarification. Would you mind me asking several more questions based on your reply and my reading about the RewriteTransformer?

those will be inlined in that block in place of PRIM.

I'm not sure about what inlining means in a Transformer. Does it actually apply the chain of substitution rules? And when will it be called?


Calling super.transform, by default, notes that you are *not* updating this node, and instead should update its inputs based on other updated nodes. 


Suppose I write a transform function, and in some of its cases I update the current statement, should I make sure I call f? Or in other word, is it a convention to always update the inputs of a statement? Or maybe an update here only apply the substitution rules to get the correct symbol?



Questions about RewriteTransformer:

1) A typical procedure to create a new symbol that refers to a statement replacement is: creating the node and then staging it, as far as I can understand. Documents in Staging say that stage creates a symbol for the newly created symbol (like assigning an ID?).  But I cannot get a clue about the purpose of stageWithFlow, as well as what "flow" means in the context of staging.

2) For this specific line of code, how does the unwrapping work with F for the inputs? Is it like looking up key by value?

3) Should we manually call transferData, and if so, at which points?


In general the RewriteTransformer is indeed a good start, but I think I still have several tiny points to be resolved before fully understanding what is going on in that class.

Thank you so much for your patience and time!

Best regards
Ruizhe

David Koeplinger

unread,
May 1, 2019, 4:00:42 AM5/1/19
to spatial-l...@googlegroups.com
Good timing, you caught me right as I was checking email. 

Good questions, let's take a concrete example, that'll probably help solidify this. 

Suppose I've written a simplified version of your code in Spatial:
val x = ArgIn[Int]
val y
= ArgIn[Int]
val z
= ArgOut[Int]

Accel{
  val d
= x +
y
  z
:= PRIM(x, y)
}

The (simplified) IR for that looks like:
Block1: // Program block
  x0 
= ArgInNew[Int](0) // aka x
  x1 
= ArgInNew[Int](0) // aka y
  x2 
= ArgOutNew[Int]() // aka z
  x3 
= Accel(Block2)  
 
 
Block2: // Accel block
    x4
= RegRead(x0)
    x5
= RegRead(x1)
    x6
= FixAdd(x4, x5) // aka d
    x7
= PRIM(x4, x5)
    x8
= RegWrite(x2, x7)
 
// End of Block2
// End of Block1

Let's suppose we run the transformer we've been talking about, replacing PRIM(x, y) -> x*x + y*y if there is a prior x + y. We'll ignore this doesn't make much sense as a transformation rule for now. Let's also assume for simplicity that PRIM is only on fixed point numbers for now. What should this transformer do for a single PRIM instance?
1. Anywhere a = PRIM(b,c) is in the graph, replace it with three operations: a1 = FixMul(b,b), a2 = FixMul(c,c), a3 = Add(a1, a2)
2. Replace anything that used "a" with "a3"

If we follow both of those rules, we would have a graph that looks like:

Block3: // Program block
  x0 
= ArgInNew[Int](0) // aka x
  x1 
= ArgInNew[Int](0) // aka y
  x2 
= ArgOutNew[Int]() // aka z
  x3
= Accel(Block2)  
 
 
Block4: // Accel block
    x4
= RegRead(x0)
    x5
= RegRead(x1)
    x6
= FixAdd(x4, x5) // aka d
    x9
= FixMul(x4, x4)   // was PRIM (x7)
    x10
= FixMul(x5, x5)  // was PRIM (x7)
    x11
= FixAdd(x9, x10) // was PRIM (x7)
    x8
= RegWrite(x2, x11) // Previously used PRIM (x7)
 
// End of Block2
// End of Block1


This is what I meant by "inlined." We haven't introduced a new Block, we've only replaced PRIM in Block2 with a couple other operations. This is what an Op transformation rule does by default. 

So we got the creation of those three things from the definition of the transform rule, but two questions remain. 
1. How did we know to remove the original PRIM node?
2. How did we know what node to replace it with? (in x8, for example, we now x11 instead)

For 1, the answer is pretty simple. When we run a transformer, for each block, we visit each statement and re-register it with the block in the default (super.transform) transform rule. If you override the method to give a custom transformation, as we did here, no such registration occurs, and the original node is left out of the new block. Note that Block1 and Block2 are actually replaced by new blocks with mostly the same contents.

For 2, the answer is the surrounding transformer class. When we visit block and call transform, we check if the symbol returned from transform is the same as the one it was called on. If it isn't, that means the transform method was overridden with a custom transformation rule and we've created a transformation for this node. In that case, we register in a hashmap the substition s -> s' (original symbol to new symbol). 
In the default transformer rule, nodes have all of their inputs updated using this hashmap. This is how x8 has come to now use the new symbol.


Ok, so now your second question. When should you use the magical "f" method? The answer is, whenever the input in question might be transformed. For example, if we had:
val x = ArgIn[Int]
val y
= ArgIn[Int]
val z
= ArgOut[Int]

Accel{
  val d
= x +
y
  val e
= PRIM(x, y)
  val g 
= PRIM(y, x)
  z
:= PRIM(e, f)
}

This will result in a chain of substitutions, where PRIM "e" and "g" are transformed, and the correct transformation of "PRIM(e,g)" uses e' and g' (the transformed version of those PRIMs), not e and g themselves. 

A good rule of thumb is to just call the "f" function on every input, unless you know for sure that it won't be updated. (And if you're new to this, you won't know for sure, so just always call it.)


RewriteTransformer questions
1) stageWithFlow - you generally won't need to call this. At one point I realized that a lot of our analysis passes were extraneous if you could just fuse (akin to loop fusion) them with the staging process itself. The stageWithFlow gives you a way of defining custom analyses while staging the node separate from the DSL definition itself, but there are few use cases for this in practice. Generally, I recommend, wherever possible, calling either the API method that exists to stage a node, e.g. use "x + y" or add(x, y) rather than stage(FixAdd(x,y)).

2) Yes. The F here is the same as the f method, but can be used in pattern matching. It literally just looks up a symbol in the hashmap I mentioned before.

3) The answer here is, it depends, but if you don't know why you would need to call it, you probably don't need to call it. The transformer moves metadata on simple substitutions by default, but more complicated things can sometimes cause there to be two (or more!) symbols that should both get the metadata.

Suppose we had an analysis pass that chose a "favorite" symbol. Suppose it marked the metadata "MyFavorite" on a PRIM node. Now we run the transformer, and we have three nodes where one used to be. Who gets to be MyFavorite? By default, the answer is "whatever symbol is replacing the original," aka the result of the transform method. But we could also use the transferData method explicitly to give all three nodes the metadata that the original symbol had. 

I'm probably going on for too long here, but just to finish this thought: You can see here how there starts to be an interesting dynamic between preserving analysis metadata and writing a transformer. If you actually want to preserve all information, the transformer has to be fully aware of what information existed before it and semantically how that metadata should propagate through it. In practice, in most cases it's far easier and just as correct to just re-add the metadata after the transformer. But this means you need to schedule a lot of analysis passes after most transformers to make sure the data is available. Ideally, it'd be nice to instead have a way of automatically re-populating the metadata on the fly for performance (and developer sanity) reasons. This is the reason those "flow rules" were added to the compiler - do the analysis on the fly while you add things to the graph, assuming that the information you need is based on data "flow". 

Anyway, sorry for the long answer. Hope this helps clear things up a bit.

Ruizhe Zhao

unread,
May 3, 2019, 10:34:25 AM5/3/19
to Spatial Users
Hi David,

Thank you very much for your detailed reply! I really like the concrete examples you've listed, they make understanding the mechanism of transformation much easier.

I have created a demo based on what we've discussed in this post: https://github.com/kumasento/spatial/tree/dev-trans. I've changed the semantics of prim a little bit: if an addition doesn't appear, we replace prim with a multiplication.

The test demo resides under the apps/ directory.


It almost works: I've checked the IR that the compiler printed, and I can notice that prim has been properly replaced. But there are some issues that I might need some further help:

1) I tried to print the value of the output from the newly replaced registers, which are surprisingly symbols rather than values. getArg is not helpful. This is the reason that those assertions in the demo will fail because symbols don't match.
2) I am not sure about where should I place this new pass inside Spatial. Currently I just insert it after blackbox lowering.
3) Looking up whether two variables have been added is a bit tricky: a RegRead will be inserted before any variable passed to prim and therefore change the symbol. I need to manually extract the source of RegRead by pattern matching:

(a, b) match {
 
case (Op(RegRead(ra)), Op(RegRead(rb))) =>
    println(s"Current set is $added, required symbol: $ra, $rb")

   
if (added.contains((ra, rb)))
      stage
(FixAdd(stage(FixMul(a, a)), stage(FixMul(b, b))))
   
else
      stage(FixMul(a, b))
 
case _ =>
    stage(FixAdd(a, b))
}

Is there any more straightforward solution that I can use to decide the equivalence of two variables?


Thank you again for all your clarification!

Best regards
Ruizhe

David Koeplinger

unread,
May 3, 2019, 2:45:39 PM5/3/19
to Spatial Users
Glad you were able to make progress!

1) The syntax in Spatial is `r"my value is $x" (using r"" not s""). The r"" syntax delays stringifying things to runtime, s"" does it as part of application compilation

2) Ordering of passes in a compiler is more of an art than a science, and is very context dependent. Just after black box lowering seems like a reasonable placement for your new pass in this case.

3) Yes, unfortunately variable equivalence is actually a rather hard problem in general when writing a compiler pass, especially when you have to account for things like these register reads and writes.

There are two approaches here: 
a. Write the simple rules like you have here and expect this will cover most of the cases you care about in practice. This often works well enough because Spatial also does some basic common subexpression elimination already, so the number of cases with identical values with different names is reduced a bit.

b. OR figure out how to do a comprehensive graph equivalence analysis based on the dataflow dependencies of any two symbols, including checks for whether the value passes through registers and whether that register may be overwritten with an incompatible value on any control path reaching the point you're currently looking at. This is... doable... but generally overkill. 
(Incidentally, this is the reason ArgIns were added to Spatial - their value is runtime dependent, but the value of all reads to one ArgIn is constant throughout the entire body of Accel by contract with the host). 




Ruizhe Zhao

unread,
May 4, 2019, 9:03:55 AM5/4/19
to Spatial Users
Hi David

Thank you so much for your reply! Your suggestions are really helpful and I can manage to get the code running smoothly.

I am now moving forward to some more exciting and practical optimisation passes to be implemented on Spatial. Will let you know how it progresses.

Best regards
Ruizhe
Reply all
Reply to author
Forward
0 new messages