Using groupby() and apply() over more than a single dimension?

3,894 views
Skip to first unread message

James Adams

unread,
May 31, 2016, 2:02:53 PM5/31/16
to xarray
I'm trying to use xarray's split-apply-combine approach for processing a 3-D dataset with dimensions: time, lon, and lat.

My use case is that for each lat/lon point I want to process all the data values along the time axis.

For example let's say my dataset is 30 lats x 70 lons x 240 times. I want to compute the time series at each lon/lat point, i.e. I want to loop over all lons/lats and apply a function to the 240 time values at each lon/lat point.

I can already do this myself with simple loops but my understanding is that I can gain some performance by leveraging xarray for this sort of thing (BTW I've already seen some convenience benefits of xarray for NetCDF access, another reason I'd like to be able to use xarray for my work going forward).

My understanding on whether or not this can work is unclear -- the documentation for the groupby() function explicitly describes it as being limited to a single dimension, but I was told in another thread related to this (issue #818) that somehow you can use stack() in conjunction with groupby() to get the functionality I'm after, but I've tried several approaches to that and none have worked so far. Also I've tried using 'time' as the dimension name argument to groupby() in order to take advantage of the "convenient shortcut for aggregating over all dimensions other than the provided one". This does seem to invoke the correct number of applications of my function, but I'm not seeing the results I'm expecting when I look at the resulting output NetCDF, none of the data values are updated. That may be a separate issue related to the combine part of this, I've not found anything in the documentation about how that works other than it sort of magically combines the groups back into a single data object.


My code looks like this:


#----------------------------------------------------------------------------------------------------------------------
def function_to_be_applied(data):


    '''
    Dummy function to use for testing xarray's groupby and related functionality.
    '''


    # perform a computation on the data
    computed_data = data * 2
    return computed_data


#----------------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':


    # get the command line arguments

    precip_file = sys.argv[1]
    output_file_base = sys.argv[2]


    with xarray.open_dataset(precip_file) as dataset:

        # group by time in order to get a shortcut way of grouping by all the other dimensions, in our case 'lat' & 'lon'
        dataset.groupby('time').apply(function_to_be_applied)

        # rename the input dataset's prcp variable (which we've overwritten with computed values)
        variable_name = 'dummy_test'
        dataset = dataset.rename({'prcp': variable_name}, True)


        # write the dataset to NetCDF
        output_file = output_file_base + '_' + variable_name + '.nc'
        dataset.to_netcdf(output_file, encoding = {variable_name: {'_FillValue': np.nan, 'dtype': 'float32'}})


#----------  end example Python code --------------------------


The input NetCDF file looks like this (ncdump -h):


netcdf file:/C:/home/nclimgrid/nclimgrid_prcp_mini.nc 
{
  dimensions:
    lat = 30;
    lon = 70;
    time = UNLIMITED;   // (240 currently
  variables:
    float lat(lat=30);
      :long_name = "latitude";
      :standard_name = "latitude";

    float lon(lon=70);
      :long_name = "longitude";
      :standard_name = "longitude";

    float prcp(time=240, lon=70, lat=30);
      :standard_name = "precipitation";
      :long_name = "Precipitation, total";
      :units = "millimeters";
      :valid_max = 2000.0f; // float
      :valid_min = 0.0f; // float
      :_FillValue = -999.9f; // float
      :missing_value = -999.9f; // float

    int time(time=240);
      :long_name = "time";
      :standard_name = "time";
      :units = "days since 1800-01-01 00:00:00";
      :calendar = "gregorian";

  // global attributes:
  :standard_name_vocabulary = "CF Standard Name Table (v26, 08 November 2013)";
  :date_created = "2016-04-08  11:27:55";
  :date_modified = "2016-04-08  11:27:55";
  :geospatial_lon_max = -67.0209f; // float
  :geospatial_lat_max = 49.3542f; // float
  :geospatial_lon_min = -124.6875f; // float
  :geospatial_lat_min = 24.5625f; // float
}


Thanks in advance for any guidance.

--James

Ryan Abernathey

unread,
May 31, 2016, 4:16:32 PM5/31/16
to xar...@googlegroups.com
James,

I think what you want to do is use stack to collapse lon and lat into a single dimension. Then you can groupby('time'), apply whatever function you want, and unstack back to lon / lat.

Let us know if this works.

-Ryan

--
You received this message because you are subscribed to the Google Groups "xarray" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xarray+un...@googlegroups.com.
To post to this group, send email to xar...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xarray/be9c4d6a-35e9-4c02-91a0-fe6d0d431ca6%40googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

James Adams

unread,
May 31, 2016, 6:03:44 PM5/31/16
to xarray
Thanks for your quick help, Ryan.

This isn't working for me yet, but maybe I'm still making rookie mistakes.

My code now looks like this:

#----------------------------------------------------------------------------------------------------------------------
def function_to_be_applied(data):

    '''
    Dummy function to use for testing the xarray.groupby and related functionality.
    '''    
    logger.info("Function now being applied")
    
#     data['precipitation'] = data['precipitation'] * 2
    
    # perform a computation on the data
    computed_data = data * 2
    return computed_data

#----------------------------------------------------------------------------------------------------------------------
if __name__ == '__main__':
    try:
                
        # get the command line arguments
        precip_file = sys.argv[1]
        output_file_base = sys.argv[2]

        with xarray.open_dataset(precip_file) as dataset:
    
            # use stack() to collapse lon and lat to a single dimension
            dataset = dataset.stack(grid_cells=('lon', 'lat'))
            
            # group by time and apply the function
            dataset.groupby('time').apply(function_to_be_applied)
            
            # unstack to reinflate the lon & lat dimensions
            dataset = dataset.unstack('grid_cells')
            
            # rename the input dataset's prcp variable (which we've overwritten with computed values)
            variable_name = 'dummy'
            dataset = dataset.rename({'prcp': variable_name}, True)

            # write the dataset to NetCDF
            output_file = output_file_base + '_' + variable_name + '.nc'
            dataset.to_netcdf(output_file, encoding = {variable_name: {'_FillValue': np.nan, 'dtype': 'float32'}})

    except Exception:
        logger.error('Failed to complete', exc_info=True)
        raise

#--------------------------------- end example code

Also I created a bare bones/toy NetCDF for this test (attached) which looks like this (ncdump -h):

       File "toy.nc"

Dataset type: Hierarchical Data Format, version 5


netcdf file:/C:/tmp/toy.nc {
  dimensions:
    lat = 2;
    lon = 2;
    time = 3;
  variables:
    int prcp(time=3, lon=2, lat=2);

    double lat(lat=2);

    double lon(lon=2);

    long time(time=3);
      :calendar = "proleptic_gregorian";
      :units = "days since 2014-06-09 00:00:00";
}


The code runs to completion but my output file appears to match the input, so somewhere I'm not doing something right, i.e. nothing is being updated. Also I only see the function printing debug messages three times, which I'm assuming means that the function is being applied to all lats and lons of each time step as a group (assumption based on there being three time steps in my dataset), rather than applying to the grid cells which should contain all time steps for each lon/lat pair. There are four of these grid cells (lon/lat pairs) in my dataset so I assume that I'll see the debug log message four times if the apply() happens along the proper axis. Is this correct?

--James
toy.nc

Stephan Hoyer

unread,
May 31, 2016, 7:52:43 PM5/31/16
to xar...@googlegroups.com
For starters, you definitely need to be assigning the result of groupby.apply to something, e.g.,

applied_dataset = dataset.groupby('time').apply(function_to_be_applied)

These operations don't work in-place.

James Adams

unread,
May 31, 2016, 11:33:53 PM5/31/16
to xar...@googlegroups.com
Thanks for noticing that, Stephan, it fixed the trouble I was having with values not updating as expected.

From what I can tell the split-apply-combine using stack() / groupby() / apply() is not working as suggested, as I'm still seeing just three invocations of the function mapped using apply() when I should be seeing four (I want all three time steps to be processed together as a group for each of the four lon/lat pairs). It seems that the approach Ryan suggested is doing the reverse of what I want to do, in that it appears to be grouping the grid cells (lon/lat pairs) by time step rather than the other way around. But I don't see how you can do things the right way if groupby() is limited to a single dimension.

I remember reading in the notes for issue #818 that implementing support for multiple dimensions in groupby() would not be that difficult, but the discussion quickly went way out of my depth and even started to be irrelevant to original issue of using multiple dimensions for a grouping. If it really is relatively straight forward to implement this then I'd be happy to take a crack at it, but looking at the groupby.py file hasn't inspired confidence. I want to see how much performance gain is possible for my code using xarray but I'm hesitant to tackle this issue before that can happen, it looks daunting and I have competing obligations. I don't want to wait on someone else to do the work if it's reasonable for me to tackle it myself, but I want to first make sure that there's not another existing way to get this functionality for now, or maybe multi-dimensionality will come to groupby() in the near future if I can just be patient.

--James


--
You received this message because you are subscribed to a topic in the Google Groups "xarray" group.
To unsubscribe from this topic, visit https://groups.google.com/d/topic/xarray/fz7HHgpgwk0/unsubscribe.
To unsubscribe from this group and all its topics, send an email to xarray+un...@googlegroups.com.

To post to this group, send email to xar...@googlegroups.com.

Ryan Abernathey

unread,
Jun 1, 2016, 7:45:44 AM6/1/16
to xar...@googlegroups.com
It seems that the approach Ryan suggested is doing the reverse of what I want to do, in that it appears to be grouping the grid cells (lon/lat pairs) by time step rather than the other way around. But I don't see how you can do things the right way if groupby() is limited to a single dimension.

Yes, I gave you the wrong advice: you want to group by the OTHER dimension of your stacked array, instead of time. That's all. Please give that a try.

I can assure you that xarray in its current version can accomplish what you want. As the one who is developing the multidimensional groupby feature, I can promise you that it doesn't do what you think it does. It is designed for the case when the group variable is itself multidimensional.

Ryan Abernathey

unread,
Jun 1, 2016, 8:01:01 AM6/1/16
to xar...@googlegroups.com

James Adams

unread,
Jun 1, 2016, 12:15:03 PM6/1/16
to xar...@googlegroups.com
Thanks, Ryan, this is very helpful. That code runs with no trouble for me. I had actually already tried doing what you've suggested (grouping on the stacked axis) as well as several other permutations but I kept getting errors which I assumed were caused by an incorrect grouping. Now that I've retaken this approach I'm seeing the expected number of iterations so I think the grouping/application is now happening as expected. The errors I'm seeing are showing up after all the iterations of the applied function have completed, so I assume that there's still something amiss which is throwing off the combine phase. The code and errors are listed below:

#--------- begin example code -----------------------------------------------------------------------------------------------

def double_up(data):

    # display a message so users can see evidence of the application of this function
    logger.info("Function now being applied")
    
    # perform a computation on the data
    computed_data = data * 2
    return computed_data

#----------------------------------------------------------------------------------------------------------------------

if __name__ == '__main__':

        with xarray.open_dataset("example.nc") as dataset:
    
            # use stack() to collapse lon and lat to a single dimension
            dataset = dataset.stack(grid_cells=('lon', 'lat'))
            
            # group by grid cell and apply the function
            dataset = dataset.groupby('grid_cells').apply(double_up)
            
            # unstack to restore the lon & lat dimensions
            dataset = dataset.unstack('grid_cells')
            
            # rename the dataset's "prcp" variable (which we've overwritten with computed values)
            variable_name = 'dummy'
            dataset = dataset.rename({'prcp': variable_name}, True)

            # write the dataset to a new NetCDF file
            dataset.to_netcdf("example_results.nc", encoding = {variable_name: {'_FillValue': np.nan, 'dtype': 'float32'}})

#--------- end example code -----------------------------------------------------------------------------------------------


Errors:

  File "H:\git\mcevoy_indices\src\scripts\xarray_test.py", line 41, in <module>
    dataset = dataset.groupby('grid_cells').apply(double_data)
  File "C:\Anaconda\lib\site-packages\xarray\core\groupby.py", line 469, in apply
    combined = self._concat(applied)
  File "C:\Anaconda\lib\site-packages\xarray\core\groupby.py", line 476, in _concat
    combined = concat(applied, concat_dim, positions=positions)
  File "C:\Anaconda\lib\site-packages\xarray\core\combine.py", line 114, in concat
    return f(objs, dim, data_vars, coords, compat, positions)
  File "C:\Anaconda\lib\site-packages\xarray\core\combine.py", line 268, in _dataset_concat
    combined = Variable.concat(vars, dim, positions)
  File "C:\Anaconda\lib\site-packages\xarray\core\variable.py", line 919, in concat
    variables = list(variables)
  File "C:\Anaconda\lib\site-packages\xarray\core\combine.py", line 262, in ensure_common_dims
    var = var.expand_dims(common_dims, common_shape)
  File "C:\Anaconda\lib\site-packages\xarray\core\variable.py", line 717, in expand_dims
    expanded_data = ops.broadcast_to(self.data, tmp_shape)
  File "C:\Anaconda\lib\site-packages\xarray\core\ops.py", line 67, in f
    return getattr(module, name)(*args, **kwargs)
  File "C:\Anaconda\lib\site-packages\numpy\lib\stride_tricks.py", line 115, in broadcast_to
    return _broadcast_to(array, shape, subok=subok, readonly=True)
  File "C:\Anaconda\lib\site-packages\numpy\lib\stride_tricks.py", line 70, in _broadcast_to
    op_flags=[op_flag], itershape=shape, order='C').itviews[0]
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (2,) and requested shape (1,)


