def plusFive(tf: Ops) =
val input = tf.placeholder(classOf[TFloat32])
val output = tf.math.add(input, tf.constant(5.0f))
Signature.builder().key("plusFive").input("x", input).output("y", output).build()
def printGraphInfo(g: Graph) =
println(s"g.isEager() = ${g.isEager()}")
println(s"g.isGraph() = ${g.isGraph()}")
println(s"g.environmentType() = ${g.environmentType()}")
val scope = g.baseScope()
println(s"scope.getDeviceString() = ${scope.getDeviceString()}")
println(s"g.getFunctions() = ${g.getFunctions().asScala.map(_.toString())}")
println(s"g.operations() = ${g.operations().asScala.map(_.toString())}")
def getGraphFunctions() =
println("\n\ngetGraphFunctions")
Using.resources(ConcreteFunction.create(plusFive), Graph()) {
(function, g) =>
val tf = Ops.create(g)
tf.call(function, tf.constant(3f))
val attached = g.getFunction(function.getDefinedName())
assert(attached != null)
val x = TFloat32.scalarOf(10f)
val y = attached.call(x).asInstanceOf[TFloat32]
assert(y.getFloat() == 15f)
printGraphInfo(g)
}