batching dataset elements using filename

52 views
Skip to first unread message

Gustavo Hylander

unread,
Jan 4, 2022, 8:09:37 AMJan 4
to Discuss
Hi all,
I need to batch my data in windows of 3 elements in size.
The data must be batched according to critical information contained in the filename. 
How can I approach this? My failed attempts:

I have tried using tf.data.Dataset.window(). This conveniently creates windows of elements, but there is now way for me to drop the windows with unrelated elements (the dataset loses the file_paths property after applying tf.data.Dataset.window()).

I have tried to use tf.data.Dataset.group_by_window(), but have been unable to figure out how to write the key and reduce functions.


Jiri Simsa

unread,
Jan 4, 2022, 1:02:28 PMJan 4
to Gustavo Hylander, Discuss
Hi Gustavo, something along the following lines should do the trick for you:

import tensorflow as tf

NUM_HASH_BUCKETS = 1024
WINDOW_SIZE = 3

filenames = ["a.tfrecord", "b.tfrecord", "c.tfrecord"]

def flat_map_fn(filename):
return tf.data.TFRecordDataset(filename).map(lambda x: (filename, x))

dataset = filenames.flat_map(flat_map_fn)

def key_fn(filename, _):
return tf.strings.to_hash_bucket_fast((filename, NUM_HASH_BUCKETS)

def reduce_fn(key, dataset):
return dataset.batch(WINDOW_SIZE)

dataset = dataset.group_by_window(key_fn, reduce_fn, WINDOW_SIZE)

The idea is that you will augment each element read from a file with the filename (the `map` transformation in `flat_map_fn` takes care of that) and then use this information as the key for the `key_fn` passed into `group_by_window`.

Best,

Jiri

--
You received this message because you are subscribed to the Google Groups "Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to discuss+u...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/discuss/b16977e5-568f-4c5d-b42e-0a995606d6cfn%40tensorflow.org.

Gustavo Hylander

unread,
Jan 5, 2022, 6:22:51 AMJan 5
to Discuss, Jiri Simsa, Discuss, Gustavo Hylander
Hi Jiri,
thanks for your help. I'm trying to implement your suggestions.

I have trouble with the flat_map fucntion:

path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

train_dataset = tf.keras.preprocessing.image_dataset_from_directory(train_dir,
                                             shuffle=False,
                                             label_mode='categorical',
                                             batch_size=hyperparameters["BATCH_SIZE"],
                                             image_size=IMG_SIZE)
unb_train_dataset = train_dataset.unbatch()
filenames = os.listdir(os.path.join(train_dir,'dogs')) + os.listdir(os.path.join(train_dir,'cats'))
def flat_map_fn(filename):
  return tf.data.TFRecordDataset(filename).map(lambda x: (filename, x))

test_dataset = unb_train_dataset.flat_map(flat_map_fn)
TypeError: tf__flat_map_fn() takes 1 positional argument but 2 were given

Jiri Simsa

unread,
Jan 5, 2022, 12:20:57 PMJan 5
to Gustavo Hylander, Discuss
The error suggests that the elements of the `unb_train_dataset` dataset which is input to the `flat_map` are pairs. You can inspect the signature of the dataset by printing `unb_train_dataset.element_spec`. The signature of `flat_map_fn` needs to map the signature of `unb_train_dataset`. 

Gustavo Hylander

unread,
Jan 10, 2022, 3:58:06 AMJan 10
to Discuss, Jiri Simsa, Discuss, Gustavo Hylander
Hi Jiri,
picked this up again, using 'tf.keras.utils.image_dataset_from_directory' with 'label_mode=categorical' creates a dataset of shape ((batch), (labels)). I modified the flat_map_fn accordingly 'flat_map_fn(filename, labels)'

I'm trying now to implement it to the existing dataset. In your example code, you apply the flat_map to the list of filenames. I'm trying to apply it to my dataset object, and i'm getting the following error:

NUM_HASH_BUCKETS = 1024
WINDOW_SIZE = 3

filenames = os.listdir(os.path.join(train_dir,'dogs')) + os.listdir(os.path.join(train_dir,'cats'))
# for element in filenames:
#   print(element.dtype)

def flat_map_fn(filename, labels):
  return tf.data.TFRecordDataset(filename).map(lambda x: (filename, x))

test_dataset = unb_train_dataset.flat_map(flat_map_fn)

TypeError: in user code: /tmp/ipykernel_1179/2477723604.py:9 flat_map_fn * return tf.data.TFRecordDataset(filename).map(lambda x: (filename, x)) /root/miniconda3/lib/python3.8/site-packages/tensorflow/python/data/ops/readers.py:335 __init__ ** filenames = _create_or_validate_filenames_dataset(filenames) /root/miniconda3/lib/python3.8/site-packages/tensorflow/python/data/ops/readers.py:66 _create_or_validate_filenames_dataset raise TypeError( TypeError: `filenames` must be a `tf.Tensor` of dtype `tf.string` dtype. Got <dtype: 'float32'>

I'm not experienced with TFRecords, but it seems the function you wrote is trying to get the filenames from the dataset?

Jiri Simsa

unread,
Jan 10, 2022, 12:14:49 PMJan 10
to Gustavo Hylander, Discuss
Hi Gustavo, my suggestion is to post your question on https://discuss.tensorflow.org/ which will allow you to receive support from the broader TF community. I unfortunately do not have bandwidth for helping you to debug your code.

Best,

Jiri

Gustavo Hylander

unread,
Jan 11, 2022, 4:48:19 AMJan 11
to Discuss, Jiri Simsa, Discuss, Gustavo Hylander
Hi Jiri,
thanks for the help and pointer, i'll check out the forum you linked

Reply all
Reply to author
Forward
0 new messages