I'm new to ISL.I'm trying to Transform AST
for r in range(10)
for p in range(10)
S0: A1[0,p] = max(A1[0,p], A0[r,p])
for r in range(10)
for p in range(10)
S1: A2[r,p] = A0[r,p] - A1[0,p]
TO AST
for p in range(10)
for r in range(10)
S0: A1[0,p] = max(A1[0,p], A0[r,p])
S1: A2[r,p] = A0[r,p] - A1[0,p]
I use the following code to realize the function:
- import isl
- context = isl.set("{ : }")
- schedule_space = isl.set("{[t0,t1,t2]:}").get_space()
- precedes = isl.map.lex_lt(schedule_space)
- # iteration domain
- domain = isl.union_set(" {S0[i,j] : 0<=i<10 and 0<=j<10; S1[i,j] : 0<=i<10 and 0<=j<10; }")
- schedule = isl.union_map("{S0[i,j] -> [t0,t1,t2] : t0 = 1 and t1 = i and t2 = j; S1[i,j] -> [t0,t1,t2] : t0 = 2 and t1 = i and t2 = j}")
- build = isl.ast_build.from_context(context)
- ast = build.node_from_schedule_map(schedule.intersect_domain(domain))
- print(ast.to_C_str())
- reads = isl.union_map("{S0[i,j] -> A0[a=i,b=j]; S0[i,j] -> A1[a=0,b=j];S1[i,j] -> A1[a=0,b=j];}")
- reads= reads.intersect_domain(domain)
- writes = isl.union_map("{S0[i,j] -> A1[a=0,b=j]; S1[i,j] -> A2[a=i,b=j]; S[j] -> A1[0,b=j]}")
- writes= writes.intersect_domain(domain)
- raw = writes.apply_range(reads.reverse())
- raw = raw.apply_domain(schedule).apply_range(schedule)
- raw = raw.intersect(precedes)
- raw = raw.apply_domain(schedule.reverse()).apply_range(schedule.reverse())
- war = reads.apply_range(writes.reverse())
- war = war.apply_domain(schedule).apply_range(schedule)
- war = war.intersect(precedes)
- war = war.apply_domain(schedule.reverse()).apply_range(schedule.reverse())
- waw = writes.apply_range(writes.reverse())
- waw = waw.apply_domain(schedule).apply_range(schedule)
- waw = waw.intersect(precedes)
- waw = waw.apply_domain(schedule.reverse()).apply_range(schedule.reverse())
- sc = isl.schedule_constraints.on_domain(domain)
- dep = war.union(raw).union(waw)
- sc = sc.set_validity(dep)
- sc = sc.set_proximity(dep)
- sched = sc.compute_schedule()
- build = isl.ast_build.from_context(context)
- ast = build.node_from(sched)
- print(ast.to_C_str())
The experimental result is that: WITHOUT line 18 and line 21, I can get the expected AST.
WITH line 18 and line 21, I can just get the following AST:
for (int c0 = 0; c0 <= 9; c0 += 1)
for (int c1 = 0; c1 <= 9; c1 += 1)
S0(c1, c0);
for (int c0 = 0; c0 <= 9; c0 += 1)
for (int c1 = 0; c1 <= 9; c1 += 1)
S1(c0, c1);
Why does this happen? What is the correct way to realize the transformation?