Tensor creation performance bottleneck

275 views
Skip to first unread message

Jose Ignacio Fernandez

unread,
Mar 18, 2019, 1:30:54 PM3/18/19
to TensorFlow.js Discussion
I've found that creating tensors is a bit of a performance bottleneck in Tensorflow.js, so I've tried a simple benchmark against Python's Tensorflow. I'm building a simple Layers model, which is optimised using a mean squared error loss function for a number of epochs on some synthetic data, which shows how creating tensors is an overhead:

JS code:

#!/usr/bin/env node
const tf =  require(`@tensorflow/tfjs-node${process.env.GPU === 'true' ? '-gpu' : ''}`);
const model = tf.sequential();
const featuresCount = 1;
model.add(tf.layers.dense({ inputDim: featuresCount, units: 128 }));
model.add(tf.layers.dense({ inputDim: 128, units: 128 }));
model.add(tf.layers.dense({ inputDim: 128, units: 1 }));

async function time(str, fn) {
  if (str) { process.stdout.write(`${str}... `); }
  const start = Date.now();
  const out = await (fn || str)();
  console.log(`done in ${(Date.now() - start) / 1000}s`);
  return Promise.resolve(out);
}

const AMOUNT = 10000;
const EPOCHS = 50;
const LEARNING_RATE = 0.001;

time(`Training ${EPOCHS} epochs`, async () => {
  for (let i = 0; i < EPOCHS; i += 1) {
    tf.tidy(() => {
      const xs = tf.tensor2d(Array(AMOUNT).fill(0).map((x, i) => [Math.random()]));
      const ys = tf.tensor2d(Array(AMOUNT).fill(0).map((x, i) => [Math.random() + 10]));
      tf.train.sgd(LEARNING_RATE).minimize(() => model.apply(xs).sub(ys).square().mean());
    });
  }
});

Python's equivalent code:

#!/usr/bin/env python3
import tensorflow as tf
import numpy as np
import time
import sys
import random

features_count = 1
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(128, input_shape=[1]))
model.add(tf.keras.layers.Dense(128))
model.add(tf.keras.layers.Dense(1))

class Timer:
  def __init__(self, title=None):
    self.title = title
  def __enter__(self):
    self.start = time.time()
    if self.title:
      sys.stdout.write(f'{self.title}... ')
      sys.stdout.flush()
  def __exit__(self, *args):
    duration = time.time() - self.start
    print(f'done in {duration}s')

AMOUNT = 10000
EPOCHS = 50
LEARNING_RATE = 0.001

with Timer(f'Training {EPOCHS} epochs'):
  sess = tf.keras.backend.get_session()
  inputs = tf.placeholder(shape=[AMOUNT, features_count], dtype=tf.float32)
  labels = tf.placeholder(shape=[AMOUNT, 1], dtype=tf.float32)
  loss = tf.reduce_mean(tf.square(labels - model(inputs)))
  optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE)
  train = optimizer.minimize(loss, var_list=model.trainable_variables)
  for i in range(EPOCHS):
    xs = np.random.random((AMOUNT, 1))
    ys = np.random.random((AMOUNT, 1)) + 10
    sess.run(train, feed_dict = { inputs: xs, labels: ys })

Python's code takes less than 2s and JS takes around 8s. Fiddling with the JS code shows that:
- tf.tidy is not an overhead (i.e., disposing tensors is fast).
- creating the JS array of synthetic data is not an overhead (i.e., moving Array(AMOUNT).fill(0).map((x, i) => [Math.random()]) out of the for loop does not increase performance).
- calls to tf.tensor2d are the actual overhead, as moving the tensor creation out of the for loop enables JS performance to be on par with Python's (but this would imply using the same tensors every training iteration).

I've got the impression Python manages tensor creation in a different way: I guess that by creating a tensor placeholder, the tensor's required memory is not released as long as the placeholder exists. This lets calls to sess.run feed the tensor with the data supplied in feed_dict without re-creating tensors.

Is there a similar way to achieve this in JavaScript? I.e., create a kind of symbolic tensor with empty data, then supply new data whenever the symbolic tensor is to be used, in order to prevent disposing and creating tensors of the same size again and again. I'm assuming that'd enable a performance improvement, but I don't know so much about the internals. Or is there another way to speed up tensor creation in JS?

Is it a matter of how each tensorflow version (Python and JS) is compiled? I've tested on both CPU and GPU with similar results.

Of course, there are workarounds to this issue that are possible for each application. For cases such as reinforcement learning, fresh environment observations are tensors that need to be created on every training iteration, which is why a traditional training loop can't be used.

Thanks,
Jose

Nikhil Thorat

unread,
Mar 18, 2019, 3:37:15 PM3/18/19
to Jose Ignacio Fernandez, TensorFlow.js Discussion, Shanqing Cai, Nick Kreeger
Thanks for this detailed analysis, this is really helpful.

I'm going to add a few of the folks who are starting to really dive into the details of Node.js performance vs Python to take a look at this breakdown.

My guess is creating a Tensor handle in the binding is much slower than in Python.

--
You received this message because you are subscribed to the Google Groups "TensorFlow.js Discussion" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfjs+uns...@tensorflow.org.
Visit this group at https://groups.google.com/a/tensorflow.org/group/tfjs/.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/tfjs/bb5ef9f7-05b4-4111-9d2d-c0db37d12aa0%40tensorflow.org.

