Skip to content

Commit

Permalink
#14930: Remove unnecessary usage of creation ops
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Nov 12, 2024
1 parent 9db5fb5 commit becbf96
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ std::vector<ComplexTensor> _polar_bw(const ComplexTensor& grad, const ComplexTen
std::vector<ComplexTensor> grad_tensor;
ComplexTensor result = ttnn::polar(input, output_mem_config);
Tensor abs_result = ttnn::abs(result, output_mem_config);
Tensor sgn_result_r = ttnn::where(ttnn::eqz(abs_result, output_mem_config), ttnn::zeros_like(result.real(), result.real().get_dtype(), result.real().get_layout(), std::nullopt, output_mem_config), ttnn::multiply(result.real(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config );
Tensor sgn_result_i = ttnn::where(ttnn::eqz(abs_result, output_mem_config), ttnn::zeros_like(result.imag(), result.imag().get_dtype(), result.imag().get_layout(), std::nullopt, output_mem_config), ttnn::multiply(result.imag(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config );
Tensor sgn_result_r = ttnn::where(ttnn::eqz(abs_result, output_mem_config), 0.0f, ttnn::multiply(result.real(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config );
Tensor sgn_result_i = ttnn::where(ttnn::eqz(abs_result, output_mem_config), 0.0f, ttnn::multiply(result.imag(), ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config), output_mem_config );
abs_result.deallocate();
ComplexTensor sgn_result = ComplexTensor({ sgn_result_r, sgn_result_i });
sgn_result_r.deallocate();
sgn_result_i.deallocate();
Tensor grad_abs = ttnn::real(ttnn::operations::complex_binary::_mul(ttnn::conj(grad, output_mem_config), sgn_result, output_mem_config), output_mem_config);
sgn_result.deallocate();
ComplexTensor flip_tensor = ComplexTensor({ttnn::zeros_like(input.real(), input.real().get_dtype(), input.real().get_layout(), std::nullopt, output_mem_config), ttnn::full_like(input.imag(), 1.0f) });
ComplexTensor flip_tensor = ComplexTensor({ttnn::zeros_like(input.real(), input.real().get_dtype(), input.real().get_layout(), std::nullopt, output_mem_config), ttnn::ones_like(input.imag()) });
Tensor grad_angle = ttnn::real(ttnn::operations::complex_binary::_mul(ttnn::conj(grad, output_mem_config), ttnn::operations::complex_binary::_mul(result, flip_tensor, output_mem_config), output_mem_config), output_mem_config);
result.deallocate();
flip_tensor.deallocate();
Expand Down Expand Up @@ -74,8 +74,8 @@ std::vector<ComplexTensor> _angle_bw(const Tensor& grad, const ComplexTensor& in
const Tensor &inp_i = input.imag();
Tensor condition_zero = ttnn::logical_and(ttnn::eqz(input.real(),output_mem_config), ttnn::eqz(input.imag(),output_mem_config), std::nullopt, output_mem_config);
Tensor abs_squared = ttnn::reciprocal(ttnn::add(ttnn::square(inp_r, output_mem_config), ttnn::square(inp_i, output_mem_config), std::nullopt, output_mem_config), output_mem_config);
Tensor res_real = ttnn::where(condition_zero, ttnn::zeros_like(inp_r, inp_r.get_dtype(), inp_r.get_layout(), std::nullopt, output_mem_config), ttnn::multiply(grad, ttnn::multiply(ttnn::neg(inp_i, output_mem_config), abs_squared, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config);
Tensor res_imag = ttnn::where(condition_zero, ttnn::zeros_like(inp_i, inp_i.get_dtype(), inp_i.get_layout(), std::nullopt, output_mem_config), ttnn::multiply(grad, ttnn::multiply(inp_r, abs_squared, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config);
Tensor res_real = ttnn::where(condition_zero, 0.0f, ttnn::multiply(grad, ttnn::multiply(ttnn::neg(inp_i, output_mem_config), abs_squared, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config);
Tensor res_imag = ttnn::where(condition_zero, 0.0f, ttnn::multiply(grad, ttnn::multiply(inp_r, abs_squared, std::nullopt, output_mem_config), std::nullopt, output_mem_config), output_mem_config);
condition_zero.deallocate();
abs_squared.deallocate();
ComplexTensor grad_result = ComplexTensor({res_real, res_imag});
Expand All @@ -99,8 +99,8 @@ std::vector<ComplexTensor> _conj_bw(const ComplexTensor& grad, const ComplexTens
std::vector<ComplexTensor> _complex_abs_bw(const Tensor& grad, const ComplexTensor& input, const MemoryConfig& output_mem_config) {
std::vector<ComplexTensor> grad_tensor;
Tensor result = ttnn::abs(input, output_mem_config);
Tensor grad_inp_r = ttnn::where(ttnn::eqz(result, output_mem_config), ttnn::zeros_like(result, result.get_dtype(), result.get_layout(), std::nullopt, output_mem_config), ttnn::multiply(grad, ttnn::multiply(input.real(), ttnn::reciprocal(result, output_mem_config), std::nullopt, output_mem_config),std::nullopt, output_mem_config), output_mem_config );
Tensor grad_inp_i = ttnn::where(ttnn::eqz(result, output_mem_config), ttnn::zeros_like(result, result.get_dtype(), result.get_layout(), std::nullopt, output_mem_config), ttnn::multiply(grad, ttnn::multiply(input.imag(), ttnn::reciprocal(result, output_mem_config), std::nullopt, output_mem_config),std::nullopt, output_mem_config), output_mem_config );
Tensor grad_inp_r = ttnn::where(ttnn::eqz(result, output_mem_config), 0.0f, ttnn::multiply(grad, ttnn::multiply(input.real(), ttnn::reciprocal(result, output_mem_config), std::nullopt, output_mem_config),std::nullopt, output_mem_config), output_mem_config );
Tensor grad_inp_i = ttnn::where(ttnn::eqz(result, output_mem_config), 0.0f, ttnn::multiply(grad, ttnn::multiply(input.imag(), ttnn::reciprocal(result, output_mem_config), std::nullopt, output_mem_config),std::nullopt, output_mem_config), output_mem_config );
ComplexTensor grad_inp = ComplexTensor({ grad_inp_r, grad_inp_i});
result.deallocate();
grad_inp_r.deallocate();
Expand Down

0 comments on commit becbf96

Please sign in to comment.