Can anyone suggest why I'm seeing this shape mismatch in the combine phase? If so then is there something I should do within my code which gets around this issue?


A few follow up questions:

Ryan's example code uses a DataArray, whereas my current code operates on a Dataset. I assume that if you apply a function to a grouped DataArray then it's applied to the component (grouped) arrays of the DataArray. But when a function is applied to a grouped Dataset does it get applied to all grouped data variables of the Dataset? Is this the main difference between Datasets and DataArrays to consider when doing a split-apply-combine?

Ryan's example code appears to make several copies of the original DataArray (da, stacked, trend, and trend_unstacked), whereas in my code I've been overwriting my Dataset in order to conserve/reuse memory. Is there a reason to avoid this overwrite/reuse approach or is this reasonable? Does xarray have some magic under the covers which provides views of the original DataArray in order to avoid copies?

Thanks in advance...

--James

Stephan Hoyer

unread,
Jun 2, 2016, 12:26:40 PM6/2/16
to xar...@googlegroups.com
James -- can you save a netCDF file that produces this error somewhere? I can download it and try running it -- this looks like some sort of internal bug that is giving you that error message.

On Wed, Jun 1, 2016 at 9:14 AM, James Adams <mono...@gmail.com> wrote:
Ryan's example code uses a DataArray, whereas my current code operates on a Dataset. I assume that if you apply a function to a grouped DataArray then it's applied to the component (grouped) arrays of the DataArray. But when a function is applied to a grouped Dataset does it get applied to all grouped data variables of the Dataset? Is this the main difference between Datasets and DataArrays to consider when doing a split-apply-combine?

