[PATCH 01/17] Method::print_arg_list: introduce Method::print_fd_arg_list wrapper

3 views
Skip to first unread message

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:45 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

Unlike print_arg_list, print_fd_arg_list is an object method and
therefore has access to the object members, in particular "fd".
This will be used in the next commit.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/cpp.cc | 14 +++++++++++++-
interface/cpp.h | 2 ++
interface/plain_cpp.cc | 2 +-
3 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/interface/cpp.cc b/interface/cpp.cc
index 23e2281de1..40c6c9e4b4 100644
--- a/interface/cpp.cc
+++ b/interface/cpp.cc
@@ -1027,6 +1027,18 @@ void Method::print_arg_list(std::ostream &os, int start, int end,
os << ")";
}

+/* Print the arguments from "start" (inclusive) to "end" (exclusive)
+ * as arguments to a method of C function call, using "print_arg"
+ * to print each individual argument.
+ *
+ * Call print_arg_list to do the actual printing.
+ */
+void Method::print_fd_arg_list(std::ostream &os, int start, int end,
+ const std::function<void(int i)> &print_arg) const
+{
+ print_arg_list(os, start, end, print_arg);
+}
+
/* Print the arguments to the method call, using "print_arg"
* to print each individual argument.
*/
@@ -1034,7 +1046,7 @@ void Method::print_cpp_arg_list(std::ostream &os,
const std::function<void(int i)> &print_arg) const
{
int first_param = kind == member_method ? 1 : 0;
- print_arg_list(os, first_param, num_params(), print_arg);
+ print_fd_arg_list(os, first_param, num_params(), print_arg);
}

/* Should the parameter at position "pos" be a copy (rather than
diff --git a/interface/cpp.h b/interface/cpp.h
index 837cdee3e9..c6dfc6c211 100644
--- a/interface/cpp.h
+++ b/interface/cpp.h
@@ -35,6 +35,8 @@ struct Method {
bool is_subclass_mutator() const;
static void print_arg_list(std::ostream &os, int start, int end,
const std::function<void(int i)> &print_arg);
+ void print_fd_arg_list(std::ostream &os, int start, int end,
+ const std::function<void(int i)> &print_arg) const;
void print_cpp_arg_list(std::ostream &os,
const std::function<void(int i)> &print_arg) const;

diff --git a/interface/plain_cpp.cc b/interface/plain_cpp.cc
index ae15c34dbe..23013733eb 100644
--- a/interface/plain_cpp.cc
+++ b/interface/plain_cpp.cc
@@ -934,7 +934,7 @@ void plain_cpp_generator::impl_printer::print_method(const Method &method)

osprintf(os, " auto res = %s", methodname.c_str());

- Method::print_arg_list(os, 0, num_params, [&] (int i) {
+ method.print_fd_arg_list(os, 0, num_params, [&] (int i) {
method.print_param_use(os, i);
});
osprintf(os, ";\n");
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:46 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

Since the previous commit, the user pointer is skipped
by Method::print_fd_arg_list, so it no longer needs
to be removed from the number of arguments.
The callers of Method::num_params (which calls Method::c_num_params)
in template_cpp.cc needs to be adjusted to take into account
the extra argument.
The callers plain_cpp_generator::impl_printer::print_save_ctx and
plain_cpp_generator::impl_printer::print_argument_validity_check
do not require any adjustments because they already skip
any arguments that are not of an isl type.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/cpp.cc | 14 ++++----------
interface/template_cpp.cc | 7 +++++--
2 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/interface/cpp.cc b/interface/cpp.cc
index 3b979b5b0a..1e7648f249 100644
--- a/interface/cpp.cc
+++ b/interface/cpp.cc
@@ -985,19 +985,13 @@ Method::Method(const isl_class &clazz, FunctionDecl *fd) :

/* Return the number of parameters of the corresponding C function.
*
- * If the method has a callback argument, we reduce the number of parameters
- * that are exposed by one to hide the user pointer from the interface. On
- * the C++ side no user pointer is needed, as arguments can be forwarded
- * as part of the std::function argument which specifies the callback function.
- *
- * The user pointer is also removed from the number of parameters
- * of the C function because the pair of callback and user pointer
- * is considered as a single argument that is printed as a whole
- * by Method::print_param_use.
+ * This number includes any possible user pointers that follow callback
+ * arguments. These are skipped by Method::print_fd_arg_list
+ * during the actual argument printing.
*/
int Method::c_num_params() const
{
- return fd->getNumParams() - (callback != NULL);
+ return fd->getNumParams();
}

/* Return the number of parameters of the method
diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 900b56cd09..71d5b38df6 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1564,6 +1564,8 @@ void template_cpp_generator::method_decl_printer::print_method_sig(
* then the return kind of the callback appears at the position
* of the callback and the kinds of the arguments (except
* the user pointer argument) appear in the following positions.
+ * The user pointer argument that follows the callback argument
+ * is also removed.
*/
static int total_params(const Method &method)
{
@@ -1573,7 +1575,8 @@ static int total_params(const Method &method)
auto callback_type = method.callback->getType();
auto callback = generator::extract_prototype(callback_type);

- n += callback->getNumArgs() - 1;
+ n += callback->getNumParams() - 1;
+ n -= 1;
}

return n;
@@ -1686,7 +1689,7 @@ static void print_callback_lambda(std::ostream &os, const Method &method,
auto callback_name = method.callback->getName().str();
auto callback = generator::extract_prototype(callback_type);

- if (method.num_params() != 2)
+ if (method.num_params() != 3)
generator::die("callback is assumed to be single argument");

os << " auto lambda = [&] ";
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:47 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

Currently, the number of arguments is reduced by one in call
of a callback, meaning that the (presumed unique) user pointer
is never considered by Method::print_arg.
However, this only works if there is only a single callback
(and there are no other arguments following the callback and
the user pointer).
To be able to support exported functions with multiple callbacks
(each with its own user pointer), a user pointer needs
to be skipped for every callback argument.

Replace the "print_arg" callback of Method::print_arg_list
by a "print_arg_skip_next" that indicates whether the next
argument should be skipped.
Method::print_fd_arg_list uses this callback to skip
the user pointers following the callbacks in the exported functions.
print_callback_args in template_cpp.cc does not skip
any arguments since it is responsible for printing a callback and
callbacks are assumed not to have any callback arguments themselves.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/cpp.cc | 26 ++++++++++++++++++++------
interface/cpp.h | 2 +-
interface/template_cpp.cc | 3 +++
3 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/interface/cpp.cc b/interface/cpp.cc
index 40c6c9e4b4..3b979b5b0a 100644
--- a/interface/cpp.cc
+++ b/interface/cpp.cc
@@ -1012,17 +1012,19 @@ int Method::num_params() const
}

