Data fetch bottleneck at inference, but not during training for TPU

11 views
Skip to first unread message

Santosh Gupta

unread,
Jul 28, 2020, 5:27:09 PM7/28/20
to TPU Users


This is what my inference setup looks like

autotune = tf.data.experimental.AUTOTUNE


with strategy.scope():
    model
= LoadModel()
    raw_dataset
= tf.data.TFRecordDataset(tfRecordAddress)
    train_dataset
= raw_dataset.map(_parse_example, num_parallel_calls=autotune)
    train_dataset
= train_dataset.padded_batch(batch_size, padding_values=(1, 1, b'-'), padded_shapes=(512, 512, 1))
   
# train_dataset = train_dataset.repeat()
    train_dataset
= train_dataset.prefetch(autotune)
    train_dataset
= strategy.experimental_distribute_dataset(train_dataset)


def per_core_inference_fn(inputIds,attnIds ):
   
return model.inference((inputIds, attnIds))


@tf.function
def inference_fn(inputIds, attnIds):
   
return strategy.run(per_core_inference_fn, args=(inputIds,attnIds))


results
= []
for x in train_dataset:
    t0
= time.time()
    results
.append(inference_fn(x[0], x[1]))
    t1
= time.time()
   
print('time is :', t1-t0)


With huge batch_sizes, the inference is blazing fast, something like .0003 seconds. However, the fetching of the next batch takes a long time, for x in train_dataset:, like 60-80 seconds.

As far as I can tell, I am doing the inference correctly, but somehow the TPU's CPU is running into a huge bottleneck with the batch retrieval.

I did Not see this bottleneck during training. So it looks like model.fit is doing something I'm not.
Reply all
Reply to author
Forward
0 new messages