That's right -- when you group over a Dataset, the individual subgroups are Dataset objects, containing all the original variables from the grouped dataset.
 
Ryan's example code appears to make several copies of the original DataArray (da, stacked, trend, and trend_unstacked), whereas in my code I've been overwriting my Dataset in order to conserve/reuse memory. Is there a reason to avoid this overwrite/reuse approach or is this reasonable? Does xarray have some magic under the covers which provides views of the original DataArray in order to avoid copies?

In many cases, we avoid the making copies in stack/unstack operations because we can reshape the underlying arrays by using views. This entirely depends on NumPy, though.

More generally, to conserve memory I usually recommend hooking xarray up to dask:

I think it's pretty harmless to reuse the same variable name, though it can make code a little harder to debug. I am not familiar enough with the details of Python's garbage collector to know exactly when intermediate (unused) variables are deleted. xarray itself very rarely supports true inplace operations, mostly because logic with mutable state gets hard to keep track of.

James Adams

unread,
Jun 2, 2016, 5:12:28 PM6/2/16
to xar...@googlegroups.com
Thanks, Stephan!

I've attached a very basic/tiny NetCDF file that I've been using for this test which I built from the Python REPL using xarray. I have tested my code with other files and I see the same errors.


The errors I'm currently seeing: 

Traceback (most recent call last):
  File "H:\git\climate_indices\src\scripts\xarray_test.py", line 41, in <module>
    dataset = dataset.groupby('grid_cells').apply(double_data)
  File "C:\Anaconda\lib\site-packages\xarray\core\groupby.py", line 469, in apply
    combined = self._concat(applied)
  File "C:\Anaconda\lib\site-packages\xarray\core\groupby.py", line 476, in _concat
    combined = concat(applied, concat_dim, positions=positions)
  File "C:\Anaconda\lib\site-packages\xarray\core\combine.py", line 114, in concat
    return f(objs, dim, data_vars, coords, compat, positions)
  File "C:\Anaconda\lib\site-packages\xarray\core\combine.py", line 268, in _dataset_concat
    combined = Variable.concat(vars, dim, positions)
  File "C:\Anaconda\lib\site-packages\xarray\core\variable.py", line 919, in concat
    variables = list(variables)
  File "C:\Anaconda\lib\site-packages\xarray\core\combine.py", line 262, in ensure_common_dims
    var = var.expand_dims(common_dims, common_shape)
  File "C:\Anaconda\lib\site-packages\xarray\core\variable.py", line 717, in expand_dims
    expanded_data = ops.broadcast_to(self.data, tmp_shape)
  File "C:\Anaconda\lib\site-packages\xarray\core\ops.py", line 67, in f
    return getattr(module, name)(*args, **kwargs)
  File "C:\Anaconda\lib\site-packages\numpy\lib\stride_tricks.py", line 115, in broadcast_to
    return _broadcast_to(array, shape, subok=subok, readonly=True)
  File "C:\Anaconda\lib\site-packages\numpy\lib\stride_tricks.py", line 70, in _broadcast_to
    op_flags=[op_flag], itershape=shape, order='C').itviews[0]
