dask: `Data.argmax` by davidhassell · Pull Request #339 · NCAS-CMS/cf-python

Expand Up @@ -7542,98 +7542,80 @@ def argmax(self, axis=None, unravel=False): If no axis is specified then the returned index locates the maximum of the whole data.
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
**Performance**
If the data index is returned as a `tuple` (see the *unravel* parameter) then all delayed operations are computed.
:Parameters:
axis: `int`, optional The specified axis over which to locate the maximum values. By default the maximum over the whole data is located. values. By default the maximum over the flattened data is located.
unravel: `bool`, optional If True, then when locating the maximum over the whole data, return the location as a tuple of indices for each axis. By default an index to the flattened array is returned in this case. Ignored if locating the maxima over a subset of the axes.
If True then when locating the maximum over the whole data, return the location as an index for each axis as a `tuple`. By default an index to the flattened array is returned in this case. Ignored if locating the maxima over a subset of the axes.
:Returns:
`int` or `tuple` or `Data` `Data` or `tuple` The location of the maximum, or maxima.
**Examples:** **Examples**
>>> d = cf.Data(np.arange(6).reshape(2, 3)) >>> print(d.array) [[0 1 2] [3 4 5]] >>> a = d.argmax() >>> a <CF Data(): 5> >>> a.array 5
>>> index = d.argmax(unravel=True) >>> index (1, 2) >>> d[index] <CF Data(1, 1): [[5]]>
>>> d = cf.Data(numpy.arange(120).reshape(4, 5, 6)) >>> d.argmax() 119 >>> d.argmax(unravel=True) (3, 4, 5) >>> d.argmax(axis=0) <CF Data(5, 6): [[3, ..., 3]]> <CF Data(3): [1, 1, 1]> >>> d.argmax(axis=1) <CF Data(4, 6): [[4, ..., 4]]> >>> d.argmax(axis=2) <CF Data(4, 5): [[5, ..., 5]]> <CF Data(2): [2, 2]>
""" if axis is not None: ndim = self._ndim if -ndim - 1 <= axis < 0: axis += ndim + 1 elif not 0 <= axis <= ndim: raise ValueError( "Can't argmax: Invalid axis specification: Expected " "-{0}<=axis<{0}, got axis={1}".format(ndim, axis) )
if ndim == 1 and axis == 0: axis = None # --- End: if Only the location of the first occurrence is returned:
if axis is None: config = self.partition_configuration(readonly=True)
out = []
for partition in self.partitions.matrix.flat: partition.open(config) array = partition.array index = np.unravel_index(array.argmax(), array.shape) mx = array[index] index = [x[0] + i for x, i in zip(partition.location, index)] out.append((mx, index)) partition.close()
mx, index = sorted(out)[-1]
if unravel: return tuple(index)
return np.ravel_multi_index(index, self.shape)
# Parse axis ndim = self._ndim if -ndim - 1 <= axis < 0: axis += ndim + 1 elif not 0 <= axis <= ndim: raise ValueError( "Can't argmax: Invalid axis specification: Expected " "-{0}<=axis<{0}, got axis={1}".format(ndim, axis) ) >>> d = cf.Data([0, 4, 2, 3, 4]) >>> d.argmax() <CF Data(): 1>
sections = self.section(axis, chunks=True) for key, d in sections.items(): array = d.varray.argmax(axis=axis) array = np.expand_dims(array, axis) sections[key] = type(self)( array, self.Units, fill_value=self.fill_value ) >>> d = cf.Data(np.arange(6).reshape(2, 3)) >>> d[1, 1] = 5 >>> print(d.array) [[0 1 2] [3 5 5]] >>> d.argmax(1) <CF Data(2): [2, 1]>
out = self.reconstruct_sectioned_data(sections) """ dx = self._get_dask() a = dx.argmax(axis=axis)
out.squeeze(axis, inplace=True) if unravel and (axis is None or self.ndim <= 1): # Return a multidimensional index tuple return tuple(np.array(da.unravel_index(a, self.shape)))
return out return type(self)(a)
def get_data(self, default=ValueError(), _units=None, _fill_value=None): """Returns the data. Expand Down