Implement scalar_prod by sebasv · Pull Request #505 · rust-ndarray/ndarray
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this! Everything looks good except for a few changes to the test. (See the individual comments.)
Note for my future self: Ordinarily, I would avoid duplication of logic like in
unrolled_sumandunrolled_prod, but I think it's fine in this case:
- I anticipate sum and product being the only cases like this
- I don't anticipate needing to ever really modify
unrolled_sum/unrolled_prod, so we don't have to worry much about keeping things in syncIf we add more than two functions like this that could be combined, though, we should combine them by e.g. taking a closure as a parameter.
- I anticipate sum and product being the only cases like this
Actually, I realised I would also greatly benefit from scalar_min and scalar_max. Shall I try to write up a macro to cover all four cases?
We could implement scalar_min and scalar_max for A: Ord. However, I'd just do it in terms of fold with something like this (taking advantage of the first method from PR #507):
impl<A, S, D> ArrayBase<S, D> where S: Data<Elem = A>, D: Dimension, { /// Returns the minimum element, or `None` if the array is empty. fn scalar_min(&self) -> Option<&A> where A: Ord, { let first = self.first()?; Some(self.fold(first, |acc, x| acc.min(x))) } }
We don't need to manually unroll this because the compiler does a good job automatically (checked with Compiler Explorer using the -O compiler option).
The desired behavior for floating-point types depends on the use-case because of NaN. One option is
arr.fold(::std::f64::NAN, |acc, &x| acc.min(x))
which ignores NaN values. (It returns NaN only if there are no non-NaN values.) The compiler does a decent job automatically unrolling this, so we don't need to manually unroll in this case either.
Will you please squash the commits into one? I don't mind squashing them myself, but then GitHub won't consider the PR merged.
Edit: It looks like you might have given me permission to push to the master branch on sebasv/ndarray since you submitted a PR using that branch? If so, and you don't mind me modifying your master branch, I can squash the commits for you.
(Ordinarily, I would just use GitHub's "Squash and merge", but that option is disabled for this repo, I don't have the permissions to enable it, and I haven't heard from @bluss in a while.)
I'll squash the commits. I also put the unrolled code in a macro, is this desired or do you want to stick with separate unrolled code for prod/sum and possible future cases? Current commit does not have the macro.
// eightfold unrolled so that floating point can be vectorized // (even with strict floating point accuracy semantics) macro_rules! unrolled_fold { ($xs:expr, $unity:expr, $operation:expr) => {{ let mut collected = $unity(); let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = ($unity(), $unity(), $unity(), $unity(), $unity(), $unity(), $unity(), $unity()); while $xs.len() >= 8 { p0 = $operation(p0, $xs[0].clone()); p1 = $operation(p1, $xs[1].clone()); p2 = $operation(p2, $xs[2].clone()); p3 = $operation(p3, $xs[3].clone()); p4 = $operation(p4, $xs[4].clone()); p5 = $operation(p5, $xs[5].clone()); p6 = $operation(p6, $xs[6].clone()); p7 = $operation(p7, $xs[7].clone()); $xs = &$xs[8..]; } collected = $operation(collected.clone(), $operation(p0, p4)); collected = $operation(collected.clone(), $operation(p1, p5)); collected = $operation(collected.clone(), $operation(p2, p6)); collected = $operation(collected.clone(), $operation(p3, p7)); // make it clear to the optimizer that this loop is short // and can not be autovectorized. for i in 0..$xs.len() { if i >= 7 { break; } collected = $operation(collected.clone(), $xs[i].clone()); } collected }} } /// Compute the sum of the values in `xs` pub fn unrolled_sum<A>(mut xs: &[A]) -> A where A: Clone + Add<Output=A> + libnum::Zero, { unrolled_fold!(xs, A::zero, A::add) } /// Compute the product of the values in `xs` pub fn unrolled_prod<A>(mut xs: &[A]) -> A where A: Clone + Mul<Output=A> + libnum::One, { unrolled_fold!(xs, A::one, A::mul) }
Sure, a macro would be nice. By the way, I just noticed that the temporary variable in scalar_prod is named sum when it would be better named prod.
Fwiw, I prefer using generic functions over macros when possible. For example:
pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A where A: Clone, I: Fn() -> A, F: Fn(A, A) -> A, { // eightfold unrolled so that floating point can be vectorized // (even with strict floating point accuracy semantics) let mut acc = init(); let (mut p0, mut p1, mut p2, mut p3, mut p4, mut p5, mut p6, mut p7) = (init(), init(), init(), init(), init(), init(), init(), init()); while xs.len() >= 8 { p0 = f(p0, xs[0].clone()); p1 = f(p1, xs[1].clone()); p2 = f(p2, xs[2].clone()); p3 = f(p3, xs[3].clone()); p4 = f(p4, xs[4].clone()); p5 = f(p5, xs[5].clone()); p6 = f(p6, xs[6].clone()); p7 = f(p7, xs[7].clone()); xs = &xs[8..]; } acc = f(acc.clone(), f(p0, p4)); acc = f(acc.clone(), f(p1, p5)); acc = f(acc.clone(), f(p2, p6)); acc = f(acc.clone(), f(p3, p7)); // make it clear to the optimizer that this loop is short // and can not be autovectorized. for i in 0..xs.len() { if i >= 7 { break; } acc = f(acc.clone(), xs[i].clone()) } acc }
This can be called like this for a sum:
numeric_util::unrolled_fold(slc, A::zero, A::add)
or like this for a product:
numeric_util::unrolled_fold(slc, A::one, A::mul)
This generates basically the same code as the non-generic version (tested with Compiler Explorer with -C target-cpu=native -C opt-level=3).
Ready for review. I agree that this does not call for a macro, unless unrolled_dot is to be included as well, but I really don't expect a lot more variants to show up that need to be unrolled.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters