fix: correct negative axis handling in roll function by f14XuanLv · Pull Request #2878 · xtensor-stack/xtensor
This PR fixes a bug in xt::roll(e, shift, axis) where negative axis indices (e.g., -1 for the last axis) were incorrectly rejected.
This caused valid negative indices like -1 to incorrectly trigger the bounds check exception.
auto cpy = empty_like(e);
const auto& shape = cpy. shape();
- std::size_t saxis = static_cast<std::size_t>(axis);
- if (axis < 0)
- {
- axis += std::ptrdiff_t(cpy. dimension());
- }
+ const auto dim = cpy.dimension();
- if (saxis >= cpy.dimension() || axis < 0)
+ if (axis < -static_cast<std::ptrdiff_t>(dim) || axis >= static_cast<std::ptrdiff_t>(dim))
{
- XTENSOR_THROW(std::runtime_error, "axis is no within shape dimension.");
+ XTENSOR_THROW(std::runtime_error, "axis is not within shape dimension.");
}
+ std::size_t saxis = normalize_axis(dim, axis);
+
const auto axis_dim = static_cast<std::ptrdiff_t>(shape[saxis]); xarray<double> expected8 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}};
ASSERT_EQ(expected8, xt::roll(e2, -2, /*axis*/ 2));
+ // Boundary error cases
+ EXPECT_THROW(xt::roll(e2, 1, /*axis*/ 3), std::runtime_error);
+ EXPECT_THROW(xt::roll(e2, 1, /*axis*/ -4), std::runtime_error);
+
+ // Negative axis indices
+ xarray<double> expected9 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}};
+ ASSERT_EQ(expected9, xt::roll(e2, -2, /*axis*/ -1));
+
+ xarray<double> expected10 = {{{1, 2, 3}}, {{4, 5, 6}}, {{7, 8, 9}}};
+ ASSERT_EQ(expected10, xt::roll(e2, -2, /*axis*/ -2));
+
+ xarray<double> expected11 = {{{4, 5, 6}}, {{7, 8, 9}}, {{1, 2, 3}}};
+ ASSERT_EQ(expected11, xt::roll(e2, 2, /*axis*/ -3));
}