/* Print the arguments from "start" (inclusive) to "end" (exclusive)
- * as arguments to a method of C function call, using "print_arg"
- * to print each individual argument.
+ * as arguments to a method of C function call, using "print_arg_skip_next"
+ * to print each individual argument. If this callback return true
+ * then the next argument is skipped.
*/
void Method::print_arg_list(std::ostream &os, int start, int end,
- const std::function<void(int i)> &print_arg)
+ const std::function<bool(int i)> &print_arg_skip_next)
{
os << "(";
for (int i = start; i < end; ++i) {
if (i != start)
os << ", ";
- print_arg(i);
+ if (print_arg_skip_next(i))
+ ++i;
}
os << ")";
}
@@ -1031,12 +1033,24 @@ void Method::print_arg_list(std::ostream &os, int start, int end,
* as arguments to a method of C function call, using "print_arg"
* to print each individual argument.
*
- * Call print_arg_list to do the actual printing.
+ * Call print_arg_list to do the actual printing, skipping
+ * the user argument that comes after every callback argument.
+ * On the C++ side no user pointer is needed, as arguments can be forwarded
+ * as part of the std::function argument which specifies the callback function.
+ * The user pointer is also removed from the number of parameters
+ * of the C function because the pair of callback and user pointer
+ * is considered as a single argument that is printed as a whole
+ * by Method::print_param_use.
*/
void Method::print_fd_arg_list(std::ostream &os, int start, int end,
const std::function<void(int i)> &print_arg) const
{
- print_arg_list(os, start, end, print_arg);
+ print_arg_list(os, start, end, [this, &print_arg] (int i) {
+ auto type = fd->getParamDecl(i)->getType();
+
+ print_arg(i);
+ return generator::is_callback(type);
+ });
}

/* Print the arguments to the method call, using "print_arg"
diff --git a/interface/cpp.h b/interface/cpp.h
index c6dfc6c211..b28d4f3c29 100644
--- a/interface/cpp.h
+++ b/interface/cpp.h
@@ -34,7 +34,7 @@ struct Method {
virtual void print_param_use(ostream &os, int pos) const;
bool is_subclass_mutator() const;
static void print_arg_list(std::ostream &os, int start, int end,
- const std::function<void(int i)> &print_arg);
+ const std::function<bool(int i)> &print_arg_skip_next);
void print_fd_arg_list(std::ostream &os, int start, int end,
const std::function<void(int i)> &print_arg) const;
void print_cpp_arg_list(std::ostream &os,
diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index d58113c011..900b56cd09 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1646,6 +1646,7 @@ void template_cpp_generator::method_impl_printer::print_constructor_body(
* calling "print_arg" with the type and the name of the arguments,
* where the type is obtained from "type_printer" with argument positions
* shifted by "shift".
+ * None of the arguments should be skipped.
*/
static void print_callback_args(std::ostream &os,
const FunctionProtoType *callback, const cpp_type_printer &type_printer,
@@ -1661,6 +1662,8 @@ static void print_callback_args(std::ostream &os,
auto cpptype = type_printer.param(shift + i, type);

print_arg(cpptype, name);
+
+ return false;
});
}

--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:48 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

For the templated interface, the arguments of any callbacks
are spliced into the list of arguments.
If the callback is the last argument (aside from
the corresponding user pointer), then the splicing can be performed
purely locally. However, if a callback appears in any other position
(in particular, if there are multiple callbacks), then the splicing
needs to be taken into account for any further arguments.
Let Method::print_fd_arg_list take care of the accounting.
Most callers ignore the extra argument to the Method::print_fd_arg_list
callback. Only cpp_generator::class_printer::print_method_header
needs to know the position in the flattened list of arguments
because it needs to pass this position to the type printer.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/cpp.cc | 33 ++++++++++++++++++++++++++-------
interface/cpp.h | 4 ++--
interface/generator.cc | 8 ++++++++
interface/generator.h | 1 +
interface/plain_cpp.cc | 4 ++--
interface/template_cpp.cc | 6 +++---
6 files changed, 42 insertions(+), 14 deletions(-)

