Row indexing a dataset with numpy integers
Feature request
Allow indexing datasets with a scalar numpy integer type.
Motivation
Indexing a dataset with a scalar numpy.int* object raises a TypeError. This is due to the test in datasets/formatting/formatting.py:key_to_query_type
def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str: if isinstance(key, int): return "row" elif isinstance(key, str): return "column" elif isinstance(key, (slice, range, Iterable)): return "batch" _raise_bad_key_type(key)
In the row case, it checks if key is an int, which returns false when key is integer like but not a builtin python integer type. This is counterintuitive because a numpy array of np.int64s can be used for the batch case.
For example:
import numpy as np import datasets dataset = datasets.Dataset.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) # Regular indexing dataset[0] dataset[:2] # Indexing with numpy data types (expect same results) idx = np.asarray([0, 1]) dataset[idx] # Succeeds when using an array of np.int64 values dataset[idx[0]] # Fails with TypeError when using scalar np.int64
For the user, this can be solved by wrapping idx[0] in int but the test could also be changed in key_to_query_type to accept a less strict definition of int.
+import numbers + def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str: + if isinstance(key, numbers.Integral): - if isinstance(key, int): return "row" elif isinstance(key, str): return "column" elif isinstance(key, (slice, range, Iterable)): return "batch" _raise_bad_key_type(key)
Looking at how others do it, pandas has an is_integer definition that it checks which uses is_integer_object defined in pandas/_libs/utils.pxd:
cdef inline bint is_integer_object(object obj) noexcept: """ Cython equivalent of `isinstance(val, (int, np.integer)) and not isinstance(val, (bool, np.timedelta64))` Parameters ---------- val : object Returns ------- is_integer : bool Notes ----- This counts np.timedelta64 objects as integers. """ return (not PyBool_Check(obj) and isinstance(obj, (int, cnp.integer)) and not is_timedelta64_object(obj))
This would be less flexible as it explicitly checks for numpy integer, but worth noting that they had the need to ensure the key is not a bool.
Your contribution
I can submit a pull request with the above changes after checking that indexing succeeds with the numpy integer type. Or if there is a different integer check that would be preferred I could add that.
If there is a reason not to want this behavior that is fine too.