Implement scalar_prod by sebasv · Pull Request #505 · rust-ndarray/ndarray

@sebasv

jturner314

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this! Everything looks good except for a few changes to the test. (See the individual comments.)

Note for my future self: Ordinarily, I would avoid duplication of logic like in unrolled_sum and unrolled_prod, but I think it's fine in this case:

  • I anticipate sum and product being the only cases like this
  • I don't anticipate needing to ever really modify unrolled_sum/unrolled_prod, so we don't have to worry much about keeping things in sync

If we add more than two functions like this that could be combined, though, we should combine them by e.g. taking a closure as a parameter.

@sebasv

  • I anticipate sum and product being the only cases like this

Actually, I realised I would also greatly benefit from scalar_min and scalar_max. Shall I try to write up a macro to cover all four cases?

@jturner314

We could implement scalar_min and scalar_max for A: Ord. However, I'd just do it in terms of fold with something like this (taking advantage of the first method from PR #507):

impl<A, S, D> ArrayBase<S, D>
where
    S: Data<Elem = A>,
    D: Dimension,
{
    /// Returns the minimum element, or `None` if the array is empty.
    fn scalar_min(&self) -> Option<&A>
    where
        A: Ord,
    {
        let first = self.first()?;
        Some(self.fold(first, |acc, x| acc.min(x)))
    }
}

We don't need to manually unroll this because the compiler does a good job automatically (checked with Compiler Explorer using the -O compiler option).

The desired behavior for floating-point types depends on the use-case because of NaN. One option is

arr.fold(::std::f64::NAN, |acc, &x| acc.min(x))

which ignores NaN values. (It returns NaN only if there are no non-NaN values.) The compiler does a decent job automatically unrolling this, so we don't need to manually unroll in this case either.

@jturner314

Will you please squash the commits into one? I don't mind squashing them myself, but then GitHub won't consider the PR merged.

Edit: It looks like you might have given me permission to push to the master branch on sebasv/ndarray since you submitted a PR using that branch? If so, and you don't mind me modifying your master branch, I can squash the commits for you.

(Ordinarily, I would just use GitHub's "Squash and merge", but that option is disabled for this repo, I don't have the permissions to enable it, and I haven't heard from @bluss in a while.)

@sebasv

I'll squash the commits. I also put the unrolled code in a macro, is this desired or do you want to stick with separate unrolled code for prod/sum and possible future cases? Current commit does not have the macro.

// eightfold unrolled so that floating point can be vectorized
// (even with strict floating point accuracy semantics)
macro_rules! unrolled_fold {
    ($xs:expr, $unity:expr, $operation:expr) => {{
        let mut collected = $unity();
        let (mut p0, mut p1, mut p2, mut p3,
            mut p4, mut p5, mut p6, mut p7) =
            ($unity(), $unity(), $unity(), $unity(),
            $unity(), $unity(), $unity(), $unity());
        while $xs.len() >= 8 {
            p0 = $operation(p0, $xs[0].clone());
            p1 = $operation(p1, $xs[1].clone());
            p2 = $operation(p2, $xs[2].clone());
            p3 = $operation(p3, $xs[3].clone());
            p4 = $operation(p4, $xs[4].clone());
            p5 = $operation(p5, $xs[5].clone());
            p6 = $operation(p6, $xs[6].clone());
            p7 = $operation(p7, $xs[7].clone());

            $xs = &$xs[8..];
        }
        collected = $operation(collected.clone(), $operation(p0, p4));
        collected = $operation(collected.clone(), $operation(p1, p5));
        collected = $operation(collected.clone(), $operation(p2, p6));
        collected = $operation(collected.clone(), $operation(p3, p7));

        // make it clear to the optimizer that this loop is short
        // and can not be autovectorized.
        for i in 0..$xs.len() {
            if i >= 7 { break; }
            collected = $operation(collected.clone(), $xs[i].clone());
        }
        collected
    }}
}

/// Compute the sum of the values in `xs`
pub fn unrolled_sum<A>(mut xs: &[A]) -> A
    where A: Clone + Add<Output=A> + libnum::Zero,
{
unrolled_fold!(xs, A::zero, A::add)
}

/// Compute the product of the values in `xs`
pub fn unrolled_prod<A>(mut xs: &[A]) -> A
    where A: Clone + Mul<Output=A> + libnum::One,
{
    unrolled_fold!(xs, A::one, A::mul)
}

@jturner314

Sure, a macro would be nice. By the way, I just noticed that the temporary variable in scalar_prod is named sum when it would be better named prod.

@jturner314

Fwiw, I prefer using generic functions over macros when possible. For example:

pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A
where
    A: Clone,
    I: Fn() -> A,
    F: Fn(A, A) -> A,
{
    // eightfold unrolled so that floating point can be vectorized
    // (even with strict floating point accuracy semantics)
    let mut acc = init();
    let (mut p0, mut p1, mut p2, mut p3,
         mut p4, mut p5, mut p6, mut p7) =
        (init(), init(), init(), init(),
         init(), init(), init(), init());
    while xs.len() >= 8 {
        p0 = f(p0, xs[0].clone());
        p1 = f(p1, xs[1].clone());
        p2 = f(p2, xs[2].clone());
        p3 = f(p3, xs[3].clone());
        p4 = f(p4, xs[4].clone());
        p5 = f(p5, xs[5].clone());
        p6 = f(p6, xs[6].clone());
        p7 = f(p7, xs[7].clone());

        xs = &xs[8..];
    }
    acc = f(acc.clone(), f(p0, p4));
    acc = f(acc.clone(), f(p1, p5));
    acc = f(acc.clone(), f(p2, p6));
    acc = f(acc.clone(), f(p3, p7));

    // make it clear to the optimizer that this loop is short
    // and can not be autovectorized.
    for i in 0..xs.len() {
        if i >= 7 { break; }
        acc = f(acc.clone(), xs[i].clone())
    }
    acc
}

This can be called like this for a sum:

numeric_util::unrolled_fold(slc, A::zero, A::add)

or like this for a product:

numeric_util::unrolled_fold(slc, A::one, A::mul)

This generates basically the same code as the non-generic version (tested with Compiler Explorer with -C target-cpu=native -C opt-level=3).

@sebasv

@sebasv

Ready for review. I agree that this does not call for a macro, unless unrolled_dot is to be included as well, but I really don't expect a lot more variants to show up that need to be unrolled.

@jturner314

Thanks for contributing this!

@sebasv

Thank you for the guidance! I am learning a ton more about safety and optimization.