From: Sven Verdoolaege <
sv...@cerebras.net>
For the templated interface, the arguments of any callbacks
are spliced into the list of arguments.
If the callback is the last argument (aside from
the corresponding user pointer), then the splicing can be performed
purely locally. However, if a callback appears in any other position
(in particular, if there are multiple callbacks), then the splicing
needs to be taken into account for any further arguments.
Let Method::print_fd_arg_list take care of the accounting.
Most callers ignore the extra argument to the Method::print_fd_arg_list
callback. Only cpp_generator::class_printer::print_method_header
needs to know the position in the flattened list of arguments
because it needs to pass this position to the type printer.
interface/cpp.cc | 33 ++++++++++++++++++++++++++-------
interface/cpp.h | 4 ++--
interface/generator.cc | 8 ++++++++
interface/generator.h | 1 +
interface/plain_cpp.cc | 4 ++--
interface/template_cpp.cc | 6 +++---
6 files changed, 42 insertions(+), 14 deletions(-)
diff --git a/interface/cpp.cc b/interface/cpp.cc
index 1e7648f249..2fdda6e7d2 100644
--- a/interface/cpp.cc
+++ b/interface/cpp.cc
@@ -649,11 +649,11 @@ void cpp_generator::class_printer::print_method_header(
else
os << cppstring;
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
std::string name = method.fd->getParamDecl(i)->getName().str();
ParmVarDecl *param = method.get_param(i);
QualType type = param->getOriginalType();
- string cpptype = type_printer.param(i, type);
+ string cpptype = type_printer.param(arg, type);
if (!method.param_needs_copy(i))
os << "const " << cpptype << " &" << name;
@@ -1026,6 +1026,10 @@ void Method::print_arg_list(std::ostream &os, int start, int end,
/* Print the arguments from "start" (inclusive) to "end" (exclusive)
* as arguments to a method of C function call, using "print_arg"
* to print each individual argument.
+ * The first argument to this callback is the position of the argument
+ * in this->fd.
+ * The second argument is the (first) position in the list of arguments
+ * with all callback arguments spliced in.
*
* Call print_arg_list to do the actual printing, skipping
* the user argument that comes after every callback argument.
@@ -1035,23 +1039,38 @@ void Method::print_arg_list(std::ostream &os, int start, int end,
* of the C function because the pair of callback and user pointer
* is considered as a single argument that is printed as a whole
* by Method::print_param_use.
+ *
+ * In case of a callback argument, the second argument to "print_arg"
+ * is also adjusted to account for the spliced-in arguments of the callback.
+ * The return value takes the place of the callback itself,
+ * while the arguments (excluding the final user pointer)
+ * take the following positions.
*/
void Method::print_fd_arg_list(std::ostream &os, int start, int end,
- const std::function<void(int i)> &print_arg) const
+ const std::function<void(int i, int arg)> &print_arg) const
{
- print_arg_list(os, start, end, [this, &print_arg] (int i) {
+ int arg = start;
+
+ print_arg_list(os, start, end, [this, &print_arg, &arg] (int i) {
auto type = fd->getParamDecl(i)->getType();
- print_arg(i);
- return generator::is_callback(type);
+ print_arg(i, arg++);
+ if (!generator::is_callback(type))
+ return false;
+ arg += generator::prototype_n_args(type) - 1;
+ return true;
});
}
/* Print the arguments to the method call, using "print_arg"
* to print each individual argument.
+ * The first argument to this callback is the position of the argument
+ * in this->fd.
+ * The second argument is the (first) position in the list of arguments
+ * with all callback arguments spliced in.
*/
void Method::print_cpp_arg_list(std::ostream &os,
- const std::function<void(int i)> &print_arg) const
+ const std::function<void(int i, int arg)> &print_arg) const
{
int first_param = kind == member_method ? 1 : 0;
print_fd_arg_list(os, first_param, num_params(), print_arg);
diff --git a/interface/cpp.h b/interface/cpp.h
index b28d4f3c29..f29c5ed6a9 100644
--- a/interface/cpp.h
+++ b/interface/cpp.h
@@ -36,9 +36,9 @@ struct Method {
static void print_arg_list(std::ostream &os, int start, int end,
const std::function<bool(int i)> &print_arg_skip_next);
void print_fd_arg_list(std::ostream &os, int start, int end,
- const std::function<void(int i)> &print_arg) const;
+ const std::function<void(int i, int arg)> &print_arg) const;
void print_cpp_arg_list(std::ostream &os,
- const std::function<void(int i)> &print_arg) const;
+ const std::function<void(int i, int arg)> &print_arg) const;
const isl_class &clazz;
FunctionDecl *const fd;
diff --git a/interface/generator.cc b/interface/generator.cc
index eacbc36a65..566ed427b3 100644
--- a/interface/generator.cc
+++ b/interface/generator.cc
@@ -781,6 +781,14 @@ const FunctionProtoType *generator::extract_prototype(QualType type)
return type->getPointeeType()->getAs<FunctionProtoType>();
}
+/* Given the type of a function pointer, return the number of arguments
+ * of the corresponding function prototype.
+ */
+int generator::prototype_n_args(QualType type)
+{
+ return extract_prototype(type)->getNumArgs();
+}
+
/* Return the function name suffix for the type of "param".
*
* If the type of "param" is an isl object type,
diff --git a/interface/generator.h b/interface/generator.h
index 29faaaac08..864e123694 100644
--- a/interface/generator.h
+++ b/interface/generator.h
@@ -195,6 +195,7 @@ public:
static bool is_mutator(const isl_class &clazz, FunctionDecl *fd);
static string extract_type(QualType type);
static const FunctionProtoType *extract_prototype(QualType type);
+ static int prototype_n_args(QualType type);
static ParmVarDecl *persistent_callback_arg(FunctionDecl *fd);
};
diff --git a/interface/plain_cpp.cc b/interface/plain_cpp.cc
index 23013733eb..0b32e69691 100644
--- a/interface/plain_cpp.cc
+++ b/interface/plain_cpp.cc
@@ -934,7 +934,7 @@ void plain_cpp_generator::impl_printer::print_method(const Method &method)
osprintf(os, " auto res = %s", methodname.c_str());
- method.print_fd_arg_list(os, 0, num_params, [&] (int i) {
+ method.print_fd_arg_list(os, 0, num_params, [&] (int i, int arg) {
method.print_param_use(os, i);
});
osprintf(os, ";\n");
@@ -1002,7 +1002,7 @@ void plain_cpp_generator::impl_printer::print_method(
print_check_ptr("ptr");
osprintf(os, " return ");
method.print_call(os, generator.isl_namespace());
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
ParmVarDecl *param = method.fd->getParamDecl(i);
print_arg_conversion(param, method.get_param(i));
diff --git a/interface/template_cpp.cc b/interface/template_cpp.cc
index 71d5b38df6..427056cdfd 100644
--- a/interface/template_cpp.cc
+++ b/interface/template_cpp.cc
@@ -1636,7 +1636,7 @@ void template_cpp_generator::method_impl_printer::print_constructor_body(
const auto &base_name = instance.base_name();
os << " : " << base_name;
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
os << method.fd->getParamDecl(i)->getName().str();
});
os << "\n";
@@ -1735,7 +1735,7 @@ void template_cpp_generator::method_impl_printer::print_callback_method_body(
os << " return ";
os << base_name << "::" <<
method.name;
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
auto param = method.fd->getParamDecl(i);
if (param == method.callback)
@@ -1763,7 +1763,7 @@ void template_cpp_generator::method_impl_printer::print_method_body(
os << "{\n";
os << " auto res = ";
os << base_name << "::" <<
method.name;
- method.print_cpp_arg_list(os, [&] (int i) {
+ method.print_cpp_arg_list(os, [&] (int i, int arg) {
os << method.fd->getParamDecl(i)->getName().str();
});
os << ";\n";
--
2.25.1