tf.RegisterGradient

A decorator for registering the gradient function for an op type.

View aliases

Compat aliases for migration

See Migration guide for more details.

tf.compat.v1.RegisterGradient

tf.RegisterGradient(
    op_type
)

This decorator is only used when defining a new op type. For an op with m inputs and n outputs, the gradient function is a function that takes the original Operation and n Tensor objects (representing the gradients with respect to each output of the op), and returns m Tensor objects (representing the partial gradients with respect to each input of the op).

For example, assuming that operations of type "Sub" take two inputs x and y, and return a single output x - y, the following gradient function would be registered:

@tf.RegisterGradient("Sub")
def _sub_grad(unused_op, grad):
  return grad, tf.negative(grad)

The decorator argument op_type is the string type of an operation. This corresponds to the OpDef.name field for the proto that defines the operation.

Args

op_type The string type of an operation. This corresponds to the OpDef.name field for the proto that defines the operation.

Raises

TypeError If op_type is not string.

Methods

__call__

View source

__call__(
    f: _T
) -> _T

Registers the function f as gradient function for op_type.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2024-04-26 UTC.