Skip to content

Commit

Permalink
Support ADD_TIFLASH_ON_DEMAND option
Browse files Browse the repository at this point in the history
Signed-off-by: JaySon-Huang <tshent@qq.com>
  • Loading branch information
JaySon-Huang committed Nov 7, 2024
1 parent 3e99f50 commit 67bf470
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 2 deletions.
92 changes: 91 additions & 1 deletion tests/sqlalchemy/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
import numpy as np
import sqlalchemy
from sqlalchemy import URL, create_engine, Column, Integer, select
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.exc import OperationalError
from tidb_vector.sqlalchemy import VectorType, VectorAdaptor
from tidb_vector.sqlalchemy import VectorType, VectorAdaptor, VectorIndex
import tidb_vector
from ..config import TestConfig

Expand Down Expand Up @@ -385,3 +386,92 @@ def test_index_and_search(self):
)
assert len(items) == 2
assert items[0].distance == 0.0


class TestSQLAlchemyVectorIndex:

def setup_class(self):
Item2Model.__table__.drop(bind=engine, checkfirst=True)
Item2Model.__table__.create(bind=engine)

def teardown_class(self):
Item2Model.__table__.drop(bind=engine, checkfirst=True)

def test_create_vector_index_statement(self):
from sqlalchemy.sql.ddl import CreateIndex
l2_index = VectorIndex(
"idx_embedding_l2",
sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding),
)
compiled = CreateIndex(l2_index).compile(dialect=engine.dialect)
assert compiled.string == "CREATE VECTOR INDEX idx_embedding_l2 ON sqlalchemy_item2 ((vec_l2_distance(embedding))) ADD_TIFLASH_ON_DEMAND"

cos_index = VectorIndex(
"idx_embedding_cos",
sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding),
)
compiled = CreateIndex(cos_index).compile(dialect=engine.dialect)
assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding))) ADD_TIFLASH_ON_DEMAND"

def test_query_with_index(self):
# indexes
l2_index = VectorIndex(
"idx_embedding_l2",
sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding),
)
l2_index.create(engine)
cos_index = VectorIndex(
"idx_embedding_cos",
sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding),
)
cos_index.create(engine)

self.check_indexes(
Item2Model.__table__, ["idx_embedding_l2", "idx_embedding_cos"]
)

with Session() as session:
session.add_all(
[Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])]
)
session.commit()

# l2 distance
result_l2 = session.scalars(
select(Item2Model).filter(
Item2Model.embedding.l2_distance([1, 2, 3.1]) < 0.2
)
).all()
assert len(result_l2) == 2

distance_l2 = Item2Model.embedding.l2_distance([1, 2, 3])
items_l2 = (
session.query(Item2Model.id, distance_l2.label("distance"))
.order_by(distance_l2)
.limit(5)
.all()
)
assert len(items_l2) == 2
assert items_l2[0].distance == 0.0

# cosine distance
result_cos = session.scalars(
select(Item2Model).filter(
Item2Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2
)
).all()
assert len(result_cos) == 2

distance_cos = Item2Model.embedding.cosine_distance([1, 2, 3])
items_cos = (
session.query(Item2Model.id, distance_cos.label("distance"))
.order_by(distance_cos)
.limit(5)
.all()
)
assert len(items_cos) == 2
assert items_cos[0].distance == 0.0

# drop indexes
l2_index.drop(engine)
cos_index.drop(engine)
3 changes: 2 additions & 1 deletion tidb_vector/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .vector_type import VectorType
from .adaptor import VectorAdaptor
from .index import VectorIndex

__all__ = ["VectorType", "VectorAdaptor"]
__all__ = ["VectorType", "VectorAdaptor", "VectorIndex"]
29 changes: 29 additions & 0 deletions tidb_vector/sqlalchemy/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Optional, Any

import sqlalchemy

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Index

class VectorIndex(Index):
def __init__(
self,
name: Optional[str],
*expressions, # _DDLColumnArgument
_table: Optional[Any] = None,
**dialect_kw: Any,
):
super().__init__(name, *expressions, unique=False, _table=_table, **dialect_kw)
self.dialect_options["mysql"]["prefix"] = "VECTOR"
# add tiflash automatically when creating vector index
self.dialect_options["mysql"]["add_tiflash_on_demand"] = True

# VectorIndex.argument_for("mysql", "add_tiflash_on_demand", None)

@compiles(sqlalchemy.schema.CreateIndex)
def compile_create_vector_index(create_index_elem: sqlalchemy.sql.ddl.CreateIndex, compiler: sqlalchemy.sql.compiler.DDLCompiler, **kw):
text = compiler.visit_create_index(create_index_elem, **kw)
index_elem = create_index_elem.element
if index_elem.dialect_options.get("mysql", {}).get("add_tiflash_on_demand"):
text += " ADD_TIFLASH_ON_DEMAND"
return text

0 comments on commit 67bf470

Please sign in to comment.