diff --git a/CHANGELOG.md b/CHANGELOG.md index bfbff068..e6903d40 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Changed -## v0.16.0 +## v0.18.0 +### Changed +- PyTorch v2.5 support + +## v0.17.0 ### Changed - PyTorch v2.4 support diff --git a/Cargo.toml b/Cargo.toml index f0f633db..11736027 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tch" -version = "0.17.0" +version = "0.18.0" authors = ["Laurent Mazare "] edition = "2021" build = "build.rs" @@ -22,7 +22,7 @@ libc = "0.2.0" ndarray = "0.15" rand = "0.8" thiserror = "1" -torch-sys = { version = "0.17.0", path = "torch-sys" } +torch-sys = { version = "0.18.0", path = "torch-sys" } zip = "0.6" half = "2" safetensors = "0.3.0" diff --git a/README.md b/README.md index f6786754..b815dc7e 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ The code generation part for the C api on top of libtorch comes from ## Getting Started -This crate requires the C++ PyTorch library (libtorch) in version *v2.4.1* to be available on +This crate requires the C++ PyTorch library (libtorch) in version *v2.5.0* to be available on your system. You can either: - Use the system-wide libtorch installation (default). @@ -85,7 +85,7 @@ seem to include `libtorch.a` by default so this would have to be compiled manually, e.g. via the following: ```bash -git clone -b v2.4.1 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1 +git clone -b v2.5.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1 cd pytorch-static USE_CUDA=OFF BUILD_SHARED_LIBS=OFF python setup.py build # export LIBTORCH to point at the build directory in pytorch-static. diff --git a/examples/python-extension/Cargo.toml b/examples/python-extension/Cargo.toml index 8fd38bab..989b71e8 100644 --- a/examples/python-extension/Cargo.toml +++ b/examples/python-extension/Cargo.toml @@ -18,6 +18,6 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.21", features = ["extension-module"] } -pyo3-tch = { path = "../../pyo3-tch", version = "0.17.0" } -tch = { path = "../..", features = ["python-extension"], version = "0.17.0" } -torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.17.0" } +pyo3-tch = { path = "../../pyo3-tch", version = "0.18.0" } +tch = { path = "../..", features = ["python-extension"], version = "0.18.0" } +torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.18.0" } diff --git a/gen/gen.ml b/gen/gen.ml index e08306e9..c5cf4181 100644 --- a/gen/gen.ml +++ b/gen/gen.ml @@ -882,7 +882,7 @@ let run let () = run - ~yaml_filename:"third_party/pytorch/Declarations-v2.4.0.yaml" + ~yaml_filename:"third_party/pytorch/Declarations-v2.5.0.yaml" ~cpp_filename:"torch-sys/libtch/torch_api_generated" ~ffi_filename:"torch-sys/src/c_generated.rs" ~wrapper_filename:"src/wrappers/tensor_generated.rs" diff --git a/pyo3-tch/Cargo.toml b/pyo3-tch/Cargo.toml index 67d3352d..34e24c07 100644 --- a/pyo3-tch/Cargo.toml +++ b/pyo3-tch/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pyo3-tch" -version = "0.17.0" +version = "0.18.0" authors = ["Laurent Mazare "] edition = "2021" build = "build.rs" @@ -12,6 +12,6 @@ categories = ["science"] license = "MIT/Apache-2.0" [dependencies] -tch = { path = "..", features = ["python-extension"], version = "0.17.0" } -torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.17.0" } +tch = { path = "..", features = ["python-extension"], version = "0.18.0" } +torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.18.0" } pyo3 = { version = "0.21", features = ["extension-module"] } diff --git a/src/wrappers/tensor_fallible_generated.rs b/src/wrappers/tensor_fallible_generated.rs index dfbd6118..a252d204 100644 --- a/src/wrappers/tensor_fallible_generated.rs +++ b/src/wrappers/tensor_fallible_generated.rs @@ -3176,6 +3176,7 @@ impl Tensor { dropout_p: f64, is_causal: bool, scale: impl Into>, + enable_gqa: bool, ) -> Result { let scale = scale.into(); let return_; @@ -3188,7 +3189,8 @@ impl Tensor { dropout_p, if is_causal { 1 } else { 0 }, scale.unwrap_or(std::f64::NAN), - scale.is_none() as i8 + scale.is_none() as i8, + if enable_gqa { 1 } else { 0 } ) ); Ok(return_) @@ -4815,6 +4817,18 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } + pub fn f_internal_nested_get_max_seqlen(&self) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__nested_get_max_seqlen(c_tensors.as_mut_ptr(), self.c_tensor)); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + + pub fn f_internal_nested_get_min_seqlen(&self) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__nested_get_min_seqlen(c_tensors.as_mut_ptr(), self.c_tensor)); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_internal_nested_get_offsets(&self) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__nested_get_offsets(c_tensors.as_mut_ptr(), self.c_tensor)); @@ -4943,6 +4957,8 @@ impl Tensor { dummy: &Tensor, lengths: Option, ragged_idx: i64, + min_seqlen: Option, + max_seqlen: Option, ) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__nested_view_from_jagged( @@ -4951,7 +4967,9 @@ impl Tensor { offsets.c_tensor, dummy.c_tensor, lengths.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), - ragged_idx + ragged_idx, + min_seqlen.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + max_seqlen.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor) )); Ok(Tensor { c_tensor: c_tensors[0] }) } @@ -4962,6 +4980,8 @@ impl Tensor { dummy: &Tensor, lengths: Option, ragged_idx: i64, + min_seqlen: Option, + max_seqlen: Option, ) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__nested_view_from_jagged_copy( @@ -4970,7 +4990,9 @@ impl Tensor { offsets.c_tensor, dummy.c_tensor, lengths.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), - ragged_idx + ragged_idx, + min_seqlen.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + max_seqlen.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor) )); Ok(Tensor { c_tensor: c_tensors[0] }) } @@ -4982,6 +5004,8 @@ impl Tensor { dummy: &Tensor, lengths: Option, ragged_idx: i64, + min_seqlen: Option, + max_seqlen: Option, ) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__nested_view_from_jagged_copy_out( @@ -4991,7 +5015,9 @@ impl Tensor { offsets.c_tensor, dummy.c_tensor, lengths.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), - ragged_idx + ragged_idx, + min_seqlen.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + max_seqlen.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor) )); Ok(Tensor { c_tensor: c_tensors[0] }) } @@ -5434,6 +5460,21 @@ impl Tensor { Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) } + pub fn f_internal_safe_softmax( + &self, + dim: i64, + dtype: impl Into>, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__safe_softmax( + c_tensors.as_mut_ptr(), + self.c_tensor, + dim, + dtype.into().map_or(-1, |s| s.c_int()) + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_internal_sample_dirichlet(&self) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__sample_dirichlet(c_tensors.as_mut_ptr(), self.c_tensor)); @@ -5465,10 +5506,39 @@ impl Tensor { is_causal: bool, dropout_mask: Option, scale: impl Into>, + enable_gqa: bool, ) -> Result<(Tensor, Tensor), TchError> { let scale = scale.into(); let mut c_tensors = [std::ptr::null_mut(); 2]; unsafe_torch_err!(atg__scaled_dot_product_attention_math( + c_tensors.as_mut_ptr(), + query.c_tensor, + key.c_tensor, + value.c_tensor, + attn_mask.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + dropout_p, + if is_causal { 1 } else { 0 }, + dropout_mask.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + scale.unwrap_or(std::f64::NAN), + scale.is_none() as i8, + if enable_gqa { 1 } else { 0 } + )); + Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) + } + + pub fn f_internal_scaled_dot_product_attention_math_for_mps>( + query: &Tensor, + key: &Tensor, + value: &Tensor, + attn_mask: Option, + dropout_p: f64, + is_causal: bool, + dropout_mask: Option, + scale: impl Into>, + ) -> Result<(Tensor, Tensor), TchError> { + let scale = scale.into(); + let mut c_tensors = [std::ptr::null_mut(); 2]; + unsafe_torch_err!(atg__scaled_dot_product_attention_math_for_mps( c_tensors.as_mut_ptr(), query.c_tensor, key.c_tensor, @@ -5490,14 +5560,15 @@ impl Tensor { value: &Tensor, out: &Tensor, logsumexp: &Tensor, + philox_seed: &Tensor, + philox_offset: &Tensor, + attn_bias: &Tensor, cum_seq_q: &Tensor, cum_seq_k: &Tensor, max_q: i64, max_k: i64, dropout_p: f64, is_causal: bool, - philox_seed: &Tensor, - philox_offset: &Tensor, scale: impl Into>, ) -> Result<(Tensor, Tensor, Tensor), TchError> { let scale = scale.into(); @@ -5510,14 +5581,15 @@ impl Tensor { value.c_tensor, out.c_tensor, logsumexp.c_tensor, + philox_seed.c_tensor, + philox_offset.c_tensor, + attn_bias.c_tensor, cum_seq_q.c_tensor, cum_seq_k.c_tensor, max_q, max_k, dropout_p, if is_causal { 1 } else { 0 }, - philox_seed.c_tensor, - philox_offset.c_tensor, scale.unwrap_or(std::f64::NAN), scale.is_none() as i8 )); @@ -5668,55 +5740,53 @@ impl Tensor { pub fn f_internal_scaled_mm>( &self, mat2: &Tensor, + scale_a: &Tensor, + scale_b: &Tensor, bias: Option, - out_dtype: impl Into>, - scale_a: Option, - scale_b: Option, scale_result: Option, + out_dtype: impl Into>, use_fast_accum: bool, - ) -> Result<(Tensor, Tensor), TchError> { - let mut c_tensors = [std::ptr::null_mut(); 2]; + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__scaled_mm( c_tensors.as_mut_ptr(), self.c_tensor, mat2.c_tensor, + scale_a.c_tensor, + scale_b.c_tensor, bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), - out_dtype.into().map_or(-1, |s| s.c_int()), - scale_a.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), - scale_b.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), scale_result.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + out_dtype.into().map_or(-1, |s| s.c_int()), if use_fast_accum { 1 } else { 0 } )); - Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) + Ok(Tensor { c_tensor: c_tensors[0] }) } pub fn f_internal_scaled_mm_out>( &self, out: &Tensor, - out_amax: &Tensor, mat2: &Tensor, + scale_a: &Tensor, + scale_b: &Tensor, bias: Option, - out_dtype: impl Into>, - scale_a: Option, - scale_b: Option, scale_result: Option, + out_dtype: impl Into>, use_fast_accum: bool, - ) -> Result<(Tensor, Tensor), TchError> { - let mut c_tensors = [std::ptr::null_mut(); 2]; + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__scaled_mm_out( c_tensors.as_mut_ptr(), out.c_tensor, - out_amax.c_tensor, self.c_tensor, mat2.c_tensor, + scale_a.c_tensor, + scale_b.c_tensor, bias.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), - out_dtype.into().map_or(-1, |s| s.c_int()), - scale_a.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), - scale_b.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), scale_result.as_ref().map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor), + out_dtype.into().map_or(-1, |s| s.c_int()), if use_fast_accum { 1 } else { 0 } )); - Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) + Ok(Tensor { c_tensor: c_tensors[0] }) } pub fn f_internal_scatter_reduce( @@ -6877,6 +6947,17 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } + pub fn f_internal_spsolve(a: &Tensor, b: &Tensor, left: bool) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__spsolve( + c_tensors.as_mut_ptr(), + a.c_tensor, + b.c_tensor, + if left { 1 } else { 0 } + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_internal_stack>( tensors: &[T], dim: i64, @@ -7946,6 +8027,42 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } + pub fn f_internal_unsafe_masked_index, S: Into>( + &self, + mask: &Tensor, + indices: &[Option], + fill: S, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__unsafe_masked_index( + c_tensors.as_mut_ptr(), + self.c_tensor, + mask.c_tensor, + ptr_list_opt(indices).as_ptr(), + indices.len() as i32, + fill.into().c_scalar + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + + pub fn f_internal_unsafe_masked_index_put_accumulate>( + &self, + mask: &Tensor, + indices: &[Option], + values: &Tensor, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__unsafe_masked_index_put_accumulate( + c_tensors.as_mut_ptr(), + self.c_tensor, + mask.c_tensor, + ptr_list_opt(indices).as_ptr(), + indices.len() as i32, + values.c_tensor + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_internal_unsafe_view(&self, size: impl IntList) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg__unsafe_view( @@ -8878,6 +8995,46 @@ impl Tensor { Ok((Tensor { c_tensor: c_tensors[0] }, Tensor { c_tensor: c_tensors[1] })) } + pub fn f_internal_wrapped_linear_prepack( + weight: &Tensor, + weight_scale: &Tensor, + weight_zero_point: &Tensor, + bias: &Tensor, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__wrapped_linear_prepack( + c_tensors.as_mut_ptr(), + weight.c_tensor, + weight_scale.c_tensor, + weight_zero_point.c_tensor, + bias.c_tensor + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + + pub fn f_internal_wrapped_quantized_linear_prepacked( + &self, + input_scale: &Tensor, + input_zero_point: &Tensor, + packed_weight: &Tensor, + output_scale: &Tensor, + output_zero_point: &Tensor, + out_channel: i64, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg__wrapped_quantized_linear_prepacked( + c_tensors.as_mut_ptr(), + self.c_tensor, + input_scale.c_tensor, + input_zero_point.c_tensor, + packed_weight.c_tensor, + output_scale.c_tensor, + output_zero_point.c_tensor, + out_channel + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_abs(&self) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg_abs(c_tensors.as_mut_ptr(), self.c_tensor)); @@ -24066,6 +24223,21 @@ impl Tensor { Ok(Tensor { c_tensor: c_tensors[0] }) } + pub fn f_mean_dtype_out( + &self, + out: &Tensor, + dtype: impl Into>, + ) -> Result { + let mut c_tensors = [std::ptr::null_mut(); 1]; + unsafe_torch_err!(atg_mean_dtype_out( + c_tensors.as_mut_ptr(), + out.c_tensor, + self.c_tensor, + dtype.into().map_or(-1, |s| s.c_int()) + )); + Ok(Tensor { c_tensor: c_tensors[0] }) + } + pub fn f_mean_out( &self, out: &Tensor, @@ -27609,6 +27781,7 @@ impl Tensor { sequences: &[T], batch_first: bool, padding_value: f64, + padding_side: &str, ) -> Result { let mut c_tensors = [std::ptr::null_mut(); 1]; unsafe_torch_err!(atg_pad_sequence( @@ -27616,7 +27789,9 @@ impl Tensor { ptr_list(sequences).as_ptr(), sequences.len() as i32, if batch_first { 1 } else { 0 }, - padding_value + padding_value, + padding_side.as_ptr(), + padding_side.len() as i32 )); Ok(Tensor { c_tensor: c_tensors[0] }) } @@ -30447,6 +30622,7 @@ impl Tensor { dropout_p: f64, is_causal: bool, scale: impl Into>, + enable_gqa: bool, ) -> Result { let scale = scale.into(); let mut c_tensors = [std::ptr::null_mut(); 1]; @@ -30459,7 +30635,8 @@ impl Tensor { dropout_p, if is_causal { 1 } else { 0 }, scale.unwrap_or(std::f64::NAN), - scale.is_none() as i8 + scale.is_none() as i8, + if enable_gqa { 1 } else { 0 } )); Ok(Tensor { c_tensor: c_tensors[0] }) } diff --git a/src/wrappers/tensor_generated.rs b/src/wrappers/tensor_generated.rs index 5d73be2f..6c0c9f22 100644 --- a/src/wrappers/tensor_generated.rs +++ b/src/wrappers/tensor_generated.rs @@ -2034,9 +2034,10 @@ impl Tensor { dropout_p: f64, is_causal: bool, scale: impl Into>, + enable_gqa: bool, ) -> i64 { Tensor::f_internal_fused_sdp_choice( - query, key, value, attn_mask, dropout_p, is_causal, scale, + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, ) .unwrap() } @@ -2944,6 +2945,14 @@ impl Tensor { self.f_internal_nested_get_lengths().unwrap() } + pub fn internal_nested_get_max_seqlen(&self) -> Tensor { + self.f_internal_nested_get_max_seqlen().unwrap() + } + + pub fn internal_nested_get_min_seqlen(&self) -> Tensor { + self.f_internal_nested_get_min_seqlen().unwrap() + } + pub fn internal_nested_get_offsets(&self) -> Tensor { self.f_internal_nested_get_offsets().unwrap() } @@ -3017,8 +3026,13 @@ impl Tensor { dummy: &Tensor, lengths: Option, ragged_idx: i64, + min_seqlen: Option, + max_seqlen: Option, ) -> Tensor { - self.f_internal_nested_view_from_jagged(offsets, dummy, lengths, ragged_idx).unwrap() + self.f_internal_nested_view_from_jagged( + offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen, + ) + .unwrap() } pub fn internal_nested_view_from_jagged_copy>( @@ -3027,8 +3041,13 @@ impl Tensor { dummy: &Tensor, lengths: Option, ragged_idx: i64, + min_seqlen: Option, + max_seqlen: Option, ) -> Tensor { - self.f_internal_nested_view_from_jagged_copy(offsets, dummy, lengths, ragged_idx).unwrap() + self.f_internal_nested_view_from_jagged_copy( + offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen, + ) + .unwrap() } pub fn internal_nested_view_from_jagged_copy_out>( @@ -3038,9 +3057,13 @@ impl Tensor { dummy: &Tensor, lengths: Option, ragged_idx: i64, + min_seqlen: Option, + max_seqlen: Option, ) -> Tensor { - self.f_internal_nested_view_from_jagged_copy_out(out, offsets, dummy, lengths, ragged_idx) - .unwrap() + self.f_internal_nested_view_from_jagged_copy_out( + out, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen, + ) + .unwrap() } pub fn internal_new_zeros_with_same_feature_meta( @@ -3244,6 +3267,10 @@ impl Tensor { Tensor::f_internal_rowwise_prune(weight, mask, compressed_indices_dtype).unwrap() } + pub fn internal_safe_softmax(&self, dim: i64, dtype: impl Into>) -> Tensor { + self.f_internal_safe_softmax(dim, dtype).unwrap() + } + pub fn internal_sample_dirichlet(&self) -> Tensor { self.f_internal_sample_dirichlet().unwrap() } @@ -3265,6 +3292,7 @@ impl Tensor { is_causal: bool, dropout_mask: Option, scale: impl Into>, + enable_gqa: bool, ) -> (Tensor, Tensor) { Tensor::f_internal_scaled_dot_product_attention_math( query, @@ -3275,6 +3303,30 @@ impl Tensor { is_causal, dropout_mask, scale, + enable_gqa, + ) + .unwrap() + } + + pub fn internal_scaled_dot_product_attention_math_for_mps>( + query: &Tensor, + key: &Tensor, + value: &Tensor, + attn_mask: Option, + dropout_p: f64, + is_causal: bool, + dropout_mask: Option, + scale: impl Into>, + ) -> (Tensor, Tensor) { + Tensor::f_internal_scaled_dot_product_attention_math_for_mps( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + dropout_mask, + scale, ) .unwrap() } @@ -3286,14 +3338,15 @@ impl Tensor { value: &Tensor, out: &Tensor, logsumexp: &Tensor, + philox_seed: &Tensor, + philox_offset: &Tensor, + attn_bias: &Tensor, cum_seq_q: &Tensor, cum_seq_k: &Tensor, max_q: i64, max_k: i64, dropout_p: f64, is_causal: bool, - philox_seed: &Tensor, - philox_offset: &Tensor, scale: impl Into>, ) -> (Tensor, Tensor, Tensor) { Tensor::f_internal_scaled_dot_product_cudnn_attention_backward( @@ -3303,14 +3356,15 @@ impl Tensor { value, out, logsumexp, + philox_seed, + philox_offset, + attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, - philox_seed, - philox_offset, scale, ) .unwrap() @@ -3412,20 +3466,20 @@ impl Tensor { pub fn internal_scaled_mm>( &self, mat2: &Tensor, + scale_a: &Tensor, + scale_b: &Tensor, bias: Option, - out_dtype: impl Into>, - scale_a: Option, - scale_b: Option, scale_result: Option, + out_dtype: impl Into>, use_fast_accum: bool, - ) -> (Tensor, Tensor) { + ) -> Tensor { self.f_internal_scaled_mm( mat2, - bias, - out_dtype, scale_a, scale_b, + bias, scale_result, + out_dtype, use_fast_accum, ) .unwrap() @@ -3434,24 +3488,22 @@ impl Tensor { pub fn internal_scaled_mm_out>( &self, out: &Tensor, - out_amax: &Tensor, mat2: &Tensor, + scale_a: &Tensor, + scale_b: &Tensor, bias: Option, - out_dtype: impl Into>, - scale_a: Option, - scale_b: Option, scale_result: Option, + out_dtype: impl Into>, use_fast_accum: bool, - ) -> (Tensor, Tensor) { + ) -> Tensor { self.f_internal_scaled_mm_out( out, - out_amax, mat2, - bias, - out_dtype, scale_a, scale_b, + bias, scale_result, + out_dtype, use_fast_accum, ) .unwrap() @@ -4057,6 +4109,10 @@ impl Tensor { Tensor::f_internal_spdiags_out(out, diagonals, offsets, shape, layout).unwrap() } + pub fn internal_spsolve(a: &Tensor, b: &Tensor, left: bool) -> Tensor { + Tensor::f_internal_spsolve(a, b, left).unwrap() + } + pub fn internal_stack>(tensors: &[T], dim: i64) -> Tensor { Tensor::f_internal_stack(tensors, dim).unwrap() } @@ -4582,6 +4638,24 @@ impl Tensor { self.f_internal_unsafe_index_put(indices, values, accumulate).unwrap() } + pub fn internal_unsafe_masked_index, S: Into>( + &self, + mask: &Tensor, + indices: &[Option], + fill: S, + ) -> Tensor { + self.f_internal_unsafe_masked_index(mask, indices, fill).unwrap() + } + + pub fn internal_unsafe_masked_index_put_accumulate>( + &self, + mask: &Tensor, + indices: &[Option], + values: &Tensor, + ) -> Tensor { + self.f_internal_unsafe_masked_index_put_accumulate(mask, indices, values).unwrap() + } + pub fn internal_unsafe_view(&self, size: impl IntList) -> Tensor { self.f_internal_unsafe_view(size).unwrap() } @@ -5154,6 +5228,36 @@ impl Tensor { Tensor::f_internal_weight_norm_interface_out(out0, out1, v, g, dim).unwrap() } + pub fn internal_wrapped_linear_prepack( + weight: &Tensor, + weight_scale: &Tensor, + weight_zero_point: &Tensor, + bias: &Tensor, + ) -> Tensor { + Tensor::f_internal_wrapped_linear_prepack(weight, weight_scale, weight_zero_point, bias) + .unwrap() + } + + pub fn internal_wrapped_quantized_linear_prepacked( + &self, + input_scale: &Tensor, + input_zero_point: &Tensor, + packed_weight: &Tensor, + output_scale: &Tensor, + output_zero_point: &Tensor, + out_channel: i64, + ) -> Tensor { + self.f_internal_wrapped_quantized_linear_prepacked( + input_scale, + input_zero_point, + packed_weight, + output_scale, + output_zero_point, + out_channel, + ) + .unwrap() + } + pub fn abs(&self) -> Tensor { self.f_abs().unwrap() } @@ -12322,6 +12426,10 @@ impl Tensor { self.f_mean_dim(dim, keepdim, dtype).unwrap() } + pub fn mean_dtype_out(&self, out: &Tensor, dtype: impl Into>) -> Tensor { + self.f_mean_dtype_out(out, dtype).unwrap() + } + pub fn mean_out( &self, out: &Tensor, @@ -14254,8 +14362,9 @@ impl Tensor { sequences: &[T], batch_first: bool, padding_value: f64, + padding_side: &str, ) -> Tensor { - Tensor::f_pad_sequence(sequences, batch_first, padding_value).unwrap() + Tensor::f_pad_sequence(sequences, batch_first, padding_value, padding_side).unwrap() } pub fn pairwise_distance(x1: &Tensor, x2: &Tensor, p: f64, eps: f64, keepdim: bool) -> Tensor { @@ -15583,9 +15692,10 @@ impl Tensor { dropout_p: f64, is_causal: bool, scale: impl Into>, + enable_gqa: bool, ) -> Tensor { Tensor::f_scaled_dot_product_attention( - query, key, value, attn_mask, dropout_p, is_causal, scale, + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, ) .unwrap() } diff --git a/torch-sys/Cargo.toml b/torch-sys/Cargo.toml index 7b2050b9..ef7f1604 100644 --- a/torch-sys/Cargo.toml +++ b/torch-sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "torch-sys" -version = "0.17.0" +version = "0.18.0" authors = ["Laurent Mazare "] edition = "2021" build = "build.rs" diff --git a/torch-sys/build.rs b/torch-sys/build.rs index 753b585d..dacc5405 100644 --- a/torch-sys/build.rs +++ b/torch-sys/build.rs @@ -10,7 +10,7 @@ use anyhow::{Context, Result}; use std::path::{Path, PathBuf}; use std::{env, fs, io}; -const TORCH_VERSION: &str = "2.4.1"; +const TORCH_VERSION: &str = "2.5.0"; const PYTHON_PRINT_PYTORCH_DETAILS: &str = r" import torch from torch.utils import cpp_extension @@ -158,7 +158,7 @@ fn version_check(version: &str) -> Result<()> { return Ok(()); } let version = version.trim(); - // Typical version number is 2.4.1+cpu or 2.4.1+cu121 + // Typical version number is 2.5.0+cpu or 2.5.0+cu121 let version = match version.split_once('+') { None => version, Some((version, _)) => version, @@ -314,6 +314,7 @@ impl SystemInfo { "cpu" => "%2Bcpu", "cu118" => "%2Bcu118", "cu121" => "%2Bcu121", + "cu124" => "%2Bcu124", _ => anyhow::bail!("unsupported device {device}, TORCH_CUDA_VERSION may be set incorrectly?"), } ), @@ -337,6 +338,7 @@ impl SystemInfo { "cpu" => "%2Bcpu", "cu118" => "%2Bcu118", "cu121" => "%2Bcu121", + "cu124" => "%2Bcu124", _ => "" }), }; diff --git a/torch-sys/libtch/torch_api_generated.cpp b/torch-sys/libtch/torch_api_generated.cpp index 2a6f8ec8..20ff66f4 100644 --- a/torch-sys/libtch/torch_api_generated.cpp +++ b/torch-sys/libtch/torch_api_generated.cpp @@ -1251,9 +1251,9 @@ void atg__fused_moving_avg_obs_fq_helper_out(tensor *out__, tensor out0, tensor ) } -int64_t atg__fused_sdp_choice(tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, double scale_v, uint8_t scale_null) { +int64_t atg__fused_sdp_choice(tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, double scale_v, uint8_t scale_null, int enable_gqa) { PROTECT( - return torch::_fused_sdp_choice(*query, *key, *value, (attn_mask ? ::std::optional(*attn_mask) : ::std::nullopt), dropout_p, (bool)is_causal, scale_null ? c10::nullopt : c10::optional(scale_v)); + return torch::_fused_sdp_choice(*query, *key, *value, (attn_mask ? ::std::optional(*attn_mask) : ::std::nullopt), dropout_p, (bool)is_causal, scale_null ? c10::nullopt : c10::optional(scale_v), (bool)enable_gqa); ) return 0; } @@ -1953,6 +1953,20 @@ void atg__nested_get_lengths(tensor *out__, tensor self) { ) } +void atg__nested_get_max_seqlen(tensor *out__, tensor self) { + PROTECT( + auto outputs__ = torch::_nested_get_max_seqlen(*self); + out__[0] = new torch::Tensor(outputs__); + ) +} + +void atg__nested_get_min_seqlen(tensor *out__, tensor self) { + PROTECT( + auto outputs__ = torch::_nested_get_min_seqlen(*self); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg__nested_get_offsets(tensor *out__, tensor self) { PROTECT( auto outputs__ = torch::_nested_get_offsets(*self); @@ -2023,23 +2037,23 @@ void atg__nested_view_from_buffer_copy_out(tensor *out__, tensor out, tensor sel ) } -void atg__nested_view_from_jagged(tensor *out__, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx) { +void atg__nested_view_from_jagged(tensor *out__, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx, tensor min_seqlen, tensor max_seqlen) { PROTECT( - auto outputs__ = torch::_nested_view_from_jagged(*self, *offsets, *dummy, (lengths ? ::std::optional(*lengths) : ::std::nullopt), ragged_idx); + auto outputs__ = torch::_nested_view_from_jagged(*self, *offsets, *dummy, (lengths ? ::std::optional(*lengths) : ::std::nullopt), ragged_idx, (min_seqlen ? ::std::optional(*min_seqlen) : ::std::nullopt), (max_seqlen ? ::std::optional(*max_seqlen) : ::std::nullopt)); out__[0] = new torch::Tensor(outputs__); ) } -void atg__nested_view_from_jagged_copy(tensor *out__, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx) { +void atg__nested_view_from_jagged_copy(tensor *out__, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx, tensor min_seqlen, tensor max_seqlen) { PROTECT( - auto outputs__ = torch::_nested_view_from_jagged_copy(*self, *offsets, *dummy, (lengths ? ::std::optional(*lengths) : ::std::nullopt), ragged_idx); + auto outputs__ = torch::_nested_view_from_jagged_copy(*self, *offsets, *dummy, (lengths ? ::std::optional(*lengths) : ::std::nullopt), ragged_idx, (min_seqlen ? ::std::optional(*min_seqlen) : ::std::nullopt), (max_seqlen ? ::std::optional(*max_seqlen) : ::std::nullopt)); out__[0] = new torch::Tensor(outputs__); ) } -void atg__nested_view_from_jagged_copy_out(tensor *out__, tensor out, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx) { +void atg__nested_view_from_jagged_copy_out(tensor *out__, tensor out, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx, tensor min_seqlen, tensor max_seqlen) { PROTECT( - auto outputs__ = torch::_nested_view_from_jagged_copy_out(*out, *self, *offsets, *dummy, (lengths ? ::std::optional(*lengths) : ::std::nullopt), ragged_idx); + auto outputs__ = torch::_nested_view_from_jagged_copy_out(*out, *self, *offsets, *dummy, (lengths ? ::std::optional(*lengths) : ::std::nullopt), ragged_idx, (min_seqlen ? ::std::optional(*min_seqlen) : ::std::nullopt), (max_seqlen ? ::std::optional(*max_seqlen) : ::std::nullopt)); out__[0] = new torch::Tensor(outputs__); ) } @@ -2257,6 +2271,13 @@ void atg__rowwise_prune(tensor *out__, tensor weight, tensor mask, int compresse ) } +void atg__safe_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { + PROTECT( + auto outputs__ = torch::_safe_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg__sample_dirichlet(tensor *out__, tensor self) { PROTECT( auto outputs__ = torch::_sample_dirichlet(*self); @@ -2278,17 +2299,25 @@ void atg__saturate_weight_to_fp16(tensor *out__, tensor weight) { ) } -void atg__scaled_dot_product_attention_math(tensor *out__, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, tensor dropout_mask, double scale_v, uint8_t scale_null) { +void atg__scaled_dot_product_attention_math(tensor *out__, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, tensor dropout_mask, double scale_v, uint8_t scale_null, int enable_gqa) { PROTECT( - auto outputs__ = torch::_scaled_dot_product_attention_math(*query, *key, *value, (attn_mask ? ::std::optional(*attn_mask) : ::std::nullopt), dropout_p, (bool)is_causal, (dropout_mask ? ::std::optional(*dropout_mask) : ::std::nullopt), scale_null ? c10::nullopt : c10::optional(scale_v)); + auto outputs__ = torch::_scaled_dot_product_attention_math(*query, *key, *value, (attn_mask ? ::std::optional(*attn_mask) : ::std::nullopt), dropout_p, (bool)is_causal, (dropout_mask ? ::std::optional(*dropout_mask) : ::std::nullopt), scale_null ? c10::nullopt : c10::optional(scale_v), (bool)enable_gqa); out__[0] = new torch::Tensor(std::get<0>(outputs__)); out__[1] = new torch::Tensor(std::get<1>(outputs__)); ) } -void atg__scaled_dot_product_cudnn_attention_backward(tensor *out__, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null) { +void atg__scaled_dot_product_attention_math_for_mps(tensor *out__, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, tensor dropout_mask, double scale_v, uint8_t scale_null) { PROTECT( - auto outputs__ = torch::_scaled_dot_product_cudnn_attention_backward(*grad_out, *query, *key, *value, *out, *logsumexp, *cum_seq_q, *cum_seq_k, max_q, max_k, dropout_p, (bool)is_causal, *philox_seed, *philox_offset, scale_null ? c10::nullopt : c10::optional(scale_v)); + auto outputs__ = torch::_scaled_dot_product_attention_math_for_mps(*query, *key, *value, (attn_mask ? ::std::optional(*attn_mask) : ::std::nullopt), dropout_p, (bool)is_causal, (dropout_mask ? ::std::optional(*dropout_mask) : ::std::nullopt), scale_null ? c10::nullopt : c10::optional(scale_v)); + out__[0] = new torch::Tensor(std::get<0>(outputs__)); + out__[1] = new torch::Tensor(std::get<1>(outputs__)); + ) +} + +void atg__scaled_dot_product_cudnn_attention_backward(tensor *out__, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor philox_seed, tensor philox_offset, tensor attn_bias, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, double scale_v, uint8_t scale_null) { + PROTECT( + auto outputs__ = torch::_scaled_dot_product_cudnn_attention_backward(*grad_out, *query, *key, *value, *out, *logsumexp, *philox_seed, *philox_offset, *attn_bias, *cum_seq_q, *cum_seq_k, max_q, max_k, dropout_p, (bool)is_causal, scale_null ? c10::nullopt : c10::optional(scale_v)); out__[0] = new torch::Tensor(std::get<0>(outputs__)); out__[1] = new torch::Tensor(std::get<1>(outputs__)); out__[2] = new torch::Tensor(std::get<2>(outputs__)); @@ -2331,19 +2360,17 @@ void atg__scaled_dot_product_flash_attention_for_cpu_backward(tensor *out__, ten ) } -void atg__scaled_mm(tensor *out__, tensor self, tensor mat2, tensor bias, int out_dtype, tensor scale_a, tensor scale_b, tensor scale_result, int use_fast_accum) { +void atg__scaled_mm(tensor *out__, tensor self, tensor mat2, tensor scale_a, tensor scale_b, tensor bias, tensor scale_result, int out_dtype, int use_fast_accum) { PROTECT( - auto outputs__ = torch::_scaled_mm(*self, *mat2, (bias ? ::std::optional(*bias) : ::std::nullopt), out_dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(out_dtype)), (scale_a ? ::std::optional(*scale_a) : ::std::nullopt), (scale_b ? ::std::optional(*scale_b) : ::std::nullopt), (scale_result ? ::std::optional(*scale_result) : ::std::nullopt), (bool)use_fast_accum); - out__[0] = new torch::Tensor(std::get<0>(outputs__)); - out__[1] = new torch::Tensor(std::get<1>(outputs__)); + auto outputs__ = torch::_scaled_mm(*self, *mat2, *scale_a, *scale_b, (bias ? ::std::optional(*bias) : ::std::nullopt), (scale_result ? ::std::optional(*scale_result) : ::std::nullopt), out_dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(out_dtype)), (bool)use_fast_accum); + out__[0] = new torch::Tensor(outputs__); ) } -void atg__scaled_mm_out(tensor *out__, tensor out, tensor out_amax, tensor self, tensor mat2, tensor bias, int out_dtype, tensor scale_a, tensor scale_b, tensor scale_result, int use_fast_accum) { +void atg__scaled_mm_out(tensor *out__, tensor out, tensor self, tensor mat2, tensor scale_a, tensor scale_b, tensor bias, tensor scale_result, int out_dtype, int use_fast_accum) { PROTECT( - auto outputs__ = torch::_scaled_mm_out(*out, *out_amax, *self, *mat2, (bias ? ::std::optional(*bias) : ::std::nullopt), out_dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(out_dtype)), (scale_a ? ::std::optional(*scale_a) : ::std::nullopt), (scale_b ? ::std::optional(*scale_b) : ::std::nullopt), (scale_result ? ::std::optional(*scale_result) : ::std::nullopt), (bool)use_fast_accum); - out__[0] = new torch::Tensor(std::get<0>(outputs__)); - out__[1] = new torch::Tensor(std::get<1>(outputs__)); + auto outputs__ = torch::_scaled_mm_out(*out, *self, *mat2, *scale_a, *scale_b, (bias ? ::std::optional(*bias) : ::std::nullopt), (scale_result ? ::std::optional(*scale_result) : ::std::nullopt), out_dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(out_dtype)), (bool)use_fast_accum); + out__[0] = new torch::Tensor(outputs__); ) } @@ -2818,6 +2845,13 @@ void atg__spdiags_out(tensor *out__, tensor out, tensor diagonals, tensor offset ) } +void atg__spsolve(tensor *out__, tensor A, tensor B, int left) { + PROTECT( + auto outputs__ = torch::_spsolve(*A, *B, (bool)left); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg__stack(tensor *out__, tensor *tensors_data, int tensors_len, int64_t dim) { PROTECT( auto outputs__ = torch::_stack(of_carray_tensor(tensors_data, tensors_len), dim); @@ -3277,6 +3311,20 @@ void atg__unsafe_index_put(tensor *out__, tensor self, tensor *indices_data, int ) } +void atg__unsafe_masked_index(tensor *out__, tensor self, tensor mask, tensor *indices_data, int indices_len, scalar fill) { + PROTECT( + auto outputs__ = torch::_unsafe_masked_index(*self, *mask, of_carray_tensor_opt(indices_data, indices_len), *fill); + out__[0] = new torch::Tensor(outputs__); + ) +} + +void atg__unsafe_masked_index_put_accumulate(tensor *out__, tensor self, tensor mask, tensor *indices_data, int indices_len, tensor values) { + PROTECT( + auto outputs__ = torch::_unsafe_masked_index_put_accumulate(*self, *mask, of_carray_tensor_opt(indices_data, indices_len), *values); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg__unsafe_view(tensor *out__, tensor self, int64_t *size_data, int size_len) { PROTECT( auto outputs__ = torch::_unsafe_view(*self, torch::IntArrayRef(size_data, size_len)); @@ -3612,6 +3660,20 @@ void atg__weight_norm_interface_out(tensor *out__, tensor out0, tensor out1, ten ) } +void atg__wrapped_linear_prepack(tensor *out__, tensor weight, tensor weight_scale, tensor weight_zero_point, tensor bias) { + PROTECT( + auto outputs__ = torch::_wrapped_linear_prepack(*weight, *weight_scale, *weight_zero_point, *bias); + out__[0] = new torch::Tensor(outputs__); + ) +} + +void atg__wrapped_quantized_linear_prepacked(tensor *out__, tensor input, tensor input_scale, tensor input_zero_point, tensor packed_weight, tensor output_scale, tensor output_zero_point, int64_t out_channel) { + PROTECT( + auto outputs__ = torch::_wrapped_quantized_linear_prepacked(*input, *input_scale, *input_zero_point, *packed_weight, *output_scale, *output_zero_point, out_channel); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg_abs(tensor *out__, tensor self) { PROTECT( auto outputs__ = torch::abs(*self); @@ -11694,6 +11756,13 @@ void atg_mean_dim(tensor *out__, tensor self, int64_t *dim_data, int dim_len, in ) } +void atg_mean_dtype_out(tensor *out__, tensor out, tensor self, int dtype) { + PROTECT( + auto outputs__ = torch::mean_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); + out__[0] = new torch::Tensor(outputs__); + ) +} + void atg_mean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); @@ -13240,9 +13309,9 @@ void atg_pad(tensor *out__, tensor self, int64_t *pad_data, int pad_len, char* m ) } -void atg_pad_sequence(tensor *out__, tensor *sequences_data, int sequences_len, int batch_first, double padding_value) { +void atg_pad_sequence(tensor *out__, tensor *sequences_data, int sequences_len, int batch_first, double padding_value, char* padding_side_ptr, int padding_side_len) { PROTECT( - auto outputs__ = torch::pad_sequence(of_carray_tensor(sequences_data, sequences_len), (bool)batch_first, padding_value); + auto outputs__ = torch::pad_sequence(of_carray_tensor(sequences_data, sequences_len), (bool)batch_first, padding_value, std::string(padding_side_ptr, padding_side_len)); out__[0] = new torch::Tensor(outputs__); ) } @@ -14722,9 +14791,9 @@ void atg_scalar_tensor_out(tensor *out__, tensor out, scalar s) { ) } -void atg_scaled_dot_product_attention(tensor *out__, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, double scale_v, uint8_t scale_null) { +void atg_scaled_dot_product_attention(tensor *out__, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, double scale_v, uint8_t scale_null, int enable_gqa) { PROTECT( - auto outputs__ = torch::scaled_dot_product_attention(*query, *key, *value, (attn_mask ? ::std::optional(*attn_mask) : ::std::nullopt), dropout_p, (bool)is_causal, scale_null ? c10::nullopt : c10::optional(scale_v)); + auto outputs__ = torch::scaled_dot_product_attention(*query, *key, *value, (attn_mask ? ::std::optional(*attn_mask) : ::std::nullopt), dropout_p, (bool)is_causal, scale_null ? c10::nullopt : c10::optional(scale_v), (bool)enable_gqa); out__[0] = new torch::Tensor(outputs__); ) } diff --git a/torch-sys/libtch/torch_api_generated.h b/torch-sys/libtch/torch_api_generated.h index 00f2b92e..39b8588b 100644 --- a/torch-sys/libtch/torch_api_generated.h +++ b/torch-sys/libtch/torch_api_generated.h @@ -171,7 +171,7 @@ void atg__fused_dropout_out(tensor *, tensor out0, tensor out1, tensor self, dou void atg__fused_moving_avg_obs_fq_helper(tensor *, tensor self, tensor observer_on, tensor fake_quant_on, tensor running_min, tensor running_max, tensor scale, tensor zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int per_row_fake_quant, int symmetric_quant); void atg__fused_moving_avg_obs_fq_helper_functional(tensor *, tensor self, tensor observer_on, tensor fake_quant_on, tensor running_min, tensor running_max, tensor scale, tensor zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int per_row_fake_quant, int symmetric_quant); void atg__fused_moving_avg_obs_fq_helper_out(tensor *, tensor out0, tensor out1, tensor self, tensor observer_on, tensor fake_quant_on, tensor running_min, tensor running_max, tensor scale, tensor zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int per_row_fake_quant, int symmetric_quant); -int64_t atg__fused_sdp_choice(tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, double scale_v, uint8_t scale_null); +int64_t atg__fused_sdp_choice(tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, double scale_v, uint8_t scale_null, int enable_gqa); void atg__fw_primal(tensor *, tensor self, int64_t level); void atg__fw_primal_copy(tensor *, tensor self, int64_t level); void atg__fw_primal_copy_out(tensor *, tensor out, tensor self, int64_t level); @@ -263,6 +263,8 @@ void atg__nested_from_padded_and_nested_example_out(tensor *, tensor out, tensor void atg__nested_from_padded_out(tensor *, tensor out, tensor padded, tensor cpu_nested_shape_example, int fuse_transform_0213); void atg__nested_get_jagged_dummy(tensor *, tensor any); void atg__nested_get_lengths(tensor *, tensor self); +void atg__nested_get_max_seqlen(tensor *, tensor self); +void atg__nested_get_min_seqlen(tensor *, tensor self); void atg__nested_get_offsets(tensor *, tensor self); int64_t atg__nested_get_ragged_idx(tensor self); void atg__nested_get_values(tensor *, tensor self); @@ -273,9 +275,9 @@ void atg__nested_sum_backward(tensor *, tensor grad, tensor self, int64_t *dim_d void atg__nested_view_from_buffer(tensor *, tensor self, tensor nested_size, tensor nested_strides, tensor offsets); void atg__nested_view_from_buffer_copy(tensor *, tensor self, tensor nested_size, tensor nested_strides, tensor offsets); void atg__nested_view_from_buffer_copy_out(tensor *, tensor out, tensor self, tensor nested_size, tensor nested_strides, tensor offsets); -void atg__nested_view_from_jagged(tensor *, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx); -void atg__nested_view_from_jagged_copy(tensor *, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx); -void atg__nested_view_from_jagged_copy_out(tensor *, tensor out, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx); +void atg__nested_view_from_jagged(tensor *, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx, tensor min_seqlen, tensor max_seqlen); +void atg__nested_view_from_jagged_copy(tensor *, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx, tensor min_seqlen, tensor max_seqlen); +void atg__nested_view_from_jagged_copy_out(tensor *, tensor out, tensor self, tensor offsets, tensor dummy, tensor lengths, int64_t ragged_idx, tensor min_seqlen, tensor max_seqlen); void atg__new_zeros_with_same_feature_meta(tensor *, tensor self, tensor other, int64_t self_num_batch_dims); void atg__new_zeros_with_same_feature_meta_out(tensor *, tensor out, tensor self, tensor other, int64_t self_num_batch_dims); int atg__nnpack_available(); @@ -306,17 +308,19 @@ void atg__resize_output(tensor *, tensor self, int64_t *size_data, int size_len, void atg__resize_output_(tensor *, tensor self, int64_t *size_data, int size_len, int device); void atg__resize_output_out(tensor *, tensor out, tensor self, int64_t *size_data, int size_len, int device); void atg__rowwise_prune(tensor *, tensor weight, tensor mask, int compressed_indices_dtype); +void atg__safe_softmax(tensor *, tensor self, int64_t dim, int dtype); void atg__sample_dirichlet(tensor *, tensor self); void atg__sample_dirichlet_out(tensor *, tensor out, tensor self); void atg__saturate_weight_to_fp16(tensor *, tensor weight); -void atg__scaled_dot_product_attention_math(tensor *, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, tensor dropout_mask, double scale_v, uint8_t scale_null); -void atg__scaled_dot_product_cudnn_attention_backward(tensor *, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null); +void atg__scaled_dot_product_attention_math(tensor *, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, tensor dropout_mask, double scale_v, uint8_t scale_null, int enable_gqa); +void atg__scaled_dot_product_attention_math_for_mps(tensor *, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, tensor dropout_mask, double scale_v, uint8_t scale_null); +void atg__scaled_dot_product_cudnn_attention_backward(tensor *, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor philox_seed, tensor philox_offset, tensor attn_bias, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, double scale_v, uint8_t scale_null); void atg__scaled_dot_product_efficient_attention(tensor *, tensor query, tensor key, tensor value, tensor attn_bias, int compute_log_sumexp, double dropout_p, int is_causal, double scale_v, uint8_t scale_null); void atg__scaled_dot_product_flash_attention_backward(tensor *, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, tensor cum_seq_q, tensor cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int is_causal, tensor philox_seed, tensor philox_offset, double scale_v, uint8_t scale_null); void atg__scaled_dot_product_flash_attention_for_cpu(tensor *, tensor query, tensor key, tensor value, double dropout_p, int is_causal, tensor attn_mask, double scale_v, uint8_t scale_null); void atg__scaled_dot_product_flash_attention_for_cpu_backward(tensor *, tensor grad_out, tensor query, tensor key, tensor value, tensor out, tensor logsumexp, double dropout_p, int is_causal, tensor attn_mask, double scale_v, uint8_t scale_null); -void atg__scaled_mm(tensor *, tensor self, tensor mat2, tensor bias, int out_dtype, tensor scale_a, tensor scale_b, tensor scale_result, int use_fast_accum); -void atg__scaled_mm_out(tensor *, tensor out, tensor out_amax, tensor self, tensor mat2, tensor bias, int out_dtype, tensor scale_a, tensor scale_b, tensor scale_result, int use_fast_accum); +void atg__scaled_mm(tensor *, tensor self, tensor mat2, tensor scale_a, tensor scale_b, tensor bias, tensor scale_result, int out_dtype, int use_fast_accum); +void atg__scaled_mm_out(tensor *, tensor out, tensor self, tensor mat2, tensor scale_a, tensor scale_b, tensor bias, tensor scale_result, int out_dtype, int use_fast_accum); void atg__scatter_reduce(tensor *, tensor self, int64_t dim, tensor index, tensor src, char* reduce_ptr, int reduce_len, int include_self); void atg__scatter_reduce_(tensor *, tensor self, int64_t dim, tensor index, tensor src, char* reduce_ptr, int reduce_len, int include_self); void atg__scatter_reduce_two_out(tensor *, tensor out, tensor self, int64_t dim, tensor index, tensor src, char* reduce_ptr, int reduce_len, int include_self); @@ -383,6 +387,7 @@ void atg__sparse_sum_dim_out(tensor *, tensor out, tensor self, int64_t *dim_dat void atg__sparse_sum_dtype(tensor *, tensor self, int dtype); void atg__spdiags(tensor *, tensor diagonals, tensor offsets, int64_t *shape_data, int shape_len, int8_t layout); void atg__spdiags_out(tensor *, tensor out, tensor diagonals, tensor offsets, int64_t *shape_data, int shape_len, int8_t layout); +void atg__spsolve(tensor *, tensor A, tensor B, int left); void atg__stack(tensor *, tensor *tensors_data, int tensors_len, int64_t dim); void atg__stack_out(tensor *, tensor out, tensor *tensors_data, int tensors_len, int64_t dim); void atg__standard_gamma(tensor *, tensor self); @@ -446,6 +451,8 @@ void atg__unique_out(tensor *, tensor out0, tensor out1, tensor self, int sorted void atg__unpack_dual(tensor *, tensor dual, int64_t level); void atg__unsafe_index(tensor *, tensor self, tensor *indices_data, int indices_len); void atg__unsafe_index_put(tensor *, tensor self, tensor *indices_data, int indices_len, tensor values, int accumulate); +void atg__unsafe_masked_index(tensor *, tensor self, tensor mask, tensor *indices_data, int indices_len, scalar fill); +void atg__unsafe_masked_index_put_accumulate(tensor *, tensor self, tensor mask, tensor *indices_data, int indices_len, tensor values); void atg__unsafe_view(tensor *, tensor self, int64_t *size_data, int size_len); void atg__unsafe_view_out(tensor *, tensor out, tensor self, int64_t *size_data, int size_len); void atg__upsample_bicubic2d_aa(tensor *, tensor self, int64_t *output_size_data, int output_size_len, int align_corners, double scales_h_v, uint8_t scales_h_null, double scales_w_v, uint8_t scales_w_null); @@ -494,6 +501,8 @@ void atg__weight_norm_interface(tensor *, tensor v, tensor g, int64_t dim); void atg__weight_norm_interface_backward(tensor *, tensor grad_w, tensor saved_v, tensor saved_g, tensor saved_norms, int64_t dim); void atg__weight_norm_interface_backward_out(tensor *, tensor out0, tensor out1, tensor grad_w, tensor saved_v, tensor saved_g, tensor saved_norms, int64_t dim); void atg__weight_norm_interface_out(tensor *, tensor out0, tensor out1, tensor v, tensor g, int64_t dim); +void atg__wrapped_linear_prepack(tensor *, tensor weight, tensor weight_scale, tensor weight_zero_point, tensor bias); +void atg__wrapped_quantized_linear_prepacked(tensor *, tensor input, tensor input_scale, tensor input_zero_point, tensor packed_weight, tensor output_scale, tensor output_zero_point, int64_t out_channel); void atg_abs(tensor *, tensor self); void atg_abs_(tensor *, tensor self); void atg_abs_out(tensor *, tensor out, tensor self); @@ -1622,6 +1631,7 @@ void atg_maximum(tensor *, tensor self, tensor other); void atg_maximum_out(tensor *, tensor out, tensor self, tensor other); void atg_mean(tensor *, tensor self, int dtype); void atg_mean_dim(tensor *, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype); +void atg_mean_dtype_out(tensor *, tensor out, tensor self, int dtype); void atg_mean_out(tensor *, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype); void atg_median(tensor *, tensor self); void atg_median_dim(tensor *, tensor self, int64_t dim, int keepdim); @@ -1832,7 +1842,7 @@ void atg_outer(tensor *, tensor self, tensor vec2); void atg_outer_out(tensor *, tensor out, tensor self, tensor vec2); int64_t atg_output_nr(tensor self); void atg_pad(tensor *, tensor self, int64_t *pad_data, int pad_len, char* mode_ptr, int mode_len, double value_v, uint8_t value_null); -void atg_pad_sequence(tensor *, tensor *sequences_data, int sequences_len, int batch_first, double padding_value); +void atg_pad_sequence(tensor *, tensor *sequences_data, int sequences_len, int batch_first, double padding_value, char* padding_side_ptr, int padding_side_len); void atg_pairwise_distance(tensor *, tensor x1, tensor x2, double p, double eps, int keepdim); void atg_pdist(tensor *, tensor self, double p); void atg_permute(tensor *, tensor self, int64_t *dims_data, int dims_len); @@ -2042,7 +2052,7 @@ void atg_rsub_scalar_out(tensor *, tensor out, tensor self, scalar other); void atg_rsub_tensor_out(tensor *, tensor out, tensor self, tensor other); void atg_scalar_tensor(tensor *, scalar s, int options_kind, int options_device); void atg_scalar_tensor_out(tensor *, tensor out, scalar s); -void atg_scaled_dot_product_attention(tensor *, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, double scale_v, uint8_t scale_null); +void atg_scaled_dot_product_attention(tensor *, tensor query, tensor key, tensor value, tensor attn_mask, double dropout_p, int is_causal, double scale_v, uint8_t scale_null, int enable_gqa); void atg_scatter(tensor *, tensor self, int64_t dim, tensor index, tensor src); void atg_scatter_(tensor *, tensor self, int64_t dim, tensor index, tensor src); void atg_scatter_add(tensor *, tensor self, int64_t dim, tensor index, tensor src); diff --git a/torch-sys/src/c_generated.rs b/torch-sys/src/c_generated.rs index 69d04bb8..fc427297 100644 --- a/torch-sys/src/c_generated.rs +++ b/torch-sys/src/c_generated.rs @@ -1293,6 +1293,7 @@ extern "C" { is_causal_: c_int, scale_v: f64, scale_null: i8, + enable_gqa_: c_int, ) -> i64; pub fn atg__fw_primal(out__: *mut *mut C_tensor, self_: *mut C_tensor, level_: i64); pub fn atg__fw_primal_copy(out__: *mut *mut C_tensor, self_: *mut C_tensor, level_: i64); @@ -1947,6 +1948,8 @@ extern "C" { ); pub fn atg__nested_get_jagged_dummy(out__: *mut *mut C_tensor, any_: *mut C_tensor); pub fn atg__nested_get_lengths(out__: *mut *mut C_tensor, self_: *mut C_tensor); + pub fn atg__nested_get_max_seqlen(out__: *mut *mut C_tensor, self_: *mut C_tensor); + pub fn atg__nested_get_min_seqlen(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg__nested_get_offsets(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg__nested_get_ragged_idx(self_: *mut C_tensor) -> i64; pub fn atg__nested_get_values(out__: *mut *mut C_tensor, self_: *mut C_tensor); @@ -2000,6 +2003,8 @@ extern "C" { dummy_: *mut C_tensor, lengths_: *mut C_tensor, ragged_idx_: i64, + min_seqlen_: *mut C_tensor, + max_seqlen_: *mut C_tensor, ); pub fn atg__nested_view_from_jagged_copy( out__: *mut *mut C_tensor, @@ -2008,6 +2013,8 @@ extern "C" { dummy_: *mut C_tensor, lengths_: *mut C_tensor, ragged_idx_: i64, + min_seqlen_: *mut C_tensor, + max_seqlen_: *mut C_tensor, ); pub fn atg__nested_view_from_jagged_copy_out( out__: *mut *mut C_tensor, @@ -2017,6 +2024,8 @@ extern "C" { dummy_: *mut C_tensor, lengths_: *mut C_tensor, ragged_idx_: i64, + min_seqlen_: *mut C_tensor, + max_seqlen_: *mut C_tensor, ); pub fn atg__new_zeros_with_same_feature_meta( out__: *mut *mut C_tensor, @@ -2205,6 +2214,12 @@ extern "C" { mask_: *mut C_tensor, compressed_indices_dtype_: c_int, ); + pub fn atg__safe_softmax( + out__: *mut *mut C_tensor, + self_: *mut C_tensor, + dim_: i64, + dtype_: c_int, + ); pub fn atg__sample_dirichlet(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg__sample_dirichlet_out( out__: *mut *mut C_tensor, @@ -2223,6 +2238,19 @@ extern "C" { dropout_mask_: *mut C_tensor, scale_v: f64, scale_null: i8, + enable_gqa_: c_int, + ); + pub fn atg__scaled_dot_product_attention_math_for_mps( + out__: *mut *mut C_tensor, + query_: *mut C_tensor, + key_: *mut C_tensor, + value_: *mut C_tensor, + attn_mask_: *mut C_tensor, + dropout_p_: f64, + is_causal_: c_int, + dropout_mask_: *mut C_tensor, + scale_v: f64, + scale_null: i8, ); pub fn atg__scaled_dot_product_cudnn_attention_backward( out__: *mut *mut C_tensor, @@ -2232,14 +2260,15 @@ extern "C" { value_: *mut C_tensor, out_: *mut C_tensor, logsumexp_: *mut C_tensor, + philox_seed_: *mut C_tensor, + philox_offset_: *mut C_tensor, + attn_bias_: *mut C_tensor, cum_seq_q_: *mut C_tensor, cum_seq_k_: *mut C_tensor, max_q_: i64, max_k_: i64, dropout_p_: f64, is_causal_: c_int, - philox_seed_: *mut C_tensor, - philox_offset_: *mut C_tensor, scale_v: f64, scale_null: i8, ); @@ -2303,24 +2332,23 @@ extern "C" { out__: *mut *mut C_tensor, self_: *mut C_tensor, mat2_: *mut C_tensor, - bias_: *mut C_tensor, - out_dtype_: c_int, scale_a_: *mut C_tensor, scale_b_: *mut C_tensor, + bias_: *mut C_tensor, scale_result_: *mut C_tensor, + out_dtype_: c_int, use_fast_accum_: c_int, ); pub fn atg__scaled_mm_out( out__: *mut *mut C_tensor, out_: *mut C_tensor, - out_amax_: *mut C_tensor, self_: *mut C_tensor, mat2_: *mut C_tensor, - bias_: *mut C_tensor, - out_dtype_: c_int, scale_a_: *mut C_tensor, scale_b_: *mut C_tensor, + bias_: *mut C_tensor, scale_result_: *mut C_tensor, + out_dtype_: c_int, use_fast_accum_: c_int, ); pub fn atg__scatter_reduce( @@ -2824,6 +2852,12 @@ extern "C" { shape_len: c_int, layout_: i8, ); + pub fn atg__spsolve( + out__: *mut *mut C_tensor, + A_: *mut C_tensor, + B_: *mut C_tensor, + left_: c_int, + ); pub fn atg__stack( out__: *mut *mut C_tensor, tensors_data: *const *mut C_tensor, @@ -3271,6 +3305,22 @@ extern "C" { values_: *mut C_tensor, accumulate_: c_int, ); + pub fn atg__unsafe_masked_index( + out__: *mut *mut C_tensor, + self_: *mut C_tensor, + mask_: *mut C_tensor, + indices_data: *const *mut C_tensor, + indices_len: c_int, + fill_: *mut C_scalar, + ); + pub fn atg__unsafe_masked_index_put_accumulate( + out__: *mut *mut C_tensor, + self_: *mut C_tensor, + mask_: *mut C_tensor, + indices_data: *const *mut C_tensor, + indices_len: c_int, + values_: *mut C_tensor, + ); pub fn atg__unsafe_view( out__: *mut *mut C_tensor, self_: *mut C_tensor, @@ -3692,6 +3742,23 @@ extern "C" { g_: *mut C_tensor, dim_: i64, ); + pub fn atg__wrapped_linear_prepack( + out__: *mut *mut C_tensor, + weight_: *mut C_tensor, + weight_scale_: *mut C_tensor, + weight_zero_point_: *mut C_tensor, + bias_: *mut C_tensor, + ); + pub fn atg__wrapped_quantized_linear_prepacked( + out__: *mut *mut C_tensor, + input_: *mut C_tensor, + input_scale_: *mut C_tensor, + input_zero_point_: *mut C_tensor, + packed_weight_: *mut C_tensor, + output_scale_: *mut C_tensor, + output_zero_point_: *mut C_tensor, + out_channel_: i64, + ); pub fn atg_abs(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg_abs_(out__: *mut *mut C_tensor, self_: *mut C_tensor); pub fn atg_abs_out(out__: *mut *mut C_tensor, out_: *mut C_tensor, self_: *mut C_tensor); @@ -9881,6 +9948,12 @@ extern "C" { keepdim_: c_int, dtype_: c_int, ); + pub fn atg_mean_dtype_out( + out__: *mut *mut C_tensor, + out_: *mut C_tensor, + self_: *mut C_tensor, + dtype_: c_int, + ); pub fn atg_mean_out( out__: *mut *mut C_tensor, out_: *mut C_tensor, @@ -11398,6 +11471,8 @@ extern "C" { sequences_len: c_int, batch_first_: c_int, padding_value_: f64, + padding_side_ptr: *const u8, + padding_side_len: c_int, ); pub fn atg_pairwise_distance( out__: *mut *mut C_tensor, @@ -12580,6 +12655,7 @@ extern "C" { is_causal_: c_int, scale_v: f64, scale_null: i8, + enable_gqa_: c_int, ); pub fn atg_scatter( out__: *mut *mut C_tensor,