ValueError: operands could not be broadcast together with remapped shapes [original->remapped]: (2,) and requested shape (1,)

--
You received this message because you are subscribed to a topic in the Google Groups "xarray" group.
To unsubscribe from this topic, visit https://groups.google.com/d/topic/xarray/fz7HHgpgwk0/unsubscribe.
To unsubscribe from this group and all its topics, send an email to xarray+un...@googlegroups.com.
To post to this group, send email to xar...@googlegroups.com.
toy.nc

Stephan Hoyer

unread,
Jun 8, 2016, 12:38:21 AM6/8/16
to xar...@googlegroups.com
Hi James,

Thanks for following up with a GitHub issue:

This definitely *should* work, but we clearly have a few bugs to fix first! Hopefully we'll be able to get this working soon.

Cheers,
Stephan

--
You received this message because you are subscribed to the Google Groups "xarray" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xarray+un...@googlegroups.com.

To post to this group, send email to xar...@googlegroups.com.

Daniel Rothenberg

unread,
Jun 13, 2016, 12:43:07 PM6/13/16
to xarray
James, Ryan, and Stephan - 

Is there currently a mechanism in xarray/dask to parallelize this type of split-apply-combine workflow? I'm performing similar timeseries analyses in each grid-cell of high resolution climate model output, and these jobs can take a really long time, even though they're so embarrassingly parallel. 

