Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(download): use downloader to cache registry and check capability before download #3480

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@
GGML_MODEL_PARTITIONED_PREFIX.to_owned()
))
} else {
let (registry, name) = parse_model_id(model_id);
let (registry, name) = parse_model_id(model_id).unwrap();

Check warning on line 287 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L287

Added line #L287 was not covered by tests
let registry = ModelRegistry::new(registry).await;
registry
.get_model_entry_path(name)
Expand All @@ -311,9 +311,9 @@
if path.exists() {
PromptInfo::read(path.join("tabby.json"))
} else {
let (registry, name) = parse_model_id(model_id);
let (registry, name) = parse_model_id(model_id).unwrap();

Check warning on line 314 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L314

Added line #L314 was not covered by tests
let registry = ModelRegistry::new(registry).await;
let model_info = registry.get_model_info(name);
let model_info = registry.get_model_info(name).unwrap();

Check warning on line 316 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L316

Added line #L316 was not covered by tests
PromptInfo {
prompt_template: model_info.prompt_template.to_owned(),
chat_template: model_info.chat_template.to_owned(),
Expand Down
33 changes: 19 additions & 14 deletions crates/tabby-common/src/registry.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{fs, path::PathBuf};

use anyhow::{Context, Result};
use anyhow::{anyhow, Context, Result};
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};

use crate::path::models_dir;

#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Clone)]
pub struct ModelInfo {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -29,7 +29,7 @@
pub partition_urls: Option<Vec<PartitionModelUrl>>,
}

#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Clone)]
pub struct PartitionModelUrl {
pub urls: Vec<String>,
pub sha256: String,
Expand Down Expand Up @@ -65,7 +65,7 @@
Ok(serdeconv::from_json_file(models_json_file(registry))?)
}

#[derive(Default)]
#[derive(Default, Clone)]
pub struct ModelRegistry {
pub name: String,
pub models: Vec<ModelInfo>,
Expand Down Expand Up @@ -155,29 +155,34 @@
Ok(())
}

pub fn save_model_info(&self, name: &str) {
let model_info = self.get_model_info(name);
pub fn save_model_info(&self, name: &str) -> Result<()> {
let model_info = self.get_model_info(name)?;

Check warning on line 159 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L158-L159

Added lines #L158 - L159 were not covered by tests
let path = self.get_model_dir(name).join("tabby.json");
fs::create_dir_all(path.parent().unwrap()).unwrap();
serdeconv::to_json_file(model_info, path).unwrap();
fs::create_dir_all(
path.parent()
.ok_or_else(|| anyhow!("Fail to create model directory"))?,
)?;
serdeconv::to_json_file(model_info, path)?;

Check warning on line 165 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L161-L165

Added lines #L161 - L165 were not covered by tests

Ok(())

Check warning on line 167 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L167

Added line #L167 was not covered by tests
}

pub fn get_model_info(&self, name: &str) -> &ModelInfo {
pub fn get_model_info(&self, name: &str) -> Result<&ModelInfo> {

Check warning on line 170 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L170

Added line #L170 was not covered by tests
self.models
.iter()
.find(|x| x.name == name)
.unwrap_or_else(|| panic!("Invalid model_id <{}/{}>", self.name, name))
.ok_or_else(|| anyhow!("Invalid model_id <{}/{}>", self.name, name))

Check warning on line 174 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L174

Added line #L174 was not covered by tests
}
}

pub fn parse_model_id(model_id: &str) -> (&str, &str) {
pub fn parse_model_id(model_id: &str) -> Result<(&str, &str)> {

Check warning on line 178 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L178

Added line #L178 was not covered by tests
let parts: Vec<_> = model_id.split('/').collect();
if parts.len() == 1 {
("TabbyML", parts[0])
Ok(("TabbyML", parts[0]))

Check warning on line 181 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L181

Added line #L181 was not covered by tests
} else if parts.len() == 2 {
(parts[0], parts[1])
Ok((parts[0], parts[1]))

Check warning on line 183 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L183

Added line #L183 was not covered by tests
} else {
panic!("Invalid model id {}", model_id);
Err(anyhow!("Invalid model id {}", model_id))

Check warning on line 185 in crates/tabby-common/src/registry.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-common/src/registry.rs#L185

Added line #L185 was not covered by tests
}
}

Expand Down
21 changes: 10 additions & 11 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
use std::{fs, io};

use aim_downloader::{bar::WrappedBar, error::DownloadError, hash::HashChecker, https};
use anyhow::{bail, Result};
use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry};
use anyhow::{anyhow, bail, Result};
use tabby_common::registry::{ModelInfo, ModelRegistry};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
Expand Down Expand Up @@ -72,7 +72,7 @@
name: &str,
prefer_local_file: bool,
) -> Result<()> {
let model_info = registry.get_model_info(name);
let model_info = registry.get_model_info(name)?;

Check warning on line 75 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L75

Added line #L75 was not covered by tests
registry.migrate_legacy_model_path(name)?;

let urls = filter_download_address(model_info);
Expand Down Expand Up @@ -187,15 +187,14 @@
Ok(())
}

pub async fn download_model(model_id: &str, prefer_local_file: bool) {
let (registry, name) = parse_model_id(model_id);

let registry = ModelRegistry::new(registry).await;

let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err);
download_model_impl(&registry, name, prefer_local_file)
pub async fn download_model(
registry: &ModelRegistry,
model: &str,
prefer_local_file: bool,
) -> Result<()> {
download_model_impl(registry, model, prefer_local_file)

Check warning on line 195 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L190-L195

Added lines #L190 - L195 were not covered by tests
.await
.unwrap_or_else(handler)
.map_err(|err| anyhow!("Failed to fetch model '{}' due to '{}'", model, err))

Check warning on line 197 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L197

Added line #L197 was not covered by tests
}

#[cfg(test)]
Expand Down
13 changes: 11 additions & 2 deletions crates/tabby/src/download.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use clap::Args;
use tabby_download::download_model;
use tracing::info;

use crate::services::model::Downloader;

#[derive(Args)]
pub struct DownloadArgs {
/// model id to fetch.
Expand All @@ -14,6 +15,14 @@
}

pub async fn main(args: &DownloadArgs) {
download_model(&args.model, args.prefer_local_file).await;
let mut downloader = Downloader::new();
let (registry, _) = downloader
.get_model_registry_and_info(&args.model)
.await
.unwrap();
downloader
.download_model(&registry, &args.model, args.prefer_local_file)
.await
.unwrap();

Check warning on line 26 in crates/tabby/src/download.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/download.rs#L18-L26

Added lines #L18 - L26 were not covered by tests
info!("model '{}' is ready", args.model);
}
19 changes: 13 additions & 6 deletions crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
embedding,
event::create_event_logger,
health,
model::download_model_if_needed,
model::Downloader,
tantivy::IndexReaderProvider,
},
to_local_config, Device,
Expand Down Expand Up @@ -113,7 +113,10 @@
pub async fn main(config: &Config, args: &ServeArgs) {
let config = merge_args(config, args);

load_model(&config).await;
if let Err(e) = load_model(&config).await {
warn!("Failed to load model: {}", e);
std::process::exit(1);
}

Check warning on line 119 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L116-L119

Added lines #L116 - L119 were not covered by tests

let tx = try_run_spinner();

Expand Down Expand Up @@ -210,18 +213,22 @@
run_app(api, Some(ui), args.host, args.port).await
}

async fn load_model(config: &Config) {
async fn load_model(config: &Config) -> anyhow::Result<()> {
let mut downloader = Downloader::new();

Check warning on line 217 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L216-L217

Added lines #L216 - L217 were not covered by tests

if let Some(ModelConfig::Local(ref model)) = config.model.completion {
download_model_if_needed(&model.model_id).await;
downloader.download_completion(&model.model_id).await?;

Check warning on line 220 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L220

Added line #L220 was not covered by tests
}

if let Some(ModelConfig::Local(ref model)) = config.model.chat {
download_model_if_needed(&model.model_id).await;
downloader.download_chat(&model.model_id).await?;

Check warning on line 224 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L224

Added line #L224 was not covered by tests
}

if let ModelConfig::Local(ref model) = config.model.embedding {
download_model_if_needed(&model.model_id).await;
downloader.download_embedding(&model.model_id).await?;

Check warning on line 228 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L228

Added line #L228 was not covered by tests
}

Ok(())

Check warning on line 231 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L231

Added line #L231 was not covered by tests
}

async fn api_router(
Expand Down
92 changes: 85 additions & 7 deletions crates/tabby/src/services/model/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::{fs, sync::Arc};
use std::{collections::hash_map, fs, sync::Arc};

use anyhow::{anyhow, bail, Result};
pub use llama_cpp_server::PromptInfo;
use tabby_common::config::ModelConfig;
use tabby_common::{
config::ModelConfig,
registry::{parse_model_id, ModelInfo, ModelRegistry},
};
use tabby_download::download_model;
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding};
use tracing::info;
Expand Down Expand Up @@ -79,11 +83,85 @@

(completion, prompt, chat)
}
pub struct Downloader {
registries: hash_map::HashMap<String, ModelRegistry>,
}

pub async fn download_model_if_needed(model: &str) {
if fs::metadata(model).is_ok() {
info!("Loading model from local path {}", model);
} else {
download_model(model, true).await;
impl Downloader {
pub fn new() -> Self {
Self {
registries: hash_map::HashMap::new(),
}
}

Check warning on line 95 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L91-L95

Added lines #L91 - L95 were not covered by tests

pub async fn get_model_registry_and_info(
&mut self,
model_id: &str,
) -> Result<(ModelRegistry, ModelInfo)> {
let (registry_name, model_name) = parse_model_id(model_id)?;

Check warning on line 101 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L97-L101

Added lines #L97 - L101 were not covered by tests

let registry = if let Some(registry) = self.registries.get(registry_name) {
registry.clone()

Check warning on line 104 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L103-L104

Added lines #L103 - L104 were not covered by tests
} else {
let registry = ModelRegistry::new(registry_name).await;
self.registries
.insert(registry_name.to_owned(), registry.clone());
registry

Check warning on line 109 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L106-L109

Added lines #L106 - L109 were not covered by tests
};

let info = registry.get_model_info(model_name)?.clone();

Ok((registry, info))
}

Check warning on line 115 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L112-L115

Added lines #L112 - L115 were not covered by tests

pub async fn download_model(
&self,
registry: &ModelRegistry,
model_id: &str,
prefer_local_file: bool,
) -> Result<()> {
let (_, model_name) = parse_model_id(model_id)?;
download_model(registry, model_name, prefer_local_file).await
}

Check warning on line 125 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L117-L125

Added lines #L117 - L125 were not covered by tests

async fn download_model_with_validation_if_needed(
&mut self,
model_id: &str,
validation: fn(&ModelInfo) -> Result<()>,
) -> Result<()> {
if fs::metadata(model_id).is_ok() {
info!("Loading model from local path {}", model_id);
return Ok(());
}

Check warning on line 135 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L127-L135

Added lines #L127 - L135 were not covered by tests

let (registry, info) = self.get_model_registry_and_info(model_id).await?;
validation(&info).map_err(|err| anyhow!("Fail to load model {}: {}", model_id, err))?;

Check warning on line 138 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L137-L138

Added lines #L137 - L138 were not covered by tests

self.download_model(&registry, model_id, true).await
}

Check warning on line 141 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L140-L141

Added lines #L140 - L141 were not covered by tests

pub async fn download_completion(&mut self, model_id: &str) -> Result<()> {
self.download_model_with_validation_if_needed(model_id, |info| {
if info.prompt_template.is_none() {
bail!("Model doesn't support completion");
}
Ok(())
})
.await
}

Check warning on line 151 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L143-L151

Added lines #L143 - L151 were not covered by tests

pub async fn download_chat(&mut self, model_id: &str) -> Result<()> {
self.download_model_with_validation_if_needed(model_id, |info| {
if info.chat_template.is_none() {
bail!("Model doesn't support chat");
}
Ok(())
})
.await
}

Check warning on line 161 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L153-L161

Added lines #L153 - L161 were not covered by tests

pub async fn download_embedding(&mut self, model_id: &str) -> Result<()> {
self.download_model_with_validation_if_needed(model_id, |_| Ok(()))
.await

Check warning on line 165 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L163-L165

Added lines #L163 - L165 were not covered by tests
}
}
Loading