Skip to content

Commit

Permalink
Merge pull request #899 from de-vri-es/fix-glog-pytorch-2.3.0
Browse files Browse the repository at this point in the history
Fix glog for pytorch 2.3.x branch.
  • Loading branch information
LaurentMazare authored Oct 15, 2024
2 parents 269ff36 + a9d496e commit 155735a
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tch"
version = "0.16.0"
version = "0.16.1"
authors = ["Laurent Mazare <lmazare@gmail.com>"]
edition = "2021"
build = "build.rs"
Expand All @@ -22,7 +22,7 @@ libc = "0.2.0"
ndarray = "0.15"
rand = "0.8"
thiserror = "1"
torch-sys = { version = "0.16.0", path = "torch-sys" }
torch-sys = { version = "0.16.1", path = "torch-sys" }
zip = "0.6"
half = "2"
safetensors = "0.3.0"
Expand Down
2 changes: 1 addition & 1 deletion examples/min-gpt/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fn causal_self_attention(p: &nn::Path, cfg: Config) -> impl ModuleT {
let q = xs.apply(&query).view(sizes).transpose(1, 2);
let v = xs.apply(&value).view(sizes).transpose(1, 2);
let att = q.matmul(&k.transpose(-2, -1)) * (1.0 / f64::sqrt(sizes[3] as f64));
let att = att.masked_fill(&mask.i((.., .., ..sz_t, ..sz_t)).eq(0.), std::f64::NEG_INFINITY);
let att = att.masked_fill(&mask.i((.., .., ..sz_t, ..sz_t)).eq(0.), f64::NEG_INFINITY);
let att = att.softmax(-1, Kind::Float).dropout(cfg.attn_pdrop, train);
let ys = att.matmul(&v).transpose(1, 2).contiguous().view([sz_b, sz_t, sz_c]);
ys.apply(&proj).dropout(cfg.resid_pdrop, train)
Expand Down
6 changes: 3 additions & 3 deletions examples/python-extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.21", features = ["extension-module"] }
pyo3-tch = { path = "../../pyo3-tch", version = "0.16.0" }
tch = { path = "../..", features = ["python-extension"], version = "0.16.0" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.16.0" }
pyo3-tch = { path = "../../pyo3-tch", version = "0.16.1" }
tch = { path = "../..", features = ["python-extension"], version = "0.16.1" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.16.1" }
4 changes: 2 additions & 2 deletions examples/yolo/darknet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct Block {

impl Block {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
match self.parameters.get(key) {
None => bail!("cannot find {} in {}", key, self.block_type),
Some(value) => Ok(value),
}
Expand All @@ -32,7 +32,7 @@ pub struct Darknet {

impl Darknet {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
match self.parameters.get(key) {
None => bail!("cannot find {} in net parameters", key),
Some(value) => Ok(value),
}
Expand Down
8 changes: 4 additions & 4 deletions pyo3-tch/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pyo3-tch"
version = "0.16.0"
version = "0.16.1"
authors = ["Laurent Mazare <lmazare@gmail.com>"]
edition = "2021"
build = "build.rs"
Expand All @@ -12,6 +12,6 @@ categories = ["science"]
license = "MIT/Apache-2.0"

[dependencies]
tch = { path = "..", features = ["python-extension"], version = "0.16.0" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.16.0" }
pyo3 = { version = "0.21", features = ["extension-module"] }
tch = { path = "..", features = ["python-extension"], version = "0.16.1" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.16.1" }
pyo3 = { version = "0.21", features = ["extension-module"] }
4 changes: 2 additions & 2 deletions src/nn/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
// Optimize the case for which a single C++ code can be done.
if cst == 0. {
Tensor::f_zeros(dims, (Kind::Float, device))
} else if (cst - 1.).abs() <= std::f64::EPSILON {
} else if (cst - 1.).abs() <= f64::EPSILON {
Tensor::f_ones(dims, (Kind::Float, device))
} else {
Tensor::f_ones(dims, (Kind::Float, device)).map(|t| t * cst)
Expand All @@ -117,7 +117,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(lo, up)
}
Init::Randn { mean, stdev } => {
if mean == 0. && (stdev - 1.).abs() <= std::f64::EPSILON {
if mean == 0. && (stdev - 1.).abs() <= f64::EPSILON {
Tensor::f_randn(dims, (Kind::Float, device))
} else {
Tensor::f_randn(dims, (Kind::Float, device)).map(|t| t * stdev + mean)
Expand Down
2 changes: 1 addition & 1 deletion torch-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "torch-sys"
version = "0.16.0"
version = "0.16.1"
authors = ["Laurent Mazare <lmazare@gmail.com>"]
edition = "2021"
build = "build.rs"
Expand Down
6 changes: 4 additions & 2 deletions torch-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,10 @@ impl SystemInfo {
.pic(true)
.warnings(false)
.includes(&self.libtorch_include_dirs)
.flag(&format!("-Wl,-rpath={}", self.libtorch_lib_dir.display()))
.flag(format!("-Wl,-rpath={}", self.libtorch_lib_dir.display()))
.flag("-std=c++17")
.flag(&format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi))
.flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi))
.flag("-DGLOG_USE_GLOG_EXPORT")
.files(&c_files)
.compile("tch");
}
Expand All @@ -398,6 +399,7 @@ impl SystemInfo {
.warnings(false)
.includes(&self.libtorch_include_dirs)
.flag("/std:c++17")
.flag("/p:DefineConstants=GLOG_USE_GLOG_EXPORT")
.files(&c_files)
.compile("tch");
}
Expand Down

0 comments on commit 155735a

Please sign in to comment.