Hi there,
Sorry for disturbing you.
I am currently working on replacing a Block inside a Pass, specifically, a MutateTransformer, to a new Block constructed from some information that can be collected from the design at the Spatial IR level. And I'm stuck with finding the right way to do so.
A concrete toy example (which may not be very useful) is that, suppose I have a new primitive called PRIM that takes in 2 Bits values, e.g., x and y. The pass I intend to devise will replace that PRIM to an actual computation between them, let's say, x ** 2 + y ** 2 if x and y are added in this design, otherwise, do nothing.
// Primitive definition
@op case class PRIM[A:Bits](a: Bits[A], b: Bits[A]) extends Primitive[A] {}
// The API to use this primitive
@api def prim[A](a: Bits[A], b: Bits[A]): A = {
implicit val tA: Bits[A] = a.selfType
stage(PRIM[A](a, b))
}
// The design
@spatial object ToyDesign extends SpatialApp {
def main(args: Array[String]): Unit = {
type T = FixPt[TRUE, _24, _8]
val x = ArgIn[T]
val y = ArgIn[T]
setArg(x, 2.0)
setArg(y, 4.0)
// The gradient output
val z = ArgOut[T]
Accel {
val d = x + y
z := prim(x, y) // should be replaced
}
println(s"x = $x, dy/dx = $dydx")
}
}
As far as I understand, this compiler pass needs to:
1) Locate the statement of PRIM op
2) Find out whether there is an addition (FixAdd) between the two inputs of PRIM
3) If so, construct a new Block, which has a list of statements (e.g., FixAdd(FixMul(x, x), FixMul(y, y))); otherwise, create an empty block
4) Replace the PRIM statement with this new Block
The major problems that I met while I was trying to implement these steps are as follows:
1) In an override transform method, how could I create a new Block and use it to replace an old one?
I checked RegReadCSE and PipeInserter. I kind of get the idea that a transform should contain these steps: pattern matching, building a new op through an API (inside spatial.lang), wrapping it with stage(), and returning the staged value. Alternatively, I notice we can also use a register function to register a mapping from a block, which is extracted from a ControlBody, to a new Block constructed by some custom methods.
I'm not sure about which direction should I go with for the case given above, and what are the exact steps to take.
2) What are the purposes of calling super.transform and f (for MutateTransform)?
Specifically, is there any concrete examples of the effect from mirror and update?
3) Once we've implemented a Transform, how could we test it?
In general, I have a rough idea about the algorithm behind my pass but I'm not clear about how to actually implement it respecting the type system inside Spatial. It would be great if anyone can draft the compiler pass mentioned above (no need to be very accurate, just an outline would be perfect) and give me a hint about those questions.
Thanks in advance for any help.
Best regards
Ruizhe