Patch for clast traversal pasted below.
-Uday
From d255400679150ed91c39bdaa730d200435388e5c Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <
uday...@gmail.com>
Date: Wed, 23 May 2012 22:32:57 +0530
Subject: [PATCH 2/2] clast traversal support
See comment for function clast_traverse.
Signed-off-by: Uday Bondhugula <
uday...@gmail.com>
---
Makefile.am | 1 +
include/cloog/clast.h | 2 +
source/clast_traversal.c | 192 ++++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 195 insertions(+), 0 deletions(-)
create mode 100644 source/clast_traversal.c
diff --git a/Makefile.am b/Makefile.am
index 1749c95..d45c7ae 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -73,6 +73,7 @@ SOURCES_CORE = \
$(GET_MEMORY_FUNCTIONS) \
source/block.c \
source/clast.c \
+ source/clast_traversal.c \
source/matrix.c \
source/state.c \
source/input.c \
diff --git a/include/cloog/clast.h b/include/cloog/clast.h
index 0d83a84..3b16649 100644
--- a/include/cloog/clast.h
+++ b/include/cloog/clast.h
@@ -154,6 +154,8 @@ int clast_expr_equal(struct clast_expr *e1, struct clast_expr *e2);
struct clast_expr *clast_bound_from_constraint(CloogConstraint *constraint,
int level, CloogNames *names);
+typedef enum filterType {exact, subset} ClastFilterType;
+
#if defined(__cplusplus)
}
#endif
diff --git a/source/clast_traversal.c b/source/clast_traversal.c
new file mode 100644
index 0000000..fdeea51
--- /dev/null
+++ b/source/clast_traversal.c
@@ -0,0 +1,192 @@
+#include <stdlib.h>
+#include <string.h>
+#include <assert.h>
+#include "../include/cloog/cloog.h"
+
+
+/* Adds to the list if not already in it */
+static int add_if_new(void **list, int num, void *new, int size)
+{
+ int i;
+
+ for (i=0; i<num; i++) {
+ if (!memcmp((*list) + i*size, new, size)) break;
+ }
+
+ if (i==num) {
+ *list = realloc(*list, (num+1)*size);
+ memcpy(*list + num*size, new, size);
+ return 1;
+ }
+
+ return 0;
+}
+
+
+/* Concatenates all elements of list2 that are not in list1;
+ * Returns the new size of the list */
+int concat_if_new(void **list1, int num1, void *list2, int num2, int size)
+{
+ int i, ret;
+
+ for (i=0; i<num2; i++) {
+ ret = add_if_new(list1, num1, (char *)list2 + i*size, size);
+ if (ret) num1++;
+ }
+
+ return num1;
+}
+
+/* Compares list1 to list2
+ * Returns 0 if both have the same elements; returns -1 if all elements of
+ * list1 are strictly contained in list2; 1 otherwise
+ */
+int list_compare(const int *list1, int num1, const int *list2, int num2)
+{
+ int i, j;
+
+ for (i=0; i<num1; i++) {
+ for (j=0; j<num2; j++) {
+ if (list1[i] == list2[j]) break;
+ }
+ if (j==num2) break;
+ }
+ if (i==num1) {
+ if (num1 == num2) {
+ return 0;
+ }
+ return -1;
+ }
+
+ return 1;
+}
+
+
+
+/*
+ * A multi-purpose function to traverse and get information on Clast
+ * loops
+ *
+ * node: clast node where processing should start
+ *
+ * Returns:
+ *
+ * A list of loops under clast_stmt 'node' filtered in two ways: (1) it contains
+ * statements appearing in 'stmt_filter', (2) loop iterator's name is 'iter'
+ * If iter' is set to NULL, no filtering based on iterator name is done
+ *
+ * A list of statements (statement numbers) under clast node 'node'
+ *
+ * iter: loop iterator name
+ * stmt_filter: list of statement numbers for filtering (1-indexed)
+ * nstmts_filter: number of statements in stmt_filter
+ *
+ * FilterType: match exact (i.e., loops containing only and all those statements
+ * in stmt_filter) or subset, i.e., loops which have only those statements
+ * that appear in stmt_filter
+ *
+ * To disable all filtering, set 'iter' to NULL, provide all statement
+ * numbers in 'stmt_filter' and set FilterType to subset
+ *
+ * Return fields
+ *
+ * stmts: statement numbers under node
+ * nstmts: number of stmt numbers pointed to by stmts
+ * loops: list of clast loops
+ * nloops: number of clast loops in loops
+ *
+ */
+void clast_traverse(struct clast_stmt *node,
+ const char *iter, const int *stmt_filter, int nstmts_filter,
+ struct clast_for ***loops, int *nloops,
+ int **stmts, int *nstmts, ClastFilterType filter_type)
+{
+ int num_next_stmts, num_next_loops, ret, *stmts_next;
+ struct clast_for **loops_next;
+
+ *loops = NULL;
+ *nloops = 0;
+ *nstmts = 0;
+ *stmts = NULL;
+
+ if (node == NULL) {
+ return;
+ }
+
+ if (CLAST_STMT_IS_A(node, stmt_root)) {
+ // printf("root stmt\n");
+ struct clast_root *root = (struct clast_root *) node;
+ clast_traverse((root->stmt).next, iter, stmt_filter, nstmts_filter, &loops_next,
+ &num_next_loops, &stmts_next, &num_next_stmts, filter_type);
+ *nstmts = concat_if_new((void **)stmts, *nstmts, stmts_next, num_next_stmts, sizeof(int));
+ *nloops = concat_if_new((void **)loops, *nloops, loops_next, num_next_loops,
+ sizeof(struct clast_stmt *));
+ free(loops_next);
+ free(stmts_next);
+ }
+
+ if (CLAST_STMT_IS_A(node, stmt_guard)) {
+ // printf("guard stmt\n");
+ struct clast_guard *guard = (struct clast_guard *) node;
+ clast_traverse(guard->then, iter, stmt_filter, nstmts_filter, &loops_next,
+ &num_next_loops, &stmts_next, &num_next_stmts, filter_type);
+ *nstmts = concat_if_new((void **)stmts, *nstmts, stmts_next, num_next_stmts, sizeof(int));
+ *nloops = concat_if_new((void **)loops, *nloops, loops_next, num_next_loops,
+ sizeof(struct clast_stmt *));
+ free(loops_next);
+ free(stmts_next);
+ clast_traverse((guard->stmt).next, iter, stmt_filter, nstmts_filter, &loops_next,
+ &num_next_loops, &stmts_next, &num_next_stmts, filter_type);
+ *nstmts = concat_if_new((void **)stmts, *nstmts, stmts_next, num_next_stmts, sizeof(int));
+ *nloops = concat_if_new((void **)loops, *nloops, loops_next, num_next_loops,
+ sizeof(struct clast_stmt *));
+ free(loops_next);
+ free(stmts_next);
+ }
+
+ if (CLAST_STMT_IS_A(node, stmt_user)) {
+ struct clast_user_stmt *user_stmt = (struct clast_user_stmt *) node;
+ // printf("user stmt: S%d\n", user_stmt->statement->number);
+ ret = add_if_new((void **)stmts, *nstmts, &user_stmt->statement->number, sizeof(int));
+ if (ret) (*nstmts)++;
+ clast_traverse((user_stmt->stmt).next, iter, stmt_filter, nstmts_filter, &loops_next,
+ &num_next_loops, &stmts_next, &num_next_stmts, filter_type);
+ *nstmts = concat_if_new((void **)stmts, *nstmts, stmts_next, num_next_stmts, sizeof(int));
+ *nloops = concat_if_new((void **)loops, *nloops, loops_next, num_next_loops,
+ sizeof(struct clast_stmt *));
+ free(loops_next);
+ free(stmts_next);
+ }
+ if (CLAST_STMT_IS_A(node, stmt_for)) {
+ struct clast_for *for_stmt = (struct clast_for *) node;
+ // printf("for stmt: %s\n", for_stmt->iterator);
+
+ clast_traverse(for_stmt->body, iter, stmt_filter, nstmts_filter, &loops_next,
+ &num_next_loops, &stmts_next, &num_next_stmts, filter_type);
+ *nstmts = concat_if_new((void **)stmts, *nstmts, stmts_next, num_next_stmts, sizeof(int));
+ *nloops = concat_if_new((void **)loops, *nloops, loops_next, num_next_loops,
+ sizeof(struct clast_stmt *));
+
+ if (iter == NULL || !strcmp(for_stmt->iterator, iter)) {
+ if (stmt_filter == NULL ||
+ (filter_type == subset && list_compare(stmts_next, num_next_stmts,
+ stmt_filter, nstmts_filter) <= 0)
+ || (filter_type == exact && list_compare(stmts_next, num_next_stmts,
+ stmt_filter, nstmts_filter) == 0 )) {
+ ret = add_if_new((void **)loops, *nloops, &for_stmt, sizeof(struct clast_for *));
+ if (ret) (*nloops)++;
+ }
+ }
+ free(loops_next);
+ free(stmts_next);
+
+ clast_traverse((for_stmt->stmt).next, iter, stmt_filter, nstmts_filter, &loops_next,
+ &num_next_loops, &stmts_next, &num_next_stmts, filter_type);
+ *nstmts = concat_if_new((void **)stmts, *nstmts, stmts_next, num_next_stmts, sizeof(int));
+ *nloops = concat_if_new((void **)loops, *nloops, loops_next, num_next_loops,
+ sizeof(struct clast_stmt *));
+ free(loops_next);
+ free(stmts_next);
+ }
+}
+
--
1.7.4.4