- Daniel

Stephan Hoyer

unread,
Jun 13, 2016, 12:46:08 PM6/13/16
to xar...@googlegroups.com
Yes, in theory, but no, in practice it doesn't seem to work yet -- there are a few bugs to work out yet.

Note that the current approach won't work with tiled grid output, because xarray's reshape with dask uses a very inefficient approach. We'll need another way to process tiled grids.

Ryan Abernathey

unread,
Jun 13, 2016, 12:59:54 PM6/13/16
to xar...@googlegroups.com
Daniel,

I agree this is a really important feature. A big part of xarray's appeal is its parallel capability. But many real workflows use groupby and therefore can't really leverage dask' parallel features. 

I leaned a lot about how groupby through my recent PR (lgtm btw! ;), and now I have some ideas about how it could be done. I think it should be possible to chunk the groups and run apply via dask map_blocks. 

-R

Sent from my iPhone

James Adams

unread,
Jun 15, 2016, 10:03:53 AM6/15/16
to xarray
I'm very interested to see how much (if any) performance improvement I can get using this approach rather than what I'm currently doing using Python's multiprocessing module. For sure my code will be cleaner/more compact if I can leverage xarray for this, assuming it's more performant than my current implementation.

Can anyone provide guidance on how I might try to fix this issue? Stephan and Ryan where would you start to look for the problem(s) causing this feature not to work as expected? 

In the meantime, Daniel, I can send you a copy of some code I have which does this sort of processing using multiprocessing (process pool) in case that'll be helpful.

--James

Daniel Rothenberg

unread,
Jun 15, 2016, 1:31:03 PM6/15/16
to xarray
James, if you have an example to share that would be great. I also came up with a multiprocessing solution, but I too would prefer an xarray/dask one; if I weren't under dissertation crunch pressures I'd work on implementing it.

- Daniel

Stephan Hoyer

unread,
Jun 15, 2016, 1:34:42 PM6/15/16
to xar...@googlegroups.com
See this GitHub issue for some ideas on how to do this: https://github.com/pydata/xarray/issues/585

Basically, we need a version of dask.array's map_blocks function, possibly that uses xray's metadata to automatically infer how to put the blocks back together.

Reply all
Reply to author
Forward
0 new messages