diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index ef1ba84804d4..99e0be2d5817 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -284,7 +284,7 @@ async fn resolve_model_path(model_id: &str) -> String { GGML_MODEL_PARTITIONED_PREFIX.to_owned() )) } else { - let (registry, name) = parse_model_id(model_id); + let (registry, name) = parse_model_id(model_id).unwrap(); let registry = ModelRegistry::new(registry).await; registry .get_model_entry_path(name) @@ -311,9 +311,9 @@ async fn resolve_prompt_info(model_id: &str) -> PromptInfo { if path.exists() { PromptInfo::read(path.join("tabby.json")) } else { - let (registry, name) = parse_model_id(model_id); + let (registry, name) = parse_model_id(model_id).unwrap(); let registry = ModelRegistry::new(registry).await; - let model_info = registry.get_model_info(name); + let model_info = registry.get_model_info(name).unwrap(); PromptInfo { prompt_template: model_info.prompt_template.to_owned(), chat_template: model_info.chat_template.to_owned(), diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs index 94c3c3eb8c11..774843bd5d6e 100644 --- a/crates/tabby-common/src/registry.rs +++ b/crates/tabby-common/src/registry.rs @@ -1,12 +1,12 @@ use std::{fs, path::PathBuf}; -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use crate::path::models_dir; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct ModelInfo { pub name: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -29,7 +29,7 @@ pub struct ModelInfo { pub partition_urls: Option>, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone)] pub struct PartitionModelUrl { pub urls: Vec, pub sha256: String, @@ -65,7 +65,7 @@ fn load_local_registry(registry: &str) -> Result> { Ok(serdeconv::from_json_file(models_json_file(registry))?) } -#[derive(Default)] +#[derive(Default, Clone)] pub struct ModelRegistry { pub name: String, pub models: Vec, @@ -155,29 +155,34 @@ impl ModelRegistry { Ok(()) } - pub fn save_model_info(&self, name: &str) { - let model_info = self.get_model_info(name); + pub fn save_model_info(&self, name: &str) -> Result<()> { + let model_info = self.get_model_info(name)?; let path = self.get_model_dir(name).join("tabby.json"); - fs::create_dir_all(path.parent().unwrap()).unwrap(); - serdeconv::to_json_file(model_info, path).unwrap(); + fs::create_dir_all( + path.parent() + .ok_or_else(|| anyhow!("Fail to create model directory"))?, + )?; + serdeconv::to_json_file(model_info, path)?; + + Ok(()) } - pub fn get_model_info(&self, name: &str) -> &ModelInfo { + pub fn get_model_info(&self, name: &str) -> Result<&ModelInfo> { self.models .iter() .find(|x| x.name == name) - .unwrap_or_else(|| panic!("Invalid model_id <{}/{}>", self.name, name)) + .ok_or_else(|| anyhow!("Invalid model_id <{}/{}>", self.name, name)) } } -pub fn parse_model_id(model_id: &str) -> (&str, &str) { +pub fn parse_model_id(model_id: &str) -> Result<(&str, &str)> { let parts: Vec<_> = model_id.split('/').collect(); if parts.len() == 1 { - ("TabbyML", parts[0]) + Ok(("TabbyML", parts[0])) } else if parts.len() == 2 { - (parts[0], parts[1]) + Ok((parts[0], parts[1])) } else { - panic!("Invalid model id {}", model_id); + Err(anyhow!("Invalid model id {}", model_id)) } } diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 0c1cc942fae7..780c39699e73 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -2,8 +2,8 @@ use std::{fs, io}; use aim_downloader::{bar::WrappedBar, error::DownloadError, hash::HashChecker, https}; -use anyhow::{bail, Result}; -use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry}; +use anyhow::{anyhow, bail, Result}; +use tabby_common::registry::{ModelInfo, ModelRegistry}; use tokio_retry::{ strategy::{jitter, ExponentialBackoff}, Retry, @@ -72,7 +72,7 @@ async fn download_model_impl( name: &str, prefer_local_file: bool, ) -> Result<()> { - let model_info = registry.get_model_info(name); + let model_info = registry.get_model_info(name)?; registry.migrate_legacy_model_path(name)?; let urls = filter_download_address(model_info); @@ -187,15 +187,14 @@ async fn download_file( Ok(()) } -pub async fn download_model(model_id: &str, prefer_local_file: bool) { - let (registry, name) = parse_model_id(model_id); - - let registry = ModelRegistry::new(registry).await; - - let handler = |err| panic!("Failed to fetch model '{}' due to '{}'", model_id, err); - download_model_impl(®istry, name, prefer_local_file) +pub async fn download_model( + registry: &ModelRegistry, + model: &str, + prefer_local_file: bool, +) -> Result<()> { + download_model_impl(registry, model, prefer_local_file) .await - .unwrap_or_else(handler) + .map_err(|err| anyhow!("Failed to fetch model '{}' due to '{}'", model, err)) } #[cfg(test)] diff --git a/crates/tabby/src/download.rs b/crates/tabby/src/download.rs index dcc36c1d116e..c57c05680958 100644 --- a/crates/tabby/src/download.rs +++ b/crates/tabby/src/download.rs @@ -1,7 +1,8 @@ use clap::Args; -use tabby_download::download_model; use tracing::info; +use crate::services::model::Downloader; + #[derive(Args)] pub struct DownloadArgs { /// model id to fetch. @@ -14,6 +15,14 @@ pub struct DownloadArgs { } pub async fn main(args: &DownloadArgs) { - download_model(&args.model, args.prefer_local_file).await; + let mut downloader = Downloader::new(); + let (registry, _) = downloader + .get_model_registry_and_info(&args.model) + .await + .unwrap(); + downloader + .download_model(®istry, &args.model, args.prefer_local_file) + .await + .unwrap(); info!("model '{}' is ready", args.model); } diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index 23a6cf5fe9b5..a251b086e039 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -29,7 +29,7 @@ use crate::{ embedding, event::create_event_logger, health, - model::download_model_if_needed, + model::Downloader, tantivy::IndexReaderProvider, }, to_local_config, Device, @@ -113,7 +113,10 @@ pub struct ServeArgs { pub async fn main(config: &Config, args: &ServeArgs) { let config = merge_args(config, args); - load_model(&config).await; + if let Err(e) = load_model(&config).await { + warn!("Failed to load model: {}", e); + std::process::exit(1); + } let tx = try_run_spinner(); @@ -210,18 +213,22 @@ pub async fn main(config: &Config, args: &ServeArgs) { run_app(api, Some(ui), args.host, args.port).await } -async fn load_model(config: &Config) { +async fn load_model(config: &Config) -> anyhow::Result<()> { + let mut downloader = Downloader::new(); + if let Some(ModelConfig::Local(ref model)) = config.model.completion { - download_model_if_needed(&model.model_id).await; + downloader.download_completion(&model.model_id).await?; } if let Some(ModelConfig::Local(ref model)) = config.model.chat { - download_model_if_needed(&model.model_id).await; + downloader.download_chat(&model.model_id).await?; } if let ModelConfig::Local(ref model) = config.model.embedding { - download_model_if_needed(&model.model_id).await; + downloader.download_embedding(&model.model_id).await?; } + + Ok(()) } async fn api_router( diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index d433a75b1e94..87389ab64eaf 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -1,7 +1,11 @@ -use std::{fs, sync::Arc}; +use std::{collections::hash_map, fs, sync::Arc}; +use anyhow::{anyhow, bail, Result}; pub use llama_cpp_server::PromptInfo; -use tabby_common::config::ModelConfig; +use tabby_common::{ + config::ModelConfig, + registry::{parse_model_id, ModelInfo, ModelRegistry}, +}; use tabby_download::download_model; use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream, Embedding}; use tracing::info; @@ -79,11 +83,85 @@ async fn load_completion_and_chat( (completion, prompt, chat) } +pub struct Downloader { + registries: hash_map::HashMap, +} -pub async fn download_model_if_needed(model: &str) { - if fs::metadata(model).is_ok() { - info!("Loading model from local path {}", model); - } else { - download_model(model, true).await; +impl Downloader { + pub fn new() -> Self { + Self { + registries: hash_map::HashMap::new(), + } + } + + pub async fn get_model_registry_and_info( + &mut self, + model_id: &str, + ) -> Result<(ModelRegistry, ModelInfo)> { + let (registry_name, model_name) = parse_model_id(model_id)?; + + let registry = if let Some(registry) = self.registries.get(registry_name) { + registry.clone() + } else { + let registry = ModelRegistry::new(registry_name).await; + self.registries + .insert(registry_name.to_owned(), registry.clone()); + registry + }; + + let info = registry.get_model_info(model_name)?.clone(); + + Ok((registry, info)) + } + + pub async fn download_model( + &self, + registry: &ModelRegistry, + model_id: &str, + prefer_local_file: bool, + ) -> Result<()> { + let (_, model_name) = parse_model_id(model_id)?; + download_model(registry, model_name, prefer_local_file).await + } + + async fn download_model_with_validation_if_needed( + &mut self, + model_id: &str, + validation: fn(&ModelInfo) -> Result<()>, + ) -> Result<()> { + if fs::metadata(model_id).is_ok() { + info!("Loading model from local path {}", model_id); + return Ok(()); + } + + let (registry, info) = self.get_model_registry_and_info(model_id).await?; + validation(&info).map_err(|err| anyhow!("Fail to load model {}: {}", model_id, err))?; + + self.download_model(®istry, model_id, true).await + } + + pub async fn download_completion(&mut self, model_id: &str) -> Result<()> { + self.download_model_with_validation_if_needed(model_id, |info| { + if info.prompt_template.is_none() { + bail!("Model doesn't support completion"); + } + Ok(()) + }) + .await + } + + pub async fn download_chat(&mut self, model_id: &str) -> Result<()> { + self.download_model_with_validation_if_needed(model_id, |info| { + if info.chat_template.is_none() { + bail!("Model doesn't support chat"); + } + Ok(()) + }) + .await + } + + pub async fn download_embedding(&mut self, model_id: &str) -> Result<()> { + self.download_model_with_validation_if_needed(model_id, |_| Ok(())) + .await } }