[API Compatibility] add `out` parameter to `sqrt` by aquagull · Pull Request #74795 · PaddlePaddle/Paddle
| class TestSqrtOutAPI(unittest.TestCase): | ||
| def test_out_in_dygraph(self): | ||
| paddle.disable_static() | ||
| np.random.seed(2024) | ||
| x = paddle.to_tensor( | ||
| np.random.rand(5, 7).astype('float32'), stop_gradient=False | ||
| ) | ||
|
|
||
| def run_case(case_type): | ||
| out_buf = paddle.zeros_like(x) | ||
| out_buf.stop_gradient = False | ||
|
|
||
| if case_type == 'return': | ||
| y = paddle.sqrt(x) | ||
| elif case_type == 'input_out': | ||
| paddle.sqrt(x, out=out_buf) | ||
| y = out_buf | ||
| elif case_type == 'both_return': | ||
| y = paddle.sqrt(x, out=out_buf) | ||
| elif case_type == 'both_input_out': | ||
| _ = paddle.sqrt(x, out=out_buf) | ||
| y = out_buf | ||
| else: | ||
| raise AssertionError | ||
|
|
||
| ref = paddle._C_ops.sqrt(x) | ||
| np.testing.assert_allclose( | ||
| y.numpy(), ref.numpy(), rtol=1e-6, atol=1e-6 | ||
| ) | ||
|
|
||
| loss = (y * 2).mean() | ||
| loss.backward() | ||
| return y.numpy(), x.grad.numpy() | ||
|
|
||
| # run four scenarios | ||
| y1, g1 = run_case('return') | ||
| x.clear_gradient() | ||
| y2, g2 = run_case('input_out') | ||
| x.clear_gradient() | ||
| y3, g3 = run_case('both_return') | ||
| x.clear_gradient() | ||
| y4, g4 = run_case('both_input_out') | ||
|
|
||
| np.testing.assert_allclose(y1, y2, rtol=1e-6, atol=1e-6) | ||
| np.testing.assert_allclose(y1, y3, rtol=1e-6, atol=1e-6) | ||
| np.testing.assert_allclose(y1, y4, rtol=1e-6, atol=1e-6) | ||
| np.testing.assert_allclose(g1, g2, rtol=1e-6, atol=1e-6) | ||
| np.testing.assert_allclose(g1, g3, rtol=1e-6, atol=1e-6) | ||
| np.testing.assert_allclose(g1, g4, rtol=1e-6, atol=1e-6) | ||
|
|
||
| paddle.enable_static() | ||
|
|
||
|
|