diff --git a/src/bk-user/bkuser/apis/login/views.py b/src/bk-user/bkuser/apis/login/views.py index cc68ad862..f98cf3540 100644 --- a/src/bk-user/bkuser/apis/login/views.py +++ b/src/bk-user/bkuser/apis/login/views.py @@ -8,17 +8,15 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ -import operator -from functools import reduce from django.utils.translation import gettext_lazy as _ from rest_framework import generics from rest_framework.response import Response -from bkuser.apps.data_source.models import DataSourceUser, LocalDataSourceIdentityInfo -from bkuser.apps.idp.data_models import convert_match_rules_to_queryset_filter +from bkuser.apps.data_source.models import LocalDataSourceIdentityInfo from bkuser.apps.idp.models import Idp from bkuser.apps.tenant.models import Tenant, TenantUser +from bkuser.biz.idp import AuthenticationMatcher from bkuser.common.error_codes import error_codes from .mixins import LoginApiAccessControlMixin @@ -126,27 +124,14 @@ def post(self, request, *args, **kwargs): # 认证源 idp_id = kwargs["idp_id"] - idp = Idp.objects.filter(owner_tenant_id=tenant_id, id=idp_id).first() - if not idp: + if not Idp.objects.filter(owner_tenant_id=tenant_id, id=idp_id).exists(): raise error_codes.OBJECT_NOT_FOUND.f(_("认证源 {} 不存在").format(idp_id)) # FIXME: 查询是绑定匹配还是直接匹配, # 一般社会化登录都得通过绑定匹配方式,比如QQ,用户得先绑定后才能使用QQ登录 # 直接匹配,一般是企业身份登录方式, # 比如企业内部SAML2.0登录,认证后获取到的用户字段,能直接与数据源里的用户数据字段匹配 - # 将规则转换为Django Queryset 过滤条件, 不同用户之间过滤逻辑是OR - conditions = [ - condition - for userinfo in data["idp_users"] - if (condition := convert_match_rules_to_queryset_filter(idp.data_source_match_rule_objs, userinfo)) - ] - - # 查询数据源用户 - data_source_user_ids = ( - DataSourceUser.objects.filter(reduce(operator.or_, conditions)).values_list("id", flat=True) - if conditions - else [] - ) + data_source_user_ids = AuthenticationMatcher(tenant_id, idp_id).match(data["idp_users"]) # 查询租户用户 tenant_users = TenantUser.objects.filter( diff --git a/src/bk-user/bkuser/apis/web/idp/serializers.py b/src/bk-user/bkuser/apis/web/idp/serializers.py index 9f051dc9c..0086f90a4 100644 --- a/src/bk-user/bkuser/apis/web/idp/serializers.py +++ b/src/bk-user/bkuser/apis/web/idp/serializers.py @@ -19,7 +19,7 @@ from bkuser.apps.data_source.models import DataSource from bkuser.apps.idp.constants import IdpStatus from bkuser.apps.idp.models import Idp, IdpPlugin -from bkuser.apps.tenant.models import UserBuiltinField +from bkuser.apps.tenant.models import TenantUserCustomField, UserBuiltinField from bkuser.idp_plugins.base import get_plugin_cfg_cls from bkuser.idp_plugins.constants import BuiltinIdpPluginEnum from bkuser.utils.pydantic import stringify_pydantic_error @@ -92,18 +92,14 @@ def validate(self, attrs: Dict[str, Any]) -> Dict[str, Any]: if not DataSource.objects.filter(id=attrs["data_source_id"], owner_tenant_id=tenant_id).exists(): raise ValidationError(_("数据源必须是当前租户下的,{} 并不符合").format(attrs["data_source_id"])) - # # 匹配的数据源字段必须是当前租户的用户字段,包括内建字段和自定义字段 + # 匹配的数据源字段必须是当前租户的用户字段,包括内建字段和自定义字段 builtin_fields = set(UserBuiltinField.objects.all().values_list("name", flat=True)) - # custom_fields = set(TenantUserCustomField.objects.filter(tenant_id=tenant_id).values_list("name", flat=True)) - # allowed_target_fields = builtin_fields | custom_fields - # + custom_fields = set(TenantUserCustomField.objects.filter(tenant_id=tenant_id).values_list("name", flat=True)) + allowed_target_fields = builtin_fields | custom_fields + target_fields = {r.get("target_field") for r in attrs["field_compare_rules"]} - # if not_found_fields := target_fields - allowed_target_fields: - # raise ValidationError(_("匹配的数据源字段 {} 不属于用户自定义字段或内置字段").format(not_found_fields)) - if not_found_fields := target_fields - builtin_fields: - raise ValidationError( - _("匹配的数据源字段 {} 不属于用户内置字段,当前仅支持匹配内置字段").format(not_found_fields) - ) + if not_found_fields := target_fields - allowed_target_fields: + raise ValidationError(_("匹配的数据源字段 {} 不属于用户自定义字段或内置字段").format(not_found_fields)) return attrs diff --git a/src/bk-user/bkuser/apps/idp/data_models.py b/src/bk-user/bkuser/apps/idp/data_models.py index f010352bc..13c37b041 100644 --- a/src/bk-user/bkuser/apps/idp/data_models.py +++ b/src/bk-user/bkuser/apps/idp/data_models.py @@ -8,11 +8,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ -import operator -from functools import reduce -from typing import Any, Dict, List +from typing import List -from django.db.models import Q from pydantic import BaseModel, TypeAdapter from .constants import AllowBindScopeObjectType @@ -38,59 +35,10 @@ class DataSourceMatchRule(BaseModel): # 字段匹配规则 field_compare_rules: List[FieldCompareRule] - def convert_to_queryset_filter(self, source_data: Dict[str, Any]) -> Q | None: - """ - 将匹配规则转换为Django QuerySet过滤条件 - :param source_data: 认证源数据 - :return Django Queryset Q 查询表达式 - example: - self: - { - "data_source_id": 1, - "field_compare_rules": [ - {"source_field": "user_id", "target_field": "username", "operator": "equal"}, - {"source_field": "telephone", "target_field": "phone", "operator": "equal"}, - ] - } - source_data: {"user_id": "zhangsan", "telephone": "12345678901", "company_email": "test@example.com"} - return: (Q(data_source_id=1) & Q(username="zhangsan") & Q(phone="12345678901")) - """ - conditions = [{"data_source_id": self.data_source_id}] - # 无字段比较,相当于无法匹配,直接返回 - if not self.field_compare_rules: - return None - - # 每个认证源字段与数据源字段的比较规则 - for rule in self.field_compare_rules: - # 数据里没有规则需要比较的字段,则一定无法匹配,所以无需继续 - if rule.source_field not in source_data: - return None - - conditions.append( - { - # Note: 目前仅仅是equal的比较操作符,所以这里暂时简单处理, - # 后续支持其他操作符再抽象出Converter来处理 - rule.target_field: source_data[rule.source_field], - } - ) - - return reduce(operator.and_, [Q(**c) for c in conditions]) - DataSourceMatchRuleList = TypeAdapter(List[DataSourceMatchRule]) -def convert_match_rules_to_queryset_filter( - match_rules: List[DataSourceMatchRule], source_data: Dict[str, Any] -) -> Q | None: - """ - 将规则列表转换为Queryset查询条件 - 不同匹配规则之间的关系是OR, 匹配规则里不同字段的关系是AND - """ - q_list = [q for rule in match_rules if (q := rule.convert_to_queryset_filter(source_data))] - return reduce(operator.or_, q_list) if q_list else None - - def gen_data_source_match_rule_of_local(data_source_id: int) -> DataSourceMatchRule: """生成本地账密认证源的匹配规则""" return DataSourceMatchRule( diff --git a/src/bk-user/bkuser/biz/idp.py b/src/bk-user/bkuser/biz/idp.py new file mode 100644 index 000000000..965baa78e --- /dev/null +++ b/src/bk-user/bkuser/biz/idp.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +""" +TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-用户管理(Bk-User) available. +Copyright (C) 2017-2021 THL A29 Limited, a Tencent company. All rights reserved. +Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. +You may obtain a copy of the License at http://opensource.org/licenses/MIT +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" +import operator +from functools import reduce +from typing import Any, Dict, List + +from django.db.models import Q + +from bkuser.apps.data_source.models import DataSourceUser +from bkuser.apps.idp.data_models import DataSourceMatchRule +from bkuser.apps.idp.models import Idp +from bkuser.apps.tenant.constants import UserFieldDataType +from bkuser.apps.tenant.models import TenantUserCustomField, UserBuiltinField + + +class AuthenticationMatcher: + """认证匹配,用于对认证后的用户字段匹配到对应的数据源""" + + def __init__(self, tenant_id: str, idp_id: str): + # TODO: 后续支持协同租户的数据源用户匹配 + self.idp = Idp.objects.get(id=idp_id, owner_tenant_id=tenant_id) + self.builtin_field_data_type_map = dict(UserBuiltinField.objects.all().values_list("name", "data_type")) + self.custom_field_data_type_map = dict( + TenantUserCustomField.objects.filter(tenant_id=tenant_id).values_list("name", "data_type") + ) + + def match(self, idp_users: List[Dict[str, Any]]) -> List[int]: + """匹配出数据源用户ID""" + # 将规则转换为Django Queryset 过滤条件, 不同用户之间过滤逻辑是OR + conditions = [ + condition for userinfo in idp_users if (condition := self._convert_rules_to_queryset_filter(userinfo)) + ] + + if not conditions: + return [] + # 查询数据源用户 + return DataSourceUser.objects.filter(reduce(operator.or_, conditions)).values_list("id", flat=True) + + def _convert_rules_to_queryset_filter(self, source_data: Dict[str, Any]) -> Q | None: + """ + 将规则列表转换为Queryset查询条件 + 不同匹配规则之间的关系是OR, 匹配规则里不同字段的关系是AND + """ + q_list = [ + q + for rule in self.idp.data_source_match_rule_objs + if (q := self._convert_one_rule_to_queryset_filter(rule, source_data)) + ] + return reduce(operator.or_, q_list) if q_list else None + + def _convert_one_rule_to_queryset_filter( + self, match_rule: DataSourceMatchRule, source_data: Dict[str, Any] + ) -> Q | None: + """ + 将匹配规则转换为Django QuerySet过滤条件 + :param match_rule: 匹配规则 + :param source_data: 认证源数据 + :return Django Queryset Q 查询表达式 + example: + self: + { + "data_source_id": 1, + "field_compare_rules": [ + {"source_field": "user_id", "target_field": "username", "operator": "equal"}, + {"source_field": "telephone", "target_field": "phone", "operator": "equal"}, + ] + } + source_data: {"user_id": "zhangsan", "telephone": "12345678901", "company_email": "test@example.com"} + return: (Q(data_source_id=1) & Q(username="zhangsan") & Q(phone="12345678901")) + """ + conditions = [{"data_source_id": match_rule.data_source_id}] + # 无字段比较,相当于无法匹配,直接返回 + if not match_rule.field_compare_rules: + return None + + # 每个认证源字段与数据源字段的比较规则 + for rule in match_rule.field_compare_rules: + # 数据里没有规则需要比较的字段,则一定无法匹配,无需继续 + if rule.source_field not in source_data: + return None + + filter_key = self._build_field_filter_key(rule.target_field) + if not filter_key: + return None + + # Note: 目前仅仅是equal的比较操作符,所以这里暂时简单处理, + # 后续支持其他操作符再抽象出Converter来处理 + conditions.append({filter_key: source_data[rule.source_field]}) + + return reduce(operator.and_, [Q(**c) for c in conditions]) + + def _build_field_filter_key(self, field: str) -> str | None: + """ + 构建字段的Django过滤Key + 1. 内建字段,key=field + 2. 用户自定义字段,在extras字段里,以JSON方式存储 + - data_type=string/number/enum: key=f"extras__{field}" + - data_type=multi_enum: key=f"extras__{field}__contains" + + Django JSONField查询: https://docs.djangoproject.com/en/3.2/topics/db/queries/#querying-jsonfield + Note: JSON查询性能出现问题时,可以通过额外运维方式创建虚列索引来解决 + ALTER TABLE my_table ADD COLUMN field_col INT AS (JSON_UNQUOTE(JSON_EXTRACT(extras, '$.my_field'))) VIRTUAL; + CREATE INDEX idx_field_col ON my_table (field_col); 最好与data_source_id字段一起联合索引 + """ + # 内建字段 + if field in self.builtin_field_data_type_map: + return field + + # 自定义字段,且data_type=string/number/enum + if field in self.custom_field_data_type_map: + data_type = self.custom_field_data_type_map[field] + # string/number/enum + if data_type in [UserFieldDataType.STRING, UserFieldDataType.NUMBER, UserFieldDataType.ENUM]: + return f"extras__{field}" + + # multi_enum + if data_type in UserFieldDataType.MULTI_ENUM: + return f"extras__{field}__contains" + + # 非预期的字段和数据类型,都无法匹配 + return None diff --git a/src/bk-user/tests/apis/web/idp/test_idp.py b/src/bk-user/tests/apis/web/idp/test_idp.py index 5ed9226b8..2e525d890 100644 --- a/src/bk-user/tests/apis/web/idp/test_idp.py +++ b/src/bk-user/tests/apis/web/idp/test_idp.py @@ -154,12 +154,12 @@ def test_create_with_invalid_data_source_match_rules(self, api_client, wecom_plu request_data["data_source_match_rules"] = [ { "data_source_id": default_data_source.id, - "field_compare_rules": [{"source_field": "user_id", "target_field": "not_builtin_field"}], + "field_compare_rules": [{"source_field": "user_id", "target_field": generate_random_string()}], } ] resp = api_client.post(reverse("idp.list_create"), data=request_data) assert resp.status_code == status.HTTP_400_BAD_REQUEST - assert "当前仅支持匹配内置字段" in resp.data["message"] + assert "不属于用户自定义字段或内置字段" in resp.data["message"] def test_create_with_empty_data_source_match_rules(self, api_client, wecom_plugin_cfg, default_data_source): request_data = { diff --git a/src/bk-user/tests/apps/idp/__init__.py b/src/bk-user/tests/apps/idp/__init__.py deleted file mode 100644 index 1060b7bf4..000000000 --- a/src/bk-user/tests/apps/idp/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- -""" -TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-用户管理(Bk-User) available. -Copyright (C) 2017-2021 THL A29 Limited, a Tencent company. All rights reserved. -Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. -You may obtain a copy of the License at http://opensource.org/licenses/MIT -Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -specific language governing permissions and limitations under the License. -""" diff --git a/src/bk-user/tests/apps/idp/test_data_models.py b/src/bk-user/tests/apps/idp/test_data_models.py deleted file mode 100644 index c49d22893..000000000 --- a/src/bk-user/tests/apps/idp/test_data_models.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- -""" -TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-用户管理(Bk-User) available. -Copyright (C) 2017-2021 THL A29 Limited, a Tencent company. All rights reserved. -Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. -You may obtain a copy of the License at http://opensource.org/licenses/MIT -Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -specific language governing permissions and limitations under the License. -""" -import pytest -from bkuser.apps.idp.data_models import DataSourceMatchRule, FieldCompareRule -from django.db.models import Q - -pytestmark = pytest.mark.django_db - - -@pytest.mark.parametrize( - ("source_data", "excepted_queryset"), - [ - # valid - ( - {"user_id": "test_username", "phone": "1234567890123"}, - (Q(data_source_id=1) & Q(username="test_username") & Q(phone="1234567890123")), - ), - ( - {"user_id": "test_username", "phone": "1234567890123", "email": "111@qq.com"}, - (Q(data_source_id=1) & Q(username="test_username") & Q(phone="1234567890123")), - ), - # invalid - ( - {"user_id": "test_username"}, - None, - ), - ], -) -def test_convert_to_queryset_filter_for_source_data(source_data, excepted_queryset): - data_source_match_rule = DataSourceMatchRule( - data_source_id=1, - field_compare_rules=[ - FieldCompareRule(source_field="user_id", target_field="username"), - FieldCompareRule(source_field="phone", target_field="phone"), - ], - ) - queryset = data_source_match_rule.convert_to_queryset_filter(source_data) - - assert queryset == excepted_queryset - - -@pytest.mark.parametrize( - ("rule", "excepted_queryset"), - [ - # No field compare rule - ( - DataSourceMatchRule(data_source_id=1, field_compare_rules=[]), - None, - ), - # one field compare rule - ( - DataSourceMatchRule( - data_source_id=1, - field_compare_rules=[FieldCompareRule(source_field="user_id", target_field="username")], - ), - (Q(data_source_id=1) & Q(username="test_username")), - ), - # Mult field compare rule - ( - DataSourceMatchRule( - data_source_id=1, - field_compare_rules=[ - FieldCompareRule(source_field="user_id", target_field="username"), - FieldCompareRule(source_field="phone", target_field="phone"), - FieldCompareRule(source_field="email", target_field="private_email"), - ], - ), - ( - Q(data_source_id=1) - & Q(username="test_username") - & Q(phone="1234567890123") - & Q(private_email="111@qq.com") - ), - ), - ], -) -def test_convert_to_queryset_filter_for_rule(rule, excepted_queryset): - source_data = {"user_id": "test_username", "phone": "1234567890123", "email": "111@qq.com"} - - queryset = rule.convert_to_queryset_filter(source_data) - - assert queryset == excepted_queryset diff --git a/src/bk-user/tests/apps/sync/test_converters.py b/src/bk-user/tests/apps/sync/test_converters.py index 7fe5db28d..3aace3a68 100644 --- a/src/bk-user/tests/apps/sync/test_converters.py +++ b/src/bk-user/tests/apps/sync/test_converters.py @@ -58,7 +58,17 @@ def test_get_field_mapping_from_tenant_user_fields( ): assert DataSourceUserConverter(bare_local_data_source, logger).field_mapping == [ DataSourceUserFieldMapping(source_field=f, mapping_operation=FieldMappingOperation.DIRECT, target_field=f) - for f in ["username", "full_name", "email", "phone", "phone_country_code", "age", "gender", "region"] + for f in [ + "username", + "full_name", + "email", + "phone", + "phone_country_code", + "age", + "gender", + "region", + "sport_hobby", + ] ] def test_convert_user_enum_field_default(self, bare_local_data_source, tenant_user_custom_fields, logger): @@ -71,6 +81,7 @@ def test_convert_user_enum_field_default(self, bare_local_data_source, tenant_us "phone": "13512345671", "age": "18", "region": "beijing", + "sport_hobby": "golf", }, leaders=[], departments=["company"], @@ -83,7 +94,7 @@ def test_convert_user_enum_field_default(self, bare_local_data_source, tenant_us assert zhangsan.email == "zhangsan@m.com" assert zhangsan.phone == "13512345671" assert zhangsan.phone_country_code == "86" - assert zhangsan.extras == {"age": "18", "gender": "male", "region": "beijing"} + assert zhangsan.extras == {"age": "18", "gender": "male", "region": "beijing", "sport_hobby": "golf"} def test_convert_use_string_field_default(self, bare_local_data_source, tenant_user_custom_fields, logger): raw_lisi = RawDataSourceUser( @@ -96,6 +107,7 @@ def test_convert_use_string_field_default(self, bare_local_data_source, tenant_u "phone_country_code": "63", "age": "28", "gender": "female", + "sport_hobby": "golf", }, leaders=["zhangsan"], departments=["dept_a", "center_aa"], @@ -108,7 +120,7 @@ def test_convert_use_string_field_default(self, bare_local_data_source, tenant_u assert lisi.email == "lisi@m.com" assert lisi.phone == "13512345672" assert lisi.phone_country_code == "63" - assert lisi.extras == {"age": "28", "gender": "female", "region": ""} + assert lisi.extras == {"age": "28", "gender": "female", "region": "", "sport_hobby": "golf"} def test_convert_with_not_same_field_name_mapping(self, bare_local_data_source, tenant_user_custom_fields, logger): raw_lisi = RawDataSourceUser( @@ -121,6 +133,7 @@ def test_convert_with_not_same_field_name_mapping(self, bare_local_data_source, "phone_country_code": "63", "age": "28", "gender": "female", + "sport_hobby": "golf", "custom_region": "shanghai", }, leaders=["zhangsan"], @@ -129,10 +142,10 @@ def test_convert_with_not_same_field_name_mapping(self, bare_local_data_source, converter = DataSourceUserConverter(bare_local_data_source, logger) # 修改数据以生成不同字段名映射的比较麻烦,这里采用的是直接修改 Converter 的 field_mapping 属性 - converter.field_mapping[-1].source_field = "custom_region" + converter.field_mapping[-2].source_field = "custom_region" lisi = converter.convert(raw_lisi) - assert lisi.extras == {"age": "28", "gender": "female", "region": "shanghai"} + assert lisi.extras == {"age": "28", "gender": "female", "sport_hobby": "golf", "region": "shanghai"} def test_convert_with_invalid_username(self, bare_local_data_source, logger): raw_user = RawDataSourceUser(code="test", properties={}, leaders=[], departments=[]) diff --git a/src/bk-user/tests/biz/test_exporters.py b/src/bk-user/tests/biz/test_exporters.py index e3e256938..cdf36036f 100644 --- a/src/bk-user/tests/biz/test_exporters.py +++ b/src/bk-user/tests/biz/test_exporters.py @@ -33,6 +33,7 @@ def test_get_template(self, bare_local_data_source, tenant_user_custom_fields): "年龄/age", "性别/gender", "籍贯/region", + "运动爱好/sport_hobby", ] def test_export(self, full_local_data_source, tenant_user_custom_fields): diff --git a/src/bk-user/tests/biz/test_idp.py b/src/bk-user/tests/biz/test_idp.py new file mode 100644 index 000000000..254497265 --- /dev/null +++ b/src/bk-user/tests/biz/test_idp.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +""" +TencentBlueKing is pleased to support the open source community by making 蓝鲸智云-用户管理(Bk-User) available. +Copyright (C) 2017-2021 THL A29 Limited, a Tencent company. All rights reserved. +Licensed under the MIT License (the "License"); you may not use this file except in compliance with the License. +You may obtain a copy of the License at http://opensource.org/licenses/MIT +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. +""" +import pytest +from bkuser.apps.idp.data_models import DataSourceMatchRule, FieldCompareRule +from bkuser.apps.idp.models import Idp +from bkuser.biz.idp import AuthenticationMatcher +from django.db.models import Q + + +@pytest.mark.django_db() +class TestAuthenticationMatcher: + @pytest.fixture(autouse=True) + def _initialize(self, default_tenant, tenant_user_custom_fields): + # 初始化IDP + default_idp = Idp.objects.filter(owner_tenant_id=default_tenant.id).first() + assert default_idp is not None + + self.matcher = AuthenticationMatcher(default_tenant.id, default_idp.id) + + @pytest.mark.parametrize( + ("field", "excepted_filter_key"), + [ + ("username", "username"), + ("full_name", "full_name"), + ("phone_country_code", "phone_country_code"), + ("phone", "phone"), + ("email", "email"), + ("age", "extras__age"), + ("gender", "extras__gender"), + ("region", "extras__region"), + ("sport_hobby", "extras__sport_hobby__contains"), + ("other_not_found", None), + ], + ) + def test_build_field_filter_key(self, field, excepted_filter_key): + assert self.matcher._build_field_filter_key(field) == excepted_filter_key + + @pytest.mark.parametrize( + ("source_data", "excepted_queryset"), + [ + # valid + ( + {"user_id": "test_username", "phone": "1234567890123"}, + (Q(data_source_id=1) & Q(username="test_username") & Q(phone="1234567890123")), + ), + ( + {"user_id": "test_username", "phone": "1234567890123", "email": "111@qq.com"}, + (Q(data_source_id=1) & Q(username="test_username") & Q(phone="1234567890123")), + ), + # invalid + ( + {"user_id": "test_username"}, + None, + ), + ], + ) + def test_convert_one_rule_to_queryset_filter_for_source_data(self, source_data, excepted_queryset): + data_source_match_rule = DataSourceMatchRule( + data_source_id=1, + field_compare_rules=[ + FieldCompareRule(source_field="user_id", target_field="username"), + FieldCompareRule(source_field="phone", target_field="phone"), + ], + ) + queryset = self.matcher._convert_one_rule_to_queryset_filter(data_source_match_rule, source_data) + + assert queryset == excepted_queryset + + @pytest.mark.parametrize( + ("rule", "excepted_queryset"), + [ + # No field compare rule + ( + DataSourceMatchRule(data_source_id=1, field_compare_rules=[]), + None, + ), + # one field compare rule + ( + DataSourceMatchRule( + data_source_id=1, + field_compare_rules=[FieldCompareRule(source_field="user_id", target_field="username")], + ), + (Q(data_source_id=1) & Q(username="test_username")), + ), + # Mult field compare rule + ( + DataSourceMatchRule( + data_source_id=1, + field_compare_rules=[ + FieldCompareRule(source_field="user_id", target_field="username"), + FieldCompareRule(source_field="phone", target_field="phone"), + FieldCompareRule(source_field="email", target_field="region"), + ], + ), + ( + Q(data_source_id=1) + & Q(username="test_username") + & Q(phone="1234567890123") + & Q(extras__region="111@qq.com") + ), + ), + ], + ) + def test_convert_one_rule_to_queryset_filter_for_rule(self, rule, excepted_queryset): + source_data = {"user_id": "test_username", "phone": "1234567890123", "email": "111@qq.com"} + + queryset = self.matcher._convert_one_rule_to_queryset_filter(rule, source_data) + + assert queryset == excepted_queryset diff --git a/src/bk-user/tests/fixtures/tenant.py b/src/bk-user/tests/fixtures/tenant.py index 4766a35ec..cab8d2847 100644 --- a/src/bk-user/tests/fixtures/tenant.py +++ b/src/bk-user/tests/fixtures/tenant.py @@ -51,4 +51,22 @@ def tenant_user_custom_fields(default_tenant) -> List[TenantUserCustomField]: "required": True, }, ) - return [age_field, gender_field, region_field] + sport_hobby_field, _ = TenantUserCustomField.objects.get_or_create( + tenant=default_tenant, + name="sport_hobby", + defaults={ + "display_name": "运动爱好", + "data_type": UserFieldDataType.MULTI_ENUM, + "required": True, + "default": "running", + "options": { + "running": "跑步", + "swimming": "游泳", + "Basketball": "篮球", + "football": "足球", + "golf": "高尔夫", + "cycling": "骑行", + }, + }, + ) + return [age_field, gender_field, region_field, sport_hobby_field]