From 408013c22be864a4ad263bdc5fbcf426103c7adb Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Thu, 14 Nov 2024 18:18:58 +0800 Subject: [PATCH] feat: add distance functions (#4987) * feat: add distance functions Signed-off-by: Zhenchi * fix: f64 instead Signed-off-by: Zhenchi * address comments Signed-off-by: Zhenchi * tiny adjust Signed-off-by: Zhenchi --------- Signed-off-by: Zhenchi --- Cargo.lock | 24 +- src/common/function/Cargo.toml | 1 + src/common/function/src/function_registry.rs | 4 + src/common/function/src/scalars.rs | 1 + src/common/function/src/scalars/vector.rs | 31 ++ .../function/src/scalars/vector/distance.rs | 469 ++++++++++++++++++ src/datatypes/src/value.rs | 6 +- .../common/types/vector/vector.result | 180 +++++++ .../standalone/common/types/vector/vector.sql | 49 ++ 9 files changed, 755 insertions(+), 10 deletions(-) create mode 100644 src/common/function/src/scalars/vector.rs create mode 100644 src/common/function/src/scalars/vector/distance.rs diff --git a/Cargo.lock b/Cargo.lock index c42187922c2a..b4c70b0dbe6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,7 +1041,7 @@ dependencies = [ "bitflags 2.6.0", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -2088,6 +2088,7 @@ dependencies = [ "serde", "serde_json", "session", + "simsimd", "snafu 0.8.5", "sql", "statrs", @@ -5090,7 +5091,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.4.10", + "socket2 0.5.7", "tokio", "tower-service", "tracing", @@ -6080,7 +6081,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -8822,7 +8823,7 @@ checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", "heck 0.5.0", - "itertools 0.10.5", + "itertools 0.12.1", "log", "multimap", "once_cell", @@ -8874,7 +8875,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.10.5", + "itertools 0.12.1", "proc-macro2", "quote", "syn 2.0.79", @@ -9036,7 +9037,7 @@ dependencies = [ "indoc", "libc", "memoffset 0.9.1", - "parking_lot 0.11.2", + "parking_lot 0.12.3", "portable-atomic", "pyo3-build-config", "pyo3-ffi", @@ -11198,6 +11199,15 @@ dependencies = [ "time", ] +[[package]] +name = "simsimd" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc843bc8f12d9c8e6b734a0fe8918fc497b42f6ae0f347dbfdad5b5138ab9b4" +dependencies = [ + "cc", +] + [[package]] name = "siphasher" version = "0.3.11" @@ -13981,7 +13991,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/src/common/function/Cargo.toml b/src/common/function/Cargo.toml index 6c1ecc2d381e..cb876b352dd9 100644 --- a/src/common/function/Cargo.toml +++ b/src/common/function/Cargo.toml @@ -41,6 +41,7 @@ s2 = { version = "0.0.12", optional = true } serde.workspace = true serde_json.workspace = true session.workspace = true +simsimd = "4" snafu.workspace = true sql.workspace = true statrs = "0.16" diff --git a/src/common/function/src/function_registry.rs b/src/common/function/src/function_registry.rs index 46af3b761072..04d68a93d85e 100644 --- a/src/common/function/src/function_registry.rs +++ b/src/common/function/src/function_registry.rs @@ -27,6 +27,7 @@ use crate::scalars::matches::MatchesFunction; use crate::scalars::math::MathFunction; use crate::scalars::numpy::NumpyFunction; use crate::scalars::timestamp::TimestampFunction; +use crate::scalars::vector::VectorFunction; use crate::system::SystemFunction; use crate::table::TableFunction; @@ -120,6 +121,9 @@ pub static FUNCTION_REGISTRY: Lazy> = Lazy::new(|| { // Json related functions JsonFunction::register(&function_registry); + // Vector related functions + VectorFunction::register(&function_registry); + // Geo functions #[cfg(feature = "geo")] crate::scalars::geo::GeoFunctions::register(&function_registry); diff --git a/src/common/function/src/scalars.rs b/src/common/function/src/scalars.rs index f60cf2b0d98b..52a238273d99 100644 --- a/src/common/function/src/scalars.rs +++ b/src/common/function/src/scalars.rs @@ -21,6 +21,7 @@ pub mod json; pub mod matches; pub mod math; pub mod numpy; +pub mod vector; #[cfg(test)] pub(crate) mod test; diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs new file mode 100644 index 000000000000..67b812fd09f0 --- /dev/null +++ b/src/common/function/src/scalars/vector.rs @@ -0,0 +1,31 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod distance; + +use std::sync::Arc; + +use distance::{CosDistanceFunction, DotProductFunction, L2SqDistanceFunction}; + +use crate::function_registry::FunctionRegistry; + +pub(crate) struct VectorFunction; + +impl VectorFunction { + pub fn register(registry: &FunctionRegistry) { + registry.register(Arc::new(CosDistanceFunction)); + registry.register(Arc::new(DotProductFunction)); + registry.register(Arc::new(L2SqDistanceFunction)); + } +} diff --git a/src/common/function/src/scalars/vector/distance.rs b/src/common/function/src/scalars/vector/distance.rs new file mode 100644 index 000000000000..c1259c229821 --- /dev/null +++ b/src/common/function/src/scalars/vector/distance.rs @@ -0,0 +1,469 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::fmt::Display; +use std::sync::Arc; + +use common_query::error::{InvalidFuncArgsSnafu, Result}; +use common_query::prelude::Signature; +use datatypes::prelude::ConcreteDataType; +use datatypes::scalars::ScalarVectorBuilder; +use datatypes::value::ValueRef; +use datatypes::vectors::{Float64VectorBuilder, MutableVector, Vector, VectorRef}; +use snafu::ensure; + +use crate::function::{Function, FunctionContext}; +use crate::helper; + +macro_rules! define_distance_function { + ($StructName:ident, $display_name:expr, $similarity_method:ident) => { + + /// A function calculates the distance between two vectors. + + #[derive(Debug, Clone, Default)] + pub struct $StructName; + + impl Function for $StructName { + fn name(&self) -> &str { + $display_name + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::float64_datatype()) + } + + fn signature(&self) -> Signature { + helper::one_of_sigs2( + vec![ + ConcreteDataType::string_datatype(), + ConcreteDataType::binary_datatype(), + ], + vec![ + ConcreteDataType::string_datatype(), + ConcreteDataType::binary_datatype(), + ], + ) + } + + fn eval(&self, _func_ctx: FunctionContext, columns: &[VectorRef]) -> Result { + ensure!( + columns.len() == 2, + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the args is not correct, expect exactly two, have: {}", + columns.len() + ), + } + ); + let arg0 = &columns[0]; + let arg1 = &columns[1]; + + let size = arg0.len(); + let mut result = Float64VectorBuilder::with_capacity(size); + if size == 0 { + return Ok(result.to_vector()); + } + + let arg0_const = parse_if_constant_string(arg0)?; + let arg1_const = parse_if_constant_string(arg1)?; + + for i in 0..size { + let vec0 = match arg0_const.as_ref() { + Some(a) => Some(Cow::Borrowed(a.as_slice())), + None => as_vector(arg0.get_ref(i))?, + }; + let vec1 = match arg1_const.as_ref() { + Some(b) => Some(Cow::Borrowed(b.as_slice())), + None => as_vector(arg1.get_ref(i))?, + }; + + if let (Some(vec0), Some(vec1)) = (vec0, vec1) { + ensure!( + vec0.len() == vec1.len(), + InvalidFuncArgsSnafu { + err_msg: format!( + "The length of the vectors must match to calculate distance, have: {} vs {}", + vec0.len(), + vec1.len() + ), + } + ); + + let f = ::$similarity_method; + // Safe: checked if the length of the vectors match + let d = f(vec0.as_ref(), vec1.as_ref()).unwrap(); + result.push(Some(d)); + } else { + result.push_null(); + } + } + + return Ok(result.to_vector()); + } + } + + impl Display for $StructName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", $display_name.to_ascii_uppercase()) + } + } + } +} + +define_distance_function!(CosDistanceFunction, "cos_distance", cos); +define_distance_function!(L2SqDistanceFunction, "l2sq_distance", l2sq); +define_distance_function!(DotProductFunction, "dot_product", dot); + +/// Parse a vector value if the value is a constant string. +fn parse_if_constant_string(arg: &Arc) -> Result>> { + if !arg.is_const() { + return Ok(None); + } + if arg.data_type() != ConcreteDataType::string_datatype() { + return Ok(None); + } + arg.get_ref(0) + .as_string() + .unwrap() // Safe: checked if it is a string + .map(parse_f32_vector_from_string) + .transpose() +} + +/// Convert a value to a vector value. +/// Supported data types are binary and string. +fn as_vector(arg: ValueRef<'_>) -> Result>> { + match arg.data_type() { + ConcreteDataType::Binary(_) => arg + .as_binary() + .unwrap() // Safe: checked if it is a binary + .map(|bytes| Ok(Cow::Borrowed(binary_as_vector(bytes)?))) + .transpose(), + ConcreteDataType::String(_) => arg + .as_string() + .unwrap() // Safe: checked if it is a string + .map(|s| Ok(Cow::Owned(parse_f32_vector_from_string(s)?))) + .transpose(), + ConcreteDataType::Null(_) => Ok(None), + _ => InvalidFuncArgsSnafu { + err_msg: format!("Unsupported data type: {:?}", arg.data_type()), + } + .fail(), + } +} + +/// Convert a u8 slice to a vector value. +fn binary_as_vector(bytes: &[u8]) -> Result<&[f32]> { + if bytes.len() % 4 != 0 { + return InvalidFuncArgsSnafu { + err_msg: format!("Invalid binary length of vector: {}", bytes.len()), + } + .fail(); + } + + unsafe { + let num_floats = bytes.len() / 4; + let floats: &[f32] = std::slice::from_raw_parts(bytes.as_ptr() as *const f32, num_floats); + Ok(floats) + } +} + +/// Parse a string to a vector value. +/// Valid inputs are strings like "[1.0, 2.0, 3.0]". +fn parse_f32_vector_from_string(s: &str) -> Result> { + let trimmed = s.trim(); + if !trimmed.starts_with('[') || !trimmed.ends_with(']') { + return InvalidFuncArgsSnafu { + err_msg: format!( + "Failed to parse {s} to Vector value: not properly enclosed in brackets" + ), + } + .fail(); + } + let content = trimmed[1..trimmed.len() - 1].trim(); + if content.is_empty() { + return Ok(Vec::new()); + } + + content + .split(',') + .map(|s| s.trim().parse::()) + .collect::>() + .map_err(|e| { + InvalidFuncArgsSnafu { + err_msg: format!("Failed to parse {s} to Vector value: {e}"), + } + .build() + }) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datatypes::vectors::{BinaryVector, ConstantVector, StringVector}; + + use super::*; + + #[test] + fn test_distance_string_string() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let vec1 = Arc::new(StringVector::from(vec![ + Some("[0.0, 1.0]"), + Some("[1.0, 0.0]"), + None, + Some("[1.0, 0.0]"), + ])) as VectorRef; + let vec2 = Arc::new(StringVector::from(vec![ + Some("[0.0, 1.0]"), + Some("[0.0, 1.0]"), + Some("[0.0, 1.0]"), + None, + ])) as VectorRef; + + let result = func + .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + + let result = func + .eval(FunctionContext::default(), &[vec2, vec1]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + } + } + + #[test] + fn test_distance_binary_binary() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let vec1 = Arc::new(BinaryVector::from(vec![ + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 128, 63, 0, 0, 0, 0]), + None, + Some(vec![0, 0, 128, 63, 0, 0, 0, 0]), + ])) as VectorRef; + let vec2 = Arc::new(BinaryVector::from(vec![ + // [0.0, 1.0] + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + None, + ])) as VectorRef; + + let result = func + .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + + let result = func + .eval(FunctionContext::default(), &[vec2, vec1]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + } + } + + #[test] + fn test_distance_string_binary() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let vec1 = Arc::new(StringVector::from(vec![ + Some("[0.0, 1.0]"), + Some("[1.0, 0.0]"), + None, + Some("[1.0, 0.0]"), + ])) as VectorRef; + let vec2 = Arc::new(BinaryVector::from(vec![ + // [0.0, 1.0] + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + None, + ])) as VectorRef; + + let result = func + .eval(FunctionContext::default(), &[vec1.clone(), vec2.clone()]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + + let result = func + .eval(FunctionContext::default(), &[vec2, vec1]) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(result.get(3).is_null()); + } + } + + #[test] + fn test_distance_const_string() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let const_str = Arc::new(ConstantVector::new( + Arc::new(StringVector::from(vec!["[0.0, 1.0]"])), + 4, + )); + + let vec1 = Arc::new(StringVector::from(vec![ + Some("[0.0, 1.0]"), + Some("[1.0, 0.0]"), + None, + Some("[1.0, 0.0]"), + ])) as VectorRef; + let vec2 = Arc::new(BinaryVector::from(vec![ + // [0.0, 1.0] + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + Some(vec![0, 0, 0, 0, 0, 0, 128, 63]), + None, + ])) as VectorRef; + + let result = func + .eval( + FunctionContext::default(), + &[const_str.clone(), vec1.clone()], + ) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(!result.get(3).is_null()); + + let result = func + .eval( + FunctionContext::default(), + &[vec1.clone(), const_str.clone()], + ) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(result.get(2).is_null()); + assert!(!result.get(3).is_null()); + + let result = func + .eval( + FunctionContext::default(), + &[const_str.clone(), vec2.clone()], + ) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(!result.get(2).is_null()); + assert!(result.get(3).is_null()); + + let result = func + .eval( + FunctionContext::default(), + &[vec2.clone(), const_str.clone()], + ) + .unwrap(); + + assert!(!result.get(0).is_null()); + assert!(!result.get(1).is_null()); + assert!(!result.get(2).is_null()); + assert!(result.get(3).is_null()); + } + } + + #[test] + fn test_invalid_vector_length() { + let funcs = [ + Box::new(CosDistanceFunction {}) as Box, + Box::new(L2SqDistanceFunction {}) as Box, + Box::new(DotProductFunction {}) as Box, + ]; + + for func in funcs { + let vec1 = Arc::new(StringVector::from(vec!["[1.0]"])) as VectorRef; + let vec2 = Arc::new(StringVector::from(vec!["[1.0, 1.0]"])) as VectorRef; + let result = func.eval(FunctionContext::default(), &[vec1, vec2]); + assert!(result.is_err()); + + let vec1 = Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63]])) as VectorRef; + let vec2 = + Arc::new(BinaryVector::from(vec![vec![0, 0, 128, 63, 0, 0, 0, 64]])) as VectorRef; + let result = func.eval(FunctionContext::default(), &[vec1, vec2]); + assert!(result.is_err()); + } + } + + #[test] + fn test_parse_vector_from_string() { + let result = parse_f32_vector_from_string("[1.0, 2.0, 3.0]").unwrap(); + assert_eq!(result, vec![1.0, 2.0, 3.0]); + + let result = parse_f32_vector_from_string("[]").unwrap(); + assert_eq!(result, Vec::::new()); + + let result = parse_f32_vector_from_string("[1.0, a, 3.0]"); + assert!(result.is_err()); + } + + #[test] + fn test_binary_as_vector() { + let bytes = [0, 0, 128, 63]; + let result = binary_as_vector(&bytes).unwrap(); + assert_eq!(result, &[1.0]); + + let invalid_bytes = [0, 0, 128]; + let result = binary_as_vector(&invalid_bytes); + assert!(result.is_err()); + } +} diff --git a/src/datatypes/src/value.rs b/src/datatypes/src/value.rs index d0f36139ca4a..b57b364cf3e5 100644 --- a/src/datatypes/src/value.rs +++ b/src/datatypes/src/value.rs @@ -1089,7 +1089,7 @@ macro_rules! impl_as_for_value_ref { }; } -impl ValueRef<'_> { +impl<'a> ValueRef<'a> { define_data_type_func!(ValueRef); /// Returns true if this is null. @@ -1098,12 +1098,12 @@ impl ValueRef<'_> { } /// Cast itself to binary slice. - pub fn as_binary(&self) -> Result> { + pub fn as_binary(&self) -> Result> { impl_as_for_value_ref!(self, Binary) } /// Cast itself to string slice. - pub fn as_string(&self) -> Result> { + pub fn as_string(&self) -> Result> { impl_as_for_value_ref!(self, String) } diff --git a/tests/cases/standalone/common/types/vector/vector.result b/tests/cases/standalone/common/types/vector/vector.result index d9b5a2e61e70..ee9bbf45af25 100644 --- a/tests/cases/standalone/common/types/vector/vector.result +++ b/tests/cases/standalone/common/types/vector/vector.result @@ -31,6 +31,186 @@ SELECT * FROM t; | 1970-01-01 00:00:00.003000 | "[7,8,9]" | +----------------------------+-----------+ +SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t; + ++-----------------------------------------------------------+ +| round(cos_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) | ++-----------------------------------------------------------+ +| 1.0 | +| 1.0 | +| 1.0 | ++-----------------------------------------------------------+ + +SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-----+ +| ts | v | d | ++-------------------------+--------------------------+-----+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 1.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 1.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 1.0 | ++-------------------------+--------------------------+-----+ + +SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 4) FROM t; + ++-----------------------------------------------------------+ +| round(cos_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) | ++-----------------------------------------------------------+ +| 0.0406 | +| 0.0018 | +| 0.0 | ++-----------------------------------------------------------+ + +SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+--------+ +| ts | v | d | ++-------------------------+--------------------------+--------+ +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.0018 | +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0406 | ++-------------------------+--------------------------+--------+ + +SELECT round(cos_distance(v, v), 4) FROM t; + ++---------------------------------------+ +| round(cos_distance(t.v,t.v),Int64(4)) | ++---------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++---------------------------------------+ + +-- Unexpected dimension -- +SELECT cos_distance(v, '[1.0]') FROM t; + +Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 + +-- Invalid type -- +SELECT cos_distance(v, 1.0) FROM t; + +Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 + +SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t; + ++------------------------------------------------------------+ +| round(l2sq_distance(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) | ++------------------------------------------------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++------------------------------------------------------------+ + +SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-------+ +| ts | v | d | ++-------------------------+--------------------------+-------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 14.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 77.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 | ++-------------------------+--------------------------+-------+ + +SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) FROM t; + ++------------------------------------------------------------+ +| round(l2sq_distance(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) | ++------------------------------------------------------------+ +| 108.0 | +| 27.0 | +| 0.0 | ++------------------------------------------------------------+ + +SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-------+ +| ts | v | d | ++-------------------------+--------------------------+-------+ +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 27.0 | +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 108.0 | ++-------------------------+--------------------------+-------+ + +SELECT round(l2sq_distance(v, v), 4) FROM t; + ++----------------------------------------+ +| round(l2sq_distance(t.v,t.v),Int64(4)) | ++----------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++----------------------------------------+ + +-- Unexpected dimension -- +SELECT l2sq_distance(v, '[1.0]') FROM t; + +Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 + +-- Invalid type -- +SELECT l2sq_distance(v, 1.0) FROM t; + +Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 + +SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) FROM t; + ++----------------------------------------------------------+ +| round(dot_product(t.v,Utf8("[0.0, 0.0, 0.0]")),Int64(4)) | ++----------------------------------------------------------+ +| 0.0 | +| 0.0 | +| 0.0 | ++----------------------------------------------------------+ + +SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-----+ +| ts | v | d | ++-------------------------+--------------------------+-----+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 0.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 0.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 0.0 | ++-------------------------+--------------------------+-----+ + +SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 4) FROM t; + ++----------------------------------------------------------+ +| round(dot_product(Utf8("[7.0, 8.0, 9.0]"),t.v),Int64(4)) | ++----------------------------------------------------------+ +| 50.0 | +| 122.0 | +| 194.0 | ++----------------------------------------------------------+ + +SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; + ++-------------------------+--------------------------+-------+ +| ts | v | d | ++-------------------------+--------------------------+-------+ +| 1970-01-01T00:00:00.001 | 0000803f0000004000004040 | 50.0 | +| 1970-01-01T00:00:00.002 | 000080400000a0400000c040 | 122.0 | +| 1970-01-01T00:00:00.003 | 0000e0400000004100001041 | 194.0 | ++-------------------------+--------------------------+-------+ + +SELECT round(dot_product(v, v), 4) FROM t; + ++--------------------------------------+ +| round(dot_product(t.v,t.v),Int64(4)) | ++--------------------------------------+ +| 14.0 | +| 77.0 | +| 194.0 | ++--------------------------------------+ + +-- Unexpected dimension -- +SELECT dot_product(v, '[1.0]') FROM t; + +Error: 3001(EngineExecuteQuery), Invalid function args: The length of the vectors must match to calculate distance, have: 3 vs 1 + +-- Invalid type -- +SELECT dot_product(v, 1.0) FROM t; + +Error: 3001(EngineExecuteQuery), Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2 + -- Unexpected dimension -- INSERT INTO t VALUES (4, "[1.0]"); diff --git a/tests/cases/standalone/common/types/vector/vector.sql b/tests/cases/standalone/common/types/vector/vector.sql index 376f356aaa66..cea3ef406c63 100644 --- a/tests/cases/standalone/common/types/vector/vector.sql +++ b/tests/cases/standalone/common/types/vector/vector.sql @@ -11,6 +11,55 @@ SELECT * FROM t; -- SQLNESS PROTOCOL POSTGRES SELECT * FROM t; +SELECT round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t; + +SELECT *, round(cos_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; + +SELECT round(cos_distance('[7.0, 8.0, 9.0]', v), 4) FROM t; + +SELECT *, round(cos_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; + +SELECT round(cos_distance(v, v), 4) FROM t; + +-- Unexpected dimension -- +SELECT cos_distance(v, '[1.0]') FROM t; + +-- Invalid type -- +SELECT cos_distance(v, 1.0) FROM t; + +SELECT round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) FROM t; + +SELECT *, round(l2sq_distance(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; + +SELECT round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) FROM t; + +SELECT *, round(l2sq_distance('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; + +SELECT round(l2sq_distance(v, v), 4) FROM t; + +-- Unexpected dimension -- +SELECT l2sq_distance(v, '[1.0]') FROM t; + +-- Invalid type -- +SELECT l2sq_distance(v, 1.0) FROM t; + + +SELECT round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) FROM t; + +SELECT *, round(dot_product(v, '[0.0, 0.0, 0.0]'), 4) as d FROM t ORDER BY d; + +SELECT round(dot_product('[7.0, 8.0, 9.0]', v), 4) FROM t; + +SELECT *, round(dot_product('[7.0, 8.0, 9.0]', v), 4) as d FROM t ORDER BY d; + +SELECT round(dot_product(v, v), 4) FROM t; + +-- Unexpected dimension -- +SELECT dot_product(v, '[1.0]') FROM t; + +-- Invalid type -- +SELECT dot_product(v, 1.0) FROM t; + -- Unexpected dimension -- INSERT INTO t VALUES (4, "[1.0]");