Skip to content

Commit

Permalink
Added device parameter to allow usage with MPS (Apple Silicon) and Vu…
Browse files Browse the repository at this point in the history
…lkan (#32)

* Bumps dependencies for compatibility with libtorch 2.2

* Adds device parameter to enable usage with MPS (Apple Silicon) and Vulkan

* Disables default features for rust-bert
  • Loading branch information
Luxbit authored Sep 17, 2024
1 parent 89d505c commit 7bc6d37
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 24 deletions.
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sbert"
version = "0.4.1"
version = "0.5.0"
authors = ["Chady Dimachkie <cpcdoy@gmail.com>"]
edition = "2018"
description = "Rust implementation of Sentence Bert (SBert)"
Expand All @@ -16,15 +16,15 @@ log = "0.4"
num_cpus = "1.13"
prost = "0.9"
rayon = "1.5"
rust-bert = "0.21.0"
rust-bert = { git = "https://github.com/guillaume-be/rust-bert", rev = "29f9a7a", default-features = false }
rust_tokenizers = "7.0"
serde = "1.0"
strum = "0.23"
strum_macros = "0.23"
tch = "0.13.0"
tch = "0.15.0"
thiserror = "1.0"
tokenizers = "0.11"
torch-sys = "0.13.0"
tokenizers = "0.15"
torch-sys = "0.15.0"

[dev-dependencies]
criterion = "0.3"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ You can use different versions of the models that use different tokenizers:

```Rust
// To use Hugging Face tokenizer
let sbert_model = SBertHF::new(home.to_str().unwrap());
let sbert_model = SBertHF::new(home.to_str().unwrap(), None);

// To use Rust-tokenizers
let sbert_model = SBertRT::new(home.to_str().unwrap());
let sbert_model = SBertRT::new(home.to_str().unwrap(), None);
```

Now, you can encode your sentences:
Expand Down
2 changes: 1 addition & 1 deletion benches/bench_distilroberta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn bench_distilroberta_rust_tokenizers_sentencepiece(c: &mut Criterion) {
home.push("distilroberta_toxicity");

println!("Loading distilroberta ...");
let sbert_model = DistilRobertaForSequenceClassificationRT::new(home).unwrap();
let sbert_model = DistilRobertaForSequenceClassificationRT::new(home, None).unwrap();

let text = "TTThis player needs tp be reported lolz.";
c.bench_function(
Expand Down
4 changes: 2 additions & 2 deletions benches/bench_sbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ fn bench_sbert_rust_tokenizers(c: &mut Criterion) {
home.push("distiluse-base-multilingual-cased");

println!("Loading sbert ...");
let sbert_model = SBertRT::new(home).unwrap();
let sbert_model = SBertRT::new(home, None).unwrap();

let text = "TTThis player needs tp be reported lolz.";
c.bench_function("Encode batch, safe sbert rust tokenizer, total 1", |b| {
Expand All @@ -53,7 +53,7 @@ fn bench_sbert_hugging_face_tokenizers(c: &mut Criterion) {
home.push("distiluse-base-multilingual-cased");

println!("Loading sbert ...");
let sbert_model = SBertHF::new(home).unwrap();
let sbert_model = SBertHF::new(home, None).unwrap();

let text = "TTThis player needs tp be reported lolz.";
c.bench_function(
Expand Down
5 changes: 2 additions & 3 deletions src/layers/dense.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ impl Config for DenseConfig {}
pub struct Dense {
linear: nn::Linear,
_conf: DenseConfig,

}

impl Dense {
pub fn new<P: Into<PathBuf>>(root: P) -> Result<Dense, Error> {
pub fn new<P: Into<PathBuf>>(root: P, device: Device) -> Result<Dense, Error> {
let dense_dir = root.into().join("2_Dense");
log::info!("Loading conf {:?}", dense_dir);

let device = Device::cuda_if_available();
//let device = Device::Cpu;
let mut vs_dense = nn::VarStore::new(device);

let init_conf = nn::LinearConfig {
Expand Down
4 changes: 2 additions & 2 deletions src/models/distilroberta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ impl<T> DistilRobertaForSequenceClassification<T>
where
T: Tokenizer + Send + Sync,
{
pub fn new<P>(root: P) -> Result<Self, Error>
pub fn new<P>(root: P, device: Option<Device>) -> Result<Self, Error>
where
P: Into<PathBuf>,
{
Expand All @@ -36,7 +36,7 @@ where

let config = BertConfig::from_file(&config_file);

let device = Device::cuda_if_available();
let device = device.unwrap_or(Device::cuda_if_available());
log::info!("Using device {:?}", device);

let mut vs = nn::VarStore::new(device);
Expand Down
10 changes: 6 additions & 4 deletions src/models/sbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl<T> SBert<T>
where
T: Tokenizer + Send + Sync,
{
pub fn new<P>(root: P) -> Result<Self, Error>
pub fn new<P>(root: P, device: Option<Device>) -> Result<Self, Error>
where
P: Into<PathBuf>,
{
Expand All @@ -44,11 +44,13 @@ where
let nb_layers = config.n_layers as usize;
let nb_heads = config.n_heads as usize;

let device = device.unwrap_or(Device::cuda_if_available());
log::info!("Using device {:?}", device);

let pooling = Pooling::new(root.clone());
let dense = Dense::new(root)?;
let dense = Dense::new(root, device)?;


let device = Device::cuda_if_available();
log::info!("Using device {:?}", device);

let mut vs = nn::VarStore::new(device);

Expand Down
2 changes: 1 addition & 1 deletion src/tokenizers/hf_tokenizers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl Tokenizer for HFTokenizer {
let stride = 0;
let strategy = TruncationStrategy::LongestFirst;
let direction = TruncationDirection::Right;
tokenizer.with_truncation(Some(TruncationParams {
let _ = tokenizer.with_truncation(Some(TruncationParams {
max_length,
stride,
strategy,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_distilroberta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ mod tests {

println!("Loading distilroberta ...");
let before = Instant::now();
let sbert_model = DistilRobertaForSequenceClassificationRT::new(home).unwrap();
let sbert_model = DistilRobertaForSequenceClassificationRT::new(home, None).unwrap();
println!("Elapsed time: {:.2?}", before.elapsed());

let mut texts = Vec::new();
Expand Down
6 changes: 3 additions & 3 deletions tests/test_sbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ mod tests {

println!("Loading sbert ...");
let before = Instant::now();
let sbert_model = SBertRT::new(home).unwrap();
let sbert_model = SBertRT::new(home, None).unwrap();
println!("Elapsed time: {:.2?}", before.elapsed());

let mut texts = Vec::new();
Expand Down Expand Up @@ -104,7 +104,7 @@ mod tests {

println!("Loading sbert ...");
let before = Instant::now();
let sbert_model = SBertHF::new(home).unwrap();
let sbert_model = SBertHF::new(home, None).unwrap();
println!("Elapsed time: {:.2?}", before.elapsed());

let mut texts = Vec::new();
Expand Down Expand Up @@ -137,7 +137,7 @@ mod tests {

println!("Loading sbert ...");
let before = Instant::now();
let sbert_model = SBertHF::new(home).unwrap();
let sbert_model = SBertHF::new(home, None).unwrap();
println!("Elapsed time: {:.2?}", before.elapsed());

let mut texts = Vec::new();
Expand Down

0 comments on commit 7bc6d37

Please sign in to comment.