tf.tensor_scatter_nd_add

Adds sparse updates to an existing tensor according to indices.

View aliases

Compat aliases for migration

See Migration guide for more details.

tf.compat.v1.tensor_scatter_add, tf.compat.v1.tensor_scatter_nd_add

tf.tensor_scatter_nd_add(
    tensor: Annotated[Any, TV_TensorScatterAdd_T],
    indices: Annotated[Any, TV_TensorScatterAdd_Tindices],
    updates: Annotated[Any, TV_TensorScatterAdd_T],
    name=None
) -> Annotated[Any, TV_TensorScatterAdd_T]

Used in the notebooks

Used in the guide

This operation creates a new tensor by adding sparse updates to the passed in tensor. This operation is very similar to tf.compat.v1.scatter_nd_add, except that the updates are added onto an existing tensor (as opposed to a variable). If the memory for the existing tensor cannot be re-used, a copy is made and updated.

indices is an integer tensor containing indices into a new tensor of shape tensor.shape. The last dimension of indices can be at most the rank of tensor.shape:

indices.shape[-1] <= tensor.shape.rank

The last dimension of indices corresponds to indices into elements (if indices.shape[-1] = tensor.shape.rank) or slices (if indices.shape[-1] < tensor.shape.rank) along dimension indices.shape[-1] of tensor.shape. updates is a tensor with shape

indices.shape[:-1] + tensor.shape[indices.shape[-1]:]

The simplest form of tensor_scatter_nd_add is to add individual elements to a tensor by index. For example, say we want to add 4 elements in a rank-1 tensor with 8 elements.

In Python, this scatter add operation would look like this:

indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
tensor = tf.ones([8], dtype=tf.int32)
updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
updated
<tf.Tensor: shape=(8,), dtype=int32,
numpy=array([ 1, 12,  1, 11, 10,  1,  1, 13], dtype=int32)>

We can also, insert entire slices of a higher rank tensor all at once. For example, if we wanted to insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.

In Python, this scatter add operation would look like this:

indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]],
                       [[5, 5, 5, 5], [6, 6, 6, 6],
                        [7, 7, 7, 7], [8, 8, 8, 8]]])
tensor = tf.ones([4, 4, 4],dtype=tf.int32)
updated = tf.tensor_scatter_nd_add(tensor, indices, updates)
updated
<tf.Tensor: shape=(4, 4, 4), dtype=int32,
numpy=array([[[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
             [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]],
             [[6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8], [9, 9, 9, 9]],
             [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]], dtype=int32)>

Args

tensor A Tensor. Tensor to copy/update.
indices A Tensor. Must be one of the following types: int32, int64. Index tensor.
updates A Tensor. Must have the same type as tensor. Updates to scatter into output.
name A name for the operation (optional).

Returns

A Tensor. Has the same type as tensor.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2024-04-26 UTC.