I am using a non-separable loss function with my network, i.e., the loss cannot be expressed as individual losses over each training example which are then aggregated. I have set up the network so that each mini-batch is a group of "related" train examples, which the non-separable loss function jointly takes as input. The issue is that to group the related train examples together in one mini-batch, the batch size needs to be large, say, on the order of about 1000. For large-ish networks, this causes memory issues. I am wondering if it is possible to split the related batch into smaller batches (say 5 batches of 200), feed each one forward, and then only compute the loss and backprop after the 5 smaller batches have been fed forward. If anyone has other suggestions for how to deal with this, that would be appreciated as well, thanks.