Shanqing Cai

unread,
Mar 18, 2019, 3:44:25 PM3/18/19
to TensorFlow.js Discussion, joseignaci...@gmail.com, ca...@google.com, kre...@google.com
+1 Thanks, Jose. This is really good information.

Just one question of clarification. What version of TensorFlow (Python) are you using? Is it using eager or graph execution?

Shanqing Cai

unread,
Mar 18, 2019, 3:47:06 PM3/18/19
to TensorFlow.js Discussion, joseignaci...@gmail.com, ca...@google.com, kre...@google.com
Ah - Never mind. The session.run() makes it clear it is graph execution. But it would still be nice to know the version of tensorflow you are using.

Jose Ignacio Fernandez

unread,
Mar 18, 2019, 4:29:34 PM3/18/19
to Shanqing Cai, TensorFlow.js Discussion, ca...@google.com, kre...@google.com
Thanks for the response!

In Python, I’m using tensorflow 1.13.1 (and tensorflow-gpu 1.12.0 when running on GPU). In JavaScript, I’m using tfjs-node 0.3.0, tfjs-node-gpu 0.3.1, tfjs-core 0.15.2.

I tried it on the browser, too, with similar results by using tfjs-1.0.0. In this case, I assume there’s no bindings overhead. You can try yourself using the attached tf.html file.
tf.html

Jose Ignacio Fernandez

unread,
Mar 28, 2019, 10:40:46 AM3/28/19
to TensorFlow.js Discussion
I think I found where the bottleneck is. The main issue is providing an Array of data and no shape when calling tf.tensor. The second possible optimization is passing a Float32Array instead of a regular Array.

I've run the following benchmark that shows this:

#!/usr/bin/env node

const tf =  require(`@tensorflow/tfjs-node${process.env.GPU === 'true' ? '-gpu' : ''}`);

async function time(str, fn) {
  if (str) { process.stdout.write(`${str}... `); }
  const start = Date.now();
  const out = await (fn || str)();
  console.log(`done in ${(Date.now() - start) / 1000}s`);
  return Promise.resolve(out);
}

(async () => {

const AMOUNT = 2000;

function randomTensor() {
  const data = Array(AMOUNT).fill(0).map((x, i) => [Math.random()]);
  return tf.tensor2d(data);
}

await time(`Creating ${AMOUNT} tensors`, async () => {
  for (let i = 0; i < AMOUNT; i++) {
    randomTensor().dispose();
  }
});

console.log();
console.log("Now, let's pass shape to tf.tensor:");

function fasterRandomTensor() {
  const data = Array(AMOUNT).fill(0).map((x, i) => Math.random());
  return tf.tensor(data, [AMOUNT, 1], 'float32');
}

await time(`Creating ${AMOUNT} "faster" tensors`, async () => {
  for (let i = 0; i < AMOUNT; i++) {
    fasterRandomTensor().dispose();
  }
});

console.log();
console.log("Now, let's pass shape and Float32Arrays to tf.tensor:");

function fastestRandomTensor() {
  const data = new Float32Array(Array(AMOUNT).fill(0).map((x, i) => Math.random()));
  return tf.tensor(data, [AMOUNT, 1], 'float32');
}

await time(`Creating ${AMOUNT} "fastest" tensors`, async () => {
  for (let i = 0; i < AMOUNT; i++) {
    fastestRandomTensor().dispose();
  }
});

})().catch(console.error);

The output I get in my computer is:

Creating 2000 tensors... done in 10.484s

Now, let's pass shape to tf.tensor:
Creating 2000 "faster" tensors... done in 2.026s

Now, let's pass shape and Float32Arrays to tf.tensor:
Creating 2000 "fastest" tensors... done in 0.211s

My assumption is that passing no shape to the tensor constructor implies inferring the shape, which is a slow task. I think this is worth a mention in the documentation of tf.tensor, which so far just states that the tensor's shape can be inferred from the values (maybe it's documented somewhere I didn't find!). Also, using a Float32Array can also increase the speed even further (of course, the actual numbers might change from one computer to another).

Nikhil Thorat

unread,
Mar 28, 2019, 10:48:11 AM3/28/19
to Jose Ignacio Fernandez, TensorFlow.js Discussion, Nick Kreeger, Shanqing Cai
+Nick and Shanqing

--
You received this message because you are subscribed to the Google Groups "TensorFlow.js Discussion" group.
To unsubscribe from this group and stop receiving emails from it, send an email to tfjs+uns...@tensorflow.org.
Visit this group at https://groups.google.com/a/tensorflow.org/group/tfjs/.

Nick Kreeger

unread,
Mar 29, 2019, 2:52:43 PM3/29/19
to Nikhil Thorat, Jose Ignacio Fernandez, TensorFlow.js Discussion, Shanqing Cai
I think the root cause is two things:
  • There is a memory leak for non-tracked Tensors (e.g. the output from tf.add([1], [2]) will leak)
  • We can share native V8 memory with Tensor creation.
I have a PR to track these things - it's not ready for testing and I have a couple of other things to juggle at the moment. It does fix the case for massive Tensor creation over a long period of time.

Reply all
Reply to author
Forward
0 new messages