diff --git a/interface/cpp.cc b/interface/cpp.cc
index 1e7648f249..2fdda6e7d2 100644
--- a/interface/cpp.cc
+++ b/interface/cpp.cc
@@ -649,11 +649,11 @@ void cpp_generator::class_printer::print_method_header(
else
os << cppstring;

- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
std::string name = method.fd->getParamDecl(i)->getName().str();
ParmVarDecl *param = method.get_param(i);
QualType type = param->getOriginalType();
- string cpptype = type_printer.param(i, type);
+ string cpptype = type_printer.param(arg, type);

if (!method.param_needs_copy(i))
os << "const " << cpptype << " &" << name;
@@ -1026,6 +1026,10 @@ void Method::print_arg_list(std::ostream &os, int start, int end,
/* Print the arguments from "start" (inclusive) to "end" (exclusive)
* as arguments to a method of C function call, using "print_arg"
* to print each individual argument.
+ * The first argument to this callback is the position of the argument
+ * in this->fd.
+ * The second argument is the (first) position in the list of arguments
+ * with all callback arguments spliced in.
*
* Call print_arg_list to do the actual printing, skipping
* the user argument that comes after every callback argument.
@@ -1035,23 +1039,38 @@ void Method::print_arg_list(std::ostream &os, int start, int end,
* of the C function because the pair of callback and user pointer
* is considered as a single argument that is printed as a whole
* by Method::print_param_use.
+ *
+ * In case of a callback argument, the second argument to "print_arg"
+ * is also adjusted to account for the spliced-in arguments of the callback.
+ * The return value takes the place of the callback itself,
+ * while the arguments (excluding the final user pointer)
+ * take the following positions.
*/
void Method::print_fd_arg_list(std::ostream &os, int start, int end,
- const std::function<void(int i)> &print_arg) const
+ const std::function<void(int i, int arg)> &print_arg) const
{
- print_arg_list(os, start, end, [this, &print_arg] (int i) {
+ int arg = start;
+
+ print_arg_list(os, start, end, [this, &print_arg, &arg] (int i) {
auto type = fd->getParamDecl(i)->getType();

- print_arg(i);
- return generator::is_callback(type);
+ print_arg(i, arg++);
+ if (!generator::is_callback(type))
+ return false;
+ arg += generator::prototype_n_args(type) - 1;
+ return true;
});
}

/* Print the arguments to the method call, using "print_arg"
* to print each individual argument.
+ * The first argument to this callback is the position of the argument
+ * in this->fd.
+ * The second argument is the (first) position in the list of arguments
+ * with all callback arguments spliced in.
*/
void Method::print_cpp_arg_list(std::ostream &os,
- const std::function<void(int i)> &print_arg) const
+ const std::function<void(int i, int arg)> &print_arg) const
{
int first_param = kind == member_method ? 1 : 0;
print_fd_arg_list(os, first_param, num_params(), print_arg);
diff --git a/interface/cpp.h b/interface/cpp.h
index b28d4f3c29..f29c5ed6a9 100644
--- a/interface/cpp.h
+++ b/interface/cpp.h
@@ -36,9 +36,9 @@ struct Method {
static void print_arg_list(std::ostream &os, int start, int end,
const std::function<bool(int i)> &print_arg_skip_next);
void print_fd_arg_list(std::ostream &os, int start, int end,
- const std::function<void(int i)> &print_arg) const;
+ const std::function<void(int i, int arg)> &print_arg) const;
void print_cpp_arg_list(std::ostream &os,
- const std::function<void(int i)> &print_arg) const;
+ const std::function<void(int i, int arg)> &print_arg) const;

const isl_class &clazz;
FunctionDecl *const fd;
diff --git a/interface/generator.cc b/interface/generator.cc
index eacbc36a65..566ed427b3 100644
--- a/interface/generator.cc
+++ b/interface/generator.cc
@@ -781,6 +781,14 @@ const FunctionProtoType *generator::extract_prototype(QualType type)
return type->getPointeeType()->getAs<FunctionProtoType>();
}

+/* Given the type of a function pointer, return the number of arguments
+ * of the corresponding function prototype.
+ */
+int generator::prototype_n_args(QualType type)
+{
+ return extract_prototype(type)->getNumArgs();
+}
+
/* Return the function name suffix for the type of "param".
*
* If the type of "param" is an isl object type,
diff --git a/interface/generator.h b/interface/generator.h
index 29faaaac08..864e123694 100644
--- a/interface/generator.h
+++ b/interface/generator.h
@@ -195,6 +195,7 @@ public:
static bool is_mutator(const isl_class &clazz, FunctionDecl *fd);
static string extract_type(QualType type);
static const FunctionProtoType *extract_prototype(QualType type);
+ static int prototype_n_args(QualType type);
static ParmVarDecl *persistent_callback_arg(FunctionDecl *fd);
};

diff --git a/interface/plain_cpp.cc b/interface/plain_cpp.cc
index 23013733eb..0b32e69691 100644
--- a/interface/plain_cpp.cc
+++ b/interface/plain_cpp.cc
@@ -934,7 +934,7 @@ void plain_cpp_generator::impl_printer::print_method(const Method &method)

osprintf(os, " auto res = %s", methodname.c_str());

- method.print_fd_arg_list(os, 0, num_params, [&] (int i) {
+ method.print_fd_arg_list(os, 0, num_params, [&] (int i, int arg) {
method.print_param_use(os, i);
});
osprintf(os, ";\n");
@@ -1002,7 +1002,7 @@ void plain_cpp_generator::impl_printer::print_method(
print_check_ptr("ptr");
osprintf(os, " return ");
method.print_call(os, generator.isl_namespace());
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
ParmVarDecl *param = method.fd->getParamDecl(i);

print_arg_conversion(param, method.get_param(i));
diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 71d5b38df6..427056cdfd 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1636,7 +1636,7 @@ void template_cpp_generator::method_impl_printer::print_constructor_body(
const auto &base_name = instance.base_name();

os << " : " << base_name;
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
os << method.fd->getParamDecl(i)->getName().str();
});
os << "\n";
@@ -1735,7 +1735,7 @@ void template_cpp_generator::method_impl_printer::print_callback_method_body(

os << " return ";
os << base_name << "::" << method.name;
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
auto param = method.fd->getParamDecl(i);

if (param == method.callback)
@@ -1763,7 +1763,7 @@ void template_cpp_generator::method_impl_printer::print_method_body(
os << "{\n";
os << " auto res = ";
os << base_name << "::" << method.name;
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
os << method.fd->getParamDecl(i)->getName().str();
});
os << ";\n";
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:49 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This makes room for a "callback" argument in the next commit.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/template_cpp.cc | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 427056cdfd..838398cdd5 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1687,20 +1687,20 @@ static void print_callback_lambda(std::ostream &os, const Method &method,
{
auto callback_type = method.callback->getType();
auto callback_name = method.callback->getName().str();
- auto callback = generator::extract_prototype(callback_type);
+ auto proto = generator::extract_prototype(callback_type);

if (method.num_params() != 3)
generator::die("callback is assumed to be single argument");

os << " auto lambda = [&] ";
- print_callback_args(os, callback, cpp_type_printer(), 2,
+ print_callback_args(os, proto, cpp_type_printer(), 2,
[&] (const std::string &type, const std::string &name) {
os << type << " " << name;
});
os << " {\n";

os << " return " << callback_name;
- print_callback_args(os, callback, template_cpp_arg_type_printer(sig), 2,
+ print_callback_args(os, proto, template_cpp_arg_type_printer(sig), 2,
[&] (const std::string &type, const std::string &name) {
os << type << "(" << name << ")";
});
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:50 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This prepares for supporting multiple callbacks in an upcoming commit.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/template_cpp.cc | 52 +++++++++++++++++++++++----------------
1 file changed, 31 insertions(+), 21 deletions(-)

diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 838398cdd5..62ee9f5889 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1670,37 +1670,29 @@ static void print_callback_args(std::ostream &os,
});
}

-/* Print a lambda for passing to the plain method corresponding to "method"
- * with signature "sig".
- *
- * The method is assumed to have only the callback as argument,
- * which means the arguments of the callback are shifted by 2
- * with respect to the arguments of the signature
- * (one for the position of the callback argument plus
- * one for the return kind of the callback).
+/* Print a lambda corresponding to "callback"
+ * with signature "sig" and argument positions shifted by "shift".
*
* The lambda takes arguments with plain isl types and
* calls the callback of "method" with templated arguments.
*/
-static void print_callback_lambda(std::ostream &os, const Method &method,
- const Signature &sig)
+static void print_callback_lambda(std::ostream &os, ParmVarDecl *callback,
+ const Signature &sig, int shift)
{
- auto callback_type = method.callback->getType();
- auto callback_name = method.callback->getName().str();
+ auto callback_type = callback->getType();
+ auto callback_name = callback->getName().str();
auto proto = generator::extract_prototype(callback_type);

- if (method.num_params() != 3)
- generator::die("callback is assumed to be single argument");
-
os << " auto lambda = [&] ";
- print_callback_args(os, proto, cpp_type_printer(), 2,
+ print_callback_args(os, proto, cpp_type_printer(), shift,
[&] (const std::string &type, const std::string &name) {
os << type << " " << name;
});
os << " {\n";

os << " return " << callback_name;
- print_callback_args(os, proto, template_cpp_arg_type_printer(sig), 2,
+ print_callback_args(os, proto, template_cpp_arg_type_printer(sig),
+ shift,
[&] (const std::string &type, const std::string &name) {
os << type << "(" << name << ")";
});
@@ -1709,13 +1701,31 @@ static void print_callback_lambda(std::ostream &os, const Method &method,
os << " };\n";
}

+/* Print lambdas for passing to the plain method corresponding to "method"
+ * with signature "sig".
+ *
+ * The method is assumed to have only the callback as argument,
+ * which means the arguments of the callback are shifted by 2
+ * with respect to the arguments of the signature
+ * (one for the position of the callback argument plus
+ * one for the return kind of the callback).
+ */
+static void print_callback_lambdas(std::ostream &os, const Method &method,
+ const Signature &sig)
+{
+ if (method.num_params() != 3)
+ generator::die("callback is assumed to be single argument");
+
+ print_callback_lambda(os, method.callback, sig, 2);
+}
+
/* Print a definition of the member method "method", which is known
* to have a callback argument, with signature "sig".
*
- * First print a lambda for passing to the corresponding plain method and
+ * First print lambdas for passing to the corresponding plain method and
* calling the callback of "method" with templated arguments.
- * Then call the plain method, replacing the original callback
- * by the lambda.
+ * Then call the plain method, replacing the original callbacks
+ * by the lambdas.
*
* The return value is assumed to be isl_bool or isl_stat
* so that no conversion to a template type is required.
@@ -1731,7 +1741,7 @@ void template_cpp_generator::method_impl_printer::print_callback_method_body(

os << "{\n";

- print_callback_lambda(os, method, sig);
+ print_callback_lambdas(os, method, sig);

os << " return ";
os << base_name << "::" << method.name;
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:51 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

That is, instead of checking whether a function parameter
is equal to a previously detected callback, directly
check whether the function parameter is a callback.
This prepares for supporting multiple callbacks.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/template_cpp.cc | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 62ee9f5889..6788be8c7c 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1748,7 +1748,7 @@ void template_cpp_generator::method_impl_printer::print_callback_method_body(
method.print_cpp_arg_list(os, [&] (int i, int arg) {
auto param = method.fd->getParamDecl(i);

- if (param == method.callback)
+ if (generator::is_callback(param->getType()))
os << "lambda";
else
os << param->getName().str();
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:52 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This prepares for supporting multiple callbacks
for a single exported function.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/template_cpp.cc | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 6788be8c7c..5a52fc7590 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1683,7 +1683,7 @@ static void print_callback_lambda(std::ostream &os, ParmVarDecl *callback,
auto callback_name = callback->getName().str();
auto proto = generator::extract_prototype(callback_type);

- os << " auto lambda = [&] ";
+ os << " auto lambda_" << callback_name << " = [&] ";
print_callback_args(os, proto, cpp_type_printer(), shift,
[&] (const std::string &type, const std::string &name) {
os << type << " " << name;
@@ -1749,9 +1749,8 @@ void template_cpp_generator::method_impl_printer::print_callback_method_body(
auto param = method.fd->getParamDecl(i);

if (generator::is_callback(param->getType()))
- os << "lambda";
- else
- os << param->getName().str();
+ os << "lambda_";
+ os << param->getName().str();

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:53 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This makes room for a "callback" range based for loop iterator
in the next commit.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/template_cpp.cc | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 5a52fc7590..8686e5ca97 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1573,9 +1573,9 @@ static int total_params(const Method &method)

if (method.callback) {
auto callback_type = method.callback->getType();
- auto callback = generator::extract_prototype(callback_type);
+ auto proto = generator::extract_prototype(callback_type);

- n += callback->getNumParams() - 1;
+ n += proto->getNumParams() - 1;
n -= 1;
}

--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:54 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This is needed to be able to export isl_*_list_foreach_scc.
The callbacks are still assumed to be the only arguments
(in the templated interface).

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/cpp.cc | 12 ++++++------
interface/cpp.h | 4 ++--
interface/plain_cpp.cc | 8 ++++----
interface/template_cpp.cc | 35 ++++++++++++++++++++++-------------
4 files changed, 34 insertions(+), 25 deletions(-)

diff --git a/interface/cpp.cc b/interface/cpp.cc
index 2fdda6e7d2..ef66c2b56e 100644
--- a/interface/cpp.cc
+++ b/interface/cpp.cc
@@ -942,20 +942,20 @@ static Method::Kind get_kind(const isl_class &clazz, FunctionDecl *method)
return Method::Kind::member_method;
}

-/* Return the callback argument of "fd", if there is any.
- * Return NULL otherwise.
+/* Return the callback arguments of "fd".
*/
-static ParmVarDecl *find_callback_arg(FunctionDecl *fd)
+static std::vector<ParmVarDecl *> find_callback_args(FunctionDecl *fd)
{
+ std::vector<ParmVarDecl *> callbacks;
int num_params = fd->getNumParams();

for (int i = 0; i < num_params; ++i) {
ParmVarDecl *param = fd->getParamDecl(i);
if (generator::is_callback(param->getType()))
- return param;
+ callbacks.emplace_back(param);
}

- return NULL;
+ return callbacks;
}

/* Construct a C++ method object from the class to which is belongs,
@@ -968,7 +968,7 @@ Method::Method(const isl_class &clazz, FunctionDecl *fd,
const std::string &name) :
clazz(clazz), fd(fd), name(rename_method(name)),
kind(get_kind(clazz, fd)),
- callback(find_callback_arg(fd))
+ callbacks(find_callback_args(fd))
{
}

diff --git a/interface/cpp.h b/interface/cpp.h
index f29c5ed6a9..f7508b1b39 100644
--- a/interface/cpp.h
+++ b/interface/cpp.h
@@ -14,7 +14,7 @@
* "name" is the name of the method, which may be different
* from the default name derived from "fd".
* "kind" is the type of the method.
- * "callback" stores the callback argument, if any, or NULL.
+ * "callbacks" stores the callback arguments.
*/
struct Method {
enum Kind {
@@ -44,7 +44,7 @@ struct Method {
FunctionDecl *const fd;
const std::string name;
const enum Kind kind;
- ParmVarDecl *const callback;
+ const std::vector<ParmVarDecl *> callbacks;
};

/* A method that does not require its isl type parameters to be a copy.
diff --git a/interface/plain_cpp.cc b/interface/plain_cpp.cc
index 0b32e69691..ce8b79123d 100644
--- a/interface/plain_cpp.cc
+++ b/interface/plain_cpp.cc
@@ -929,8 +929,8 @@ void plain_cpp_generator::impl_printer::print_method(const Method &method)
print_save_ctx(method);
print_on_error_continue();

- if (method.callback)
- print_callback_local(method.callback);
+ for (const auto &callback : method.callbacks)
+ print_callback_local(callback);

osprintf(os, " auto res = %s", methodname.c_str());

@@ -1393,10 +1393,10 @@ void plain_cpp_generator::impl_printer::print_exceptional_execution_check(

print_persistent_callback_exceptional_execution_check(os, method);

- if (method.callback) {
+ for (const auto &callback : method.callbacks) {
std::string name;

- name = method.callback->getName().str();
+ name = callback->getName().str();
osprintf(os, " if (%s_data.eptr)\n", name.c_str());
osprintf(os, " std::rethrow_exception(%s_data.eptr);\n",
name.c_str());
diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 8686e5ca97..57ccdf2886 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1558,7 +1558,7 @@ void template_cpp_generator::method_decl_printer::print_method_sig(
}

/* Return the total number of arguments in the signature for "method",
- * taking into account a possible callback argument.
+ * taking into account any possible callback arguments.
*
* In particular, if the method has a callback argument,
* then the return kind of the callback appears at the position
@@ -1571,8 +1571,8 @@ static int total_params(const Method &method)
{
int n = method.num_params();

- if (method.callback) {
- auto callback_type = method.callback->getType();
+ for (const auto &callback : method.callbacks) {
+ auto callback_type = callback->getType();
auto proto = generator::extract_prototype(callback_type);

n += proto->getNumParams() - 1;
@@ -1704,19 +1704,28 @@ static void print_callback_lambda(std::ostream &os, ParmVarDecl *callback,
/* Print lambdas for passing to the plain method corresponding to "method"
* with signature "sig".
*
- * The method is assumed to have only the callback as argument,
- * which means the arguments of the callback are shifted by 2
+ * The method is assumed to have only callbacks as argument,
+ * which means the arguments of the first callback are shifted by 2
* with respect to the arguments of the signature
* (one for the position of the callback argument plus
* one for the return kind of the callback).
+ * The arguments of a subsequent callback are shifted by
+ * the number of arguments of the previous callback minus one
+ * for the user pointer plus one for the return kind.
*/
static void print_callback_lambdas(std::ostream &os, const Method &method,
const Signature &sig)
{
- if (method.num_params() != 3)
- generator::die("callback is assumed to be single argument");
+ int shift;

- print_callback_lambda(os, method.callback, sig, 2);
+ if (method.num_params() != 1 + 2 * method.callbacks.size())
+ generator::die("callbacks are assumed to be only arguments");
+
+ shift = 2;
+ for (const auto &callback : method.callbacks) {
+ print_callback_lambda(os, callback, sig, shift);
+ shift += generator::prototype_n_args(callback->getType());
+ }
}

/* Print a definition of the member method "method", which is known
@@ -1792,7 +1801,7 @@ void template_cpp_generator::method_impl_printer::print_method_body(
* Otherwise print the method header, preceded by the template parameters,
* if needed.
* The body depends on whether the method is a constructor or
- * takes a callback.
+ * takes any callbacks.
*/
void template_cpp_generator::method_impl_printer::print_method_sig(
const Method &method, const Signature &sig, bool deleted)
@@ -1806,7 +1815,7 @@ void template_cpp_generator::method_impl_printer::print_method_sig(
os << "\n";
if (method.kind == Method::Kind::constructor)
print_constructor_body(method, sig);
- else if (method.callback)
+ else if (method.callbacks.size() != 0)
print_callback_method_body(method, sig);
else
print_method_body(method, sig);
@@ -2549,15 +2558,15 @@ const std::string name_without_return(const Method &method)
}

/* If this method has a callback, then remove the type
- * of the first argument of the callback from the name of the method.
+ * of the first argument of the first callback from the name of the method.
* Otherwise, simply return the name of the method.
*/
const std::string callback_name(const Method &method)
{
- if (!method.callback)
+ if (method.callbacks.size() == 0)
return method.name;

- auto type = method.callback->getType();
+ auto type = method.callbacks.at(0)->getType();
auto callback = cpp_generator::extract_prototype(type);
auto arg_type = plain_type(callback->getArgType(0));
return generator::drop_suffix(method.name, "_" + arg_type);
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:55 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This prepares for supporting multiple callback arguments
since it drops the assumption that the user pointer
always appears as the last argument.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/python.cc | 17 ++++++++++-------
1 file changed, 10 insertions(+), 7 deletions(-)

diff --git a/interface/python.cc b/interface/python.cc
index dd78c42a9b..39855df9d2 100644
--- a/interface/python.cc
+++ b/interface/python.cc
@@ -824,34 +824,37 @@ void python_generator::print_restype(FunctionDecl *fd)
}

/* Tell ctypes about the types of the arguments of the function "fd".
+ *
+ * Any callback argument is followed by a user pointer argument.
+ * Each such pair or arguments is handled together.
*/
void python_generator::print_argtypes(FunctionDecl *fd)
{
string fullname = fd->getName().str();
int n = fd->getNumParams();
- int drop_user = 0;

printf("isl.%s.argtypes = [", fullname.c_str());
- for (int i = 0; i < n - drop_user; ++i) {
+ for (int i = 0; i < n; ++i) {
ParmVarDecl *param = fd->getParamDecl(i);
QualType type = param->getOriginalType();
- if (is_callback(type))
- drop_user = 1;
if (i)
printf(", ");
if (is_isl_ctx(type))
printf("Context");
- else if (is_isl_type(type) || is_callback(type))
+ else if (is_isl_type(type))
printf("c_void_p");
+ else if (is_callback(type))
+ printf("c_void_p, c_void_p");
else if (is_string(type))
printf("c_char_p");
else if (is_long(type))
printf("c_long");
else
printf("c_int");
+
+ if (is_callback(type))
+ ++i;
}
- if (drop_user)
- printf(", c_void_p");
printf("]\n");
}

--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:56 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This prepares for supporting multiple callbacks
on a single exported function.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/python.cc | 15 ++++++++-------
1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/interface/python.cc b/interface/python.cc
index 39855df9d2..b637859bf6 100644
--- a/interface/python.cc
+++ b/interface/python.cc
@@ -195,8 +195,7 @@ void python_generator::print_copy(QualType type)
}

/* Construct a wrapper for callback argument "param" (at position "arg").
- * Assign the wrapper to "cb". We assume here that a function call
- * has at most one callback argument.
+ * Assign the wrapper to "cb{arg}".
*
* The wrapper converts the arguments of the callback to python types,
* taking a copy if the C callback does not take its arguments.
@@ -272,7 +271,7 @@ void python_generator::print_callback(ParmVarDecl *param, int arg)
print_copy(return_type);
printf("(res.ptr)\n");
}
- printf(" cb = fn(cb_func)\n");
+ printf(" cb%d = fn(cb_func)\n", arg);
}

/* Print the argument at position "arg" in call to "fd".
@@ -284,7 +283,7 @@ void python_generator::print_callback(ParmVarDecl *param, int arg)
* assuming that the caller has made the context available
* in a "ctx" variable.
* Otherwise, if the argument is a callback, then print a reference to
- * the callback wrapper "cb".
+ * the corresponding callback wrapper.
* Otherwise, if the argument is marked as consuming a reference,
* then pass a copy of the pointer stored in the corresponding
* argument passed to the Python method.
@@ -302,7 +301,7 @@ void python_generator::print_arg_in_call(FunctionDecl *fd, const char *fmt,
if (is_isl_ctx(type)) {
printf("ctx");
} else if (is_callback(type)) {
- printf("cb");
+ printf("cb%d", arg - skip);
} else if (takes(param)) {
print_copy(type);
printf("(");
@@ -372,6 +371,8 @@ static void print_persistent_callback_failure_check(int indent,
* then keep track of the constructed C callback (such that it doesn't
* get destroyed) and the data structure that holds the captured exception
* (such that it can be raised again).
+ * The callback appears in position 1 and the C callback is therefore
+ * called "cb1".
*
* If the return type is a (const) char *, then convert the result
* to a Python string, raising an error on NULL and freeing
@@ -406,7 +407,7 @@ void python_generator::print_method_return(int indent, const isl_class &clazz,
string callback_name;

callback_name = clazz.persistent_callback_name(method);
- print_indent(indent, "obj.%s = { 'func': cb, "
+ print_indent(indent, "obj.%s = { 'func': cb1, "
"'exc_info': exc_info }\n",
callback_name.c_str());
}
@@ -509,7 +510,7 @@ void python_generator::print_method_call(int indent, const isl_class &clazz,
* If the function has a callback argument, then it also has a "user"
* argument. Since Python has closures, there is no need for such
* a user argument in the Python interface, so we simply drop it.
- * We also create a wrapper ("cb") for the callback.
+ * We also create a wrapper ("cb{arg}") for the callback.
*
* If the function consumes a reference, then we pass it a copy of
* the actual argument.
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:58 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

In particular, let print_type_checks consider positions
corresponding to user pointers as well.
The corresponding arguments are ignored anyway because
print_type_checks only prints type checks for arguments
of an isl type.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/python.cc | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/interface/python.cc b/interface/python.cc
index b637859bf6..24f96de479 100644
--- a/interface/python.cc
+++ b/interface/python.cc
@@ -538,7 +538,7 @@ void python_generator::print_method(const isl_class &clazz,
num_params - drop_ctx - drop_user);

print_type_checks(cname, method, drop_ctx,
- num_params - drop_user, super);
+ num_params, super);
for (int i = 1; i < num_params; ++i) {
ParmVarDecl *param = method->getParamDecl(i);
QualType type = param->getOriginalType();
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:05:59 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This will be reused in the next commit.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/generator.cc | 10 ++++++++++
interface/generator.h | 1 +
interface/python.cc | 4 +---
3 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/interface/generator.cc b/interface/generator.cc
index 566ed427b3..4b9890325c 100644
--- a/interface/generator.cc
+++ b/interface/generator.cc
@@ -735,6 +735,16 @@ bool generator::is_callback(QualType type)
return type->isFunctionType();
}

+/* Is the parameter at position "i" of "fd" a pointer to a function?
+ */
+bool generator::is_callback_arg(FunctionDecl *fd, int i)
+{
+ ParmVarDecl *param = fd->getParamDecl(i);
+ QualType type = param->getOriginalType();
+
+ return is_callback(type);
+}
+
/* Is "type" that of "char *" of "const char *"?
*/
bool generator::is_string(QualType type)
diff --git a/interface/generator.h b/interface/generator.h
index 864e123694..90eae8266a 100644
--- a/interface/generator.h
+++ b/interface/generator.h
@@ -190,6 +190,7 @@ public:
static bool is_isl_size(QualType type);
static bool is_long(QualType type);
static bool is_callback(QualType type);
+ static bool is_callback_arg(FunctionDecl *fd, int i);
static bool is_string(QualType type);
static bool is_static(const isl_class &clazz, FunctionDecl *method);
static bool is_mutator(const isl_class &clazz, FunctionDecl *fd);
diff --git a/interface/python.cc b/interface/python.cc
index 24f96de479..62f30519b8 100644
--- a/interface/python.cc
+++ b/interface/python.cc
@@ -528,9 +528,7 @@ void python_generator::print_method(const isl_class &clazz,
int drop_ctx = first_arg_is_isl_ctx(method);

for (int i = 1; i < num_params; ++i) {
- ParmVarDecl *param = method->getParamDecl(i);
- QualType type = param->getOriginalType();
- if (is_callback(type))
+ if (is_callback_arg(method, i))
drop_user = 1;
}

--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:06:00 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

In the next commit, support will be added for multiple callbacks.
This requires python_generator::print_method_call to be aware
of the number of user pointer arguments that have been skipped
so far, not only the total number of user pointer arguments.
This means python_generator::print_method_call needs
to recompute drop_user anyway.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/python.cc | 11 ++++++-----
interface/python.h | 2 +-
2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/interface/python.cc b/interface/python.cc
index 62f30519b8..cd72aa8f87 100644
--- a/interface/python.cc
+++ b/interface/python.cc
@@ -457,8 +457,6 @@ void python_generator::print_get_method(const isl_class &clazz,
/* Print a call to "method", along with the corresponding
* return statement, with the given indentation.
* "drop_ctx" is set if the first argument is an isl_ctx.
- * "drop_user" is set if the last argument is a "user" argument
- * corresponding to a callback argument.
*
* A "ctx" variable is first initialized as it may be needed
* in the first call to print_arg_in_call and in print_method_return.
@@ -467,10 +465,11 @@ void python_generator::print_get_method(const isl_class &clazz,
* thrown in the callback also need to be rethrown.
*/
void python_generator::print_method_call(int indent, const isl_class &clazz,
- FunctionDecl *method, const char *fmt, int drop_ctx, int drop_user)
+ FunctionDecl *method, const char *fmt, int drop_ctx)
{
string fullname = method->getName().str();
int num_params = method->getNumParams();
+ int drop_user = 0;

if (drop_ctx) {
print_indent(indent, "ctx = Context.getDefaultInstance()\n");
@@ -484,6 +483,8 @@ void python_generator::print_method_call(int indent, const isl_class &clazz,
if (i > 0)
printf(", ");
print_arg_in_call(method, fmt, i, drop_ctx);
+ if (is_callback_arg(method, i))
+ drop_user = 1;
}
if (drop_user)
printf(", None");
@@ -544,7 +545,7 @@ void python_generator::print_method(const isl_class &clazz,
continue;
print_callback(param, i - drop_ctx);
}
- print_method_call(8, clazz, method, fixed_arg_fmt, drop_ctx, drop_user);
+ print_method_call(8, clazz, method, fixed_arg_fmt, drop_ctx);

if (clazz.is_get_method(method))
print_get_method(clazz, method);
@@ -638,7 +639,7 @@ void python_generator::print_method_overload(const isl_class &clazz,
int drop_ctx = first_arg_is_isl_ctx(method);

print_argument_checks(clazz, method, drop_ctx);
- print_method_call(12, clazz, method, var_arg_fmt, drop_ctx, 0);
+ print_method_call(12, clazz, method, var_arg_fmt, drop_ctx);
}

/* Print a python method with a name derived from "fullname"
diff --git a/interface/python.h b/interface/python.h
index e56c23d9c5..db54b02a0c 100644
--- a/interface/python.h
+++ b/interface/python.h
@@ -51,7 +51,7 @@ private:
vector<string> super);
void print_method_call(int indent, const isl_class &clazz,
FunctionDecl *method, const char *fmt,
- int drop_ctx, int drop_user);
+ int drop_ctx);
void print_argument_checks(const isl_class &clazz, FunctionDecl *fd,
int drop_ctx);
void print_method_overload(const isl_class &clazz,
--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:06:01 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This is needed to be able to export isl_*_list_foreach_scc.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
interface/python.cc | 33 ++++++++++++++++++---------------
1 file changed, 18 insertions(+), 15 deletions(-)

diff --git a/interface/python.cc b/interface/python.cc
index cd72aa8f87..450362943e 100644
--- a/interface/python.cc
+++ b/interface/python.cc
@@ -461,8 +461,8 @@ void python_generator::print_get_method(const isl_class &clazz,
* A "ctx" variable is first initialized as it may be needed
* in the first call to print_arg_in_call and in print_method_return.
*
- * If the method has a callback function, then any exception
- * thrown in the callback also need to be rethrown.
+ * If the method has any callback function, then any exception
+ * thrown in any callback also need to be rethrown.
*/
void python_generator::print_method_call(int indent, const isl_class &clazz,
FunctionDecl *method, const char *fmt, int drop_ctx)
@@ -479,18 +479,19 @@ void python_generator::print_method_call(int indent, const isl_class &clazz,
printf(".ctx\n");
}
print_indent(indent, "res = isl.%s(", fullname.c_str());
- for (int i = 0; i < num_params - drop_user; ++i) {
+ for (int i = 0; i < num_params; ++i) {
if (i > 0)
printf(", ");
- print_arg_in_call(method, fmt, i, drop_ctx);
- if (is_callback_arg(method, i))
- drop_user = 1;
- }
- if (drop_user)
+ print_arg_in_call(method, fmt, i, drop_ctx + drop_user);
+ if (!is_callback_arg(method, i))
+ continue;
+ ++drop_user;
+ ++i;
printf(", None");
+ }
printf(")\n");

- if (drop_user)
+ if (drop_user > 0)
print_rethrow(indent, "exc_info[0]");

print_method_return(indent, clazz, method, fmt);
@@ -508,10 +509,10 @@ void python_generator::print_method_call(int indent, const isl_class &clazz,
* If, moreover, this first argument is an isl_ctx, then remove
* it from the arguments of the Python method.
*
- * If the function has a callback argument, then it also has a "user"
- * argument. Since Python has closures, there is no need for such
- * a user argument in the Python interface, so we simply drop it.
- * We also create a wrapper ("cb{arg}") for the callback.
+ * If the function has any callback arguments, then it also has corresponding
+ * "user" arguments. Since Python has closures, there is no need for such
+ * user arguments in the Python interface, so we simply drop them.
+ * We also create a wrapper ("cb{arg}") for each callback.
*
* If the function consumes a reference, then we pass it a copy of
* the actual argument.
@@ -530,7 +531,7 @@ void python_generator::print_method(const isl_class &clazz,

for (int i = 1; i < num_params; ++i) {
if (is_callback_arg(method, i))
- drop_user = 1;
+ drop_user += 1;
}

print_method_header(is_static(clazz, method), cname,
@@ -538,12 +539,14 @@ void python_generator::print_method(const isl_class &clazz,

print_type_checks(cname, method, drop_ctx,
num_params, super);
+ drop_user = 0;
for (int i = 1; i < num_params; ++i) {
ParmVarDecl *param = method->getParamDecl(i);
QualType type = param->getOriginalType();
if (!is_callback(type))
continue;
- print_callback(param, i - drop_ctx);
+ print_callback(param, i - drop_ctx - drop_user);
+ drop_user += 1;
}
print_method_call(8, clazz, method, fixed_arg_fmt, drop_ctx);

--
2.25.1

skim...@kotnet.org

unread,
Sep 11, 2021, 3:06:02 PM9/11/21
to isl-dev...@googlegroups.com
From: Sven Verdoolaege <sv...@cerebras.net>

This is useful for performing a topological sort.

Signed-off-by: Sven Verdoolaege <sv...@cerebras.net>
---
include/isl/list.h | 5 +++--
interface/template_cpp.cc | 19 +++++++++++++++++++
isl_test_cpp.cc | 36 ++++++++++++++++++++++++++++++++++++
isl_test_python.py | 32 ++++++++++++++++++++++++++++++++
4 files changed, 90 insertions(+), 2 deletions(-)

diff --git a/include/isl/list.h b/include/isl/list.h
index 3d8dabedf0..d269cacfa6 100644
--- a/include/isl/list.h
+++ b/include/isl/list.h
@@ -89,9 +89,10 @@ __isl_give isl_##EL##_list *isl_##EL##_list_sort( \
int (*cmp)(__isl_keep struct isl_##EL *a, \
__isl_keep struct isl_##EL *b, \
void *user), void *user); \
+EXPORT \
isl_stat isl_##EL##_list_foreach_scc(__isl_keep isl_##EL##_list *list, \
- isl_bool (*follows)(__isl_keep struct isl_##EL *a, \
- __isl_keep struct isl_##EL *b, void *user), \
+ isl_bool (*follows)(__isl_keep isl_##EL *a, \
+ __isl_keep isl_##EL *b, void *user), \
void *follows_user, \
isl_stat (*fn)(__isl_take isl_##EL##_list *scc, void *user), \
void *fn_user); \
diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 57ccdf2886..10450c516a 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -529,6 +529,24 @@ static Signature each_map =
{ { Res }, { { Domain, Range }, { Res }, { Domain, Range } } };
static std::vector<Signature> each = { each_params, each_set, each_map };

+/* Signatures for isl_*_list_foreach_scc.
+ *
+ * The first callback takes two elements with the same tuple kinds.
+ * The second callback takes a list with the same tuple kinds.
+ */
+static Signature each_scc_params =
+ { { Res }, { { }, { Res }, { }, { }, { Res }, { } } };
+static Signature each_scc_set =
+ { { Res }, { { Domain },
+ { Res }, { Domain }, { Domain },
+ { Res }, { Domain } } };
+static Signature each_scc_map =
+ { { Res }, { { Domain, Range },
+ { Res }, { Domain, Range }, { Domain, Range },
+ { Res }, { Domain, Range } } };
+static std::vector<Signature> each_scc =
+ { each_scc_params, each_scc_set, each_scc_map };
+
/* Signature for creating a map from a range,
* where the domain is given by an extra argument.
*/
@@ -798,6 +816,7 @@ member_methods {
{ "flatten_range", flatten_range },
{ "floor", fn_un_op },
{ "foreach", each },
+ { "foreach_scc", each_scc },
{ "ge_set", { set_join } },
{ "gt_set", { set_join } },
{ "gist", bin_op },
diff --git a/isl_test_cpp.cc b/isl_test_cpp.cc
index d606a21f8a..01308d1718 100644
--- a/isl_test_cpp.cc
+++ b/isl_test_cpp.cc
@@ -12,6 +12,8 @@
#include <stdlib.h>
#include <string.h>

+#include <map>
+
#include <isl/options.h>
#include <isl/typed_cpp.h>

@@ -114,6 +116,38 @@ static void test_foreach(isl::ctx ctx)
assert(caught);
}

+/* Test the functionality of "foreach_scc" functions.
+ *
+ * In particular, test it on a list of elements that can be completely sorted
+ * but where two of the elements ("a" and "b") are incomparable.
+ */
+static void test_foreach_scc(isl::ctx ctx)
+{
+ isl::multi_pw_aff id;
+ isl::id_list list(ctx, 3);
+ isl::id_list sorted(ctx, 3);
+ std::map<std::string, isl::map> data = {
+ { "a", isl::map(ctx, "{ [0] -> [1] }") },
+ { "b", isl::map(ctx, "{ [1] -> [0] }") },
+ { "c", isl::map(ctx, "{ [i = 0:1] -> [i] }") },
+ };
+
+ for (const auto &kvp: data)
+ list = list.add(kvp.first);
+ id = data.at("a").space().domain().identity_multi_pw_aff_on_domain();
+ list.foreach_scc([&data, &id] (isl::id a, isl::id b) {
+ auto map = data.at(b.name()).apply_domain(data.at(a.name()));
+ return !map.lex_ge_at(id).is_empty();
+ }, [&sorted] (isl::id_list scc) {
+ assert(scc.size() == 1);
+ sorted = sorted.concat(scc);
+ });
+ assert(sorted.size() == 3);
+ assert(sorted.at(0).name() == "b");
+ assert(sorted.at(1).name() == "c");
+ assert(sorted.at(2).name() == "a");
+}
+
/* Test the functionality of "every" functions.
*
* In particular, test the generic functionality and
@@ -313,6 +347,7 @@ static void test_typed(isl::ctx ctx)
* - Different parameter types
* - Different return types
* - Foreach functions
+ * - Foreach SCC function
* - Exceptions
* - Spaces
* - Schedule trees
@@ -331,6 +366,7 @@ int main()
test_parameters(ctx);
test_return(ctx);
test_foreach(ctx);
+ test_foreach_scc(ctx);
test_every(ctx);
test_exception(ctx);
test_space(ctx);
diff --git a/isl_test_python.py b/isl_test_python.py
index 443f5a1f4a..2870a16959 100755
--- a/isl_test_python.py
+++ b/isl_test_python.py
@@ -192,6 +192,36 @@ def test_foreach():
caught = True
assert(caught)

+# Test the functionality of "foreach_scc" functions.
+#
+# In particular, test it on a list of elements that can be completely sorted
+# but where two of the elements ("a" and "b") are incomparable.
+#
+def test_foreach_scc():
+ list = isl.id_list(3)
+ sorted = [isl.id_list(3)]
+ data = {
+ 'a' : isl.map("{ [0] -> [1] }"),
+ 'b' : isl.map("{ [1] -> [0] }"),
+ 'c' : isl.map("{ [i = 0:1] -> [i] }"),
+ }
+ for k, v in data.items():
+ list = list.add(k)
+ id = data['a'].space().domain().identity_multi_pw_aff_on_domain()
+ def follows(a, b):
+ map = data[b.name()].apply_domain(data[a.name()])
+ return not map.lex_ge_at(id).is_empty()
+
+ def add_single(scc):
+ assert(scc.size() == 1)
+ sorted[0] = sorted[0].concat(scc)
+
+ list.foreach_scc(follows, add_single)
+ assert(sorted[0].size() == 3)
+ assert(sorted[0].at(0).name() == "b")
+ assert(sorted[0].at(1).name() == "c")
+ assert(sorted[0].at(2).name() == "a")
+
# Test the functionality of "every" functions.
#
# In particular, test the generic functionality and
@@ -432,6 +462,7 @@ def test_ast_build_expr():
# - Different parameter types
# - Different return types
# - Foreach functions
+# - Foreach SCC function
# - Every functions
# - Spaces
# - Schedule trees
@@ -442,6 +473,7 @@ test_constructors()
test_parameters()
test_return()
test_foreach()
+test_foreach_scc()
test_every()
test_space()
test_schedule_tree()
--
2.25.1

Reply all
Reply to author
Forward
0 new messages