Matrix Triangular Solve GPU op by c0g · Pull Request #5010 · tensorflow/tensorflow

@c0g

Adds a GPU version of the triangular solver op using cuBLAS trsm.

@c0g

@tensorflow-jenkins

Can one of the admins verify this patch?

@mention-bot

@c0g c0g mentioned this pull request

Oct 17, 2016

ebrevdo

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please don't change whitespace in lines you aren't otherwise modifying

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please undo these changes.

}

int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final
{

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably need to multiply this by Eigen's AddCost / MulCost to get appropriate estimates

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I messed up the formatting on the original op due to changing white space. Everything in MatrixTriangularSolveOp is original code from TF.

I'm happy to modify the original code to include AddCost/MulCost.

{
double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
double cost = rows * rows * num_rhss;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you doing this in double precision?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the original tensorflow op code (see above). Possibly because double can represent larger numbers than int? I doubt anyone will try a matrix multiply large enough to overflow a 64 bit int though. Happy to change to int64.


if __name__ == "__main__":
tf.test.main()
tf.test.main() No newline at end of file

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

undo whitespace changes

} else {
output.noalias() = triangle.solve(rhs);
}
trans = perftools::gputools::blas::Transpose::kNoTranspose;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use longer variable names instead of 'trans', 'lda', 'ldb', etc?

then in ThenBlasTrsm call, you can do e.g. upper_or_lower /* uplo */, ..., etc.

easier for other users to read.

cublas_m, cublas_n, 1.0, matrix_ptr, lda, &out_ptr,
ldb)
.ok();
// LOG(INFO) << blas_launch_status;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented out code

stream
->ThenBlasTrsm(perftools::gputools::blas::Side::kRight, uplo, trans,
perftools::gputools::blas::Diagonal::kNonUnit,
cublas_m, cublas_n, 1.0, matrix_ptr, lda, &out_ptr,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be Scalar(1) instead of 1.0 (which is a double)

@ebrevdo

Thanks for the PR! A few comments.

ebrevdo

@@ -24,15 +24,17 @@
class MatrixTriangularSolveOpTest(tf.test.TestCase):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebrevdo

Ah in that case let's leave it as double. You're right.

On Oct 17, 2016 9:20 AM, "c0g" notifications@github.com wrote:

@c0g commented on this pull request.

In tensorflow/core/kernels/matrix_triangular_solve_op.cc
#5010:

  • {
  •    Base::ValidateSquareSolver(context, input_matrix_shapes);
    
  • }
  • TensorShapes GetOutputMatrixShapes(
  •    const TensorShapes& input_matrix_shapes) const final
    
  • {
  •    return TensorShapes({ TensorShape({ input_matrix_shapes[0].dim_size(1),
    
  •        input_matrix_shapes[1].dim_size(1) }) });
    
  • }
  • int64 GetCostPerUnit(const TensorShapes& input_matrix_shapes) const final
  • {
  •    double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
    
  •    double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
    
  •    double cost = rows \* rows \* num_rhss;
    

This is the original tensorflow op code (see above). Possibly because
double can represent larger numbers than int? I doubt anyone will try a
matrix multiply large enough to overflow a 64 bit int though. Happy to
change to int64.


You are receiving this because you were assigned.
Reply to this email directly, view it on GitHub
#5010, or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim1ZVajP4Q6Mph62ruIpHz25QrFSYks5q06BhgaJpZM4KYwau
.

@drpngx

Do we still need changes?

@drpngx

Jenkins, test this please.

@c0g

I have not yet made the requested changes.

On 19 Oct 2016, at 07:03, drpngx notifications@github.com wrote:

Do we still need changes?


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub, or mute the thread.

@c0g

@c0g

@ebrevdo : I reverted to the original file to fix white space then made your other requested changes.

@drpngx

I don't see the revert. Please push?

@c0g

@drpngx you're right I have no idea what happened to them. The changes should now be up. Sorry for the lag/commit spam.

@drpngx

Jenkins, test this please.

@alexggmatthews

Seems to be passing now, only looks to be waiting on verification of @ebrevdo requested changes.

ebrevdo

double rows = static_cast<double>(input_matrix_shapes[0].dim_size(0));
double num_rhss = static_cast<double>(input_matrix_shapes[1].dim_size(1));
double cost = rows * rows * num_rhss;
double cost = rows * rows * num_rhss *

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whitespace?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry! Should be gone now.

ebrevdo

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small nit; otherwise LGTM.

@c0g

drpngx


#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, one last thing: we need an extra space before the //

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@c0g

@drpngx

Jenkins, test this please.