Skip to content

Commit

Permalink
feature: Agent / Proxy 类操作支持多接入点 Agent 操作 (closed #1714)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyyalt committed Aug 25, 2023
1 parent 6192e9f commit 4d162f0
Show file tree
Hide file tree
Showing 12 changed files with 128 additions and 44 deletions.
30 changes: 19 additions & 11 deletions apps/backend/components/collections/agent_new/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@
from ..base import BaseService, CommonData


@dataclass
class AgentCommonData(CommonData):
# 默认接入点
default_ap: models.AccessPoint
# 主机ID - 接入点 映射关系
host_id__ap_map: Dict[int, models.AccessPoint]
# AgentStep 适配器
agent_step_adapter: AgentStepAdapter
# 注入AP_ID
injected_ap_id: int


class AgentBaseService(BaseService, metaclass=abc.ABCMeta):
"""
AGENT安装基类
Expand Down Expand Up @@ -305,17 +317,13 @@ def maintain_agent_proc_status_uniqueness(self, bk_host_ids: Set[int]) -> None:
proc_statuses_to_be_created.append(models.ProcessStatus(bk_host_id=host_id, **self.agent_proc_common_data))
models.ProcessStatus.objects.bulk_create(proc_statuses_to_be_created, batch_size=self.batch_size)


@dataclass
class AgentCommonData(CommonData):
# 默认接入点
default_ap: models.AccessPoint
# 主机ID - 接入点 映射关系
host_id__ap_map: Dict[int, models.AccessPoint]
# AgentStep 适配器
agent_step_adapter: AgentStepAdapter
# 注入AP_ID
injected_ap_id: int
def get_host_ap(self, common_data: AgentCommonData, host: models.Host) -> models.AccessPoint:
# 优先使用注入的AP ID
if common_data.injected_ap_id:
host_ap: models.AccessPoint = common_data.ap_id_obj_map[common_data.injected_ap_id]
else:
host_ap: models.AccessPoint = common_data.host_id__ap_map[host.bk_host_id]
return host_ap


class RetryHandler:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@

class PushUpgradeFileService(AgentTransferFileService):
def get_file_target_path(self, data, common_data: AgentCommonData, host: models.Host) -> str:
return common_data.host_id__ap_map[host.bk_host_id].get_agent_config(host.os_type)["temp_path"]
host_ap = self.get_host_ap(common_data=common_data, host=host)
return host_ap.get_agent_config(host.os_type)["temp_path"]

def get_upgrade_package_source_path(self, common_data: AgentCommonData, host: models.Host) -> Tuple[str, str]:
"""
获取升级包源路径
"""
host_ap = self.get_host_ap(common_data=common_data, host=host)
# 1.x 升级到 1.x,使用老到路径,升级包直接放在 download 目录下
agent_path = root_path = common_data.host_id__ap_map[host.bk_host_id].nginx_path or settings.DOWNLOAD_PATH
agent_path = root_path = host_ap.nginx_path or settings.DOWNLOAD_PATH
if not common_data.agent_step_adapter.is_legacy:
# 2.x 升级到 2.x,根据操作系统、CPU 架构等组合路径
agent_path = os.path.join(root_path, "agent", host.os_type.lower(), host.cpu_arch.lower())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def get_script_content(self, data, common_data: AgentCommonData, host: models.Ho
# 路径处理器
path_handler = PathHandler(host.os_type)
general_node_type = self.get_general_node_type(host.node_type)
setup_path = common_data.host_id__ap_map[host.bk_host_id].get_agent_config(host.os_type)["setup_path"]
host_ap: models.AccessPoint = self.get_host_ap(common_data=common_data, host=host)
setup_path = host_ap.get_agent_config(host.os_type)["setup_path"]
agent_path = path_handler.join(setup_path, general_node_type, "bin")
if common_data.agent_step_adapter.is_legacy:
return f"cd {agent_path} && ./gse_agent --reload"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ class RenderAndPushGseConfigService(AgentPushConfigService):
def get_config_info_list(self, data, common_data: AgentCommonData, host: models.Host) -> List[Dict[str, Any]]:
file_name = common_data.agent_step_adapter.get_main_config_filename()
general_node_type = self.get_general_node_type(host.node_type)
host_ap: models.AccessPoint = self.get_host_ap(common_data=common_data, host=host)
content = common_data.agent_step_adapter.get_config(
host=host, filename=file_name, node_type=general_node_type, ap=common_data.host_id__ap_map[host.bk_host_id]
host=host, filename=file_name, node_type=general_node_type, ap=host_ap
)
return [{"file_name": file_name, "content": content}]

def get_file_target_path(self, data, common_data: AgentCommonData, host: models.Host) -> str:
general_node_type = self.get_general_node_type(host.node_type)
path_handler = PathHandler(host.os_type)
setup_path = common_data.host_id__ap_map[host.bk_host_id].get_agent_config(host.os_type)["setup_path"]
host_ap: models.AccessPoint = self.get_host_ap(common_data=common_data, host=host)
setup_path = host_ap.get_agent_config(host.os_type)["setup_path"]
return path_handler.join(setup_path, general_node_type, "etc")
3 changes: 2 additions & 1 deletion apps/backend/components/collections/agent_new/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def get_script_content(self, data, common_data: AgentCommonData, host: models.Ho
ctl_exe_name = ("gsectl", "gsectl.bat")[host.os_type == constants.OsType.WINDOWS]
cmd_suffix = ("restart >/dev/null 2>&1", "restart")[host.os_type == constants.OsType.WINDOWS]
general_node_type = self.get_general_node_type(host.node_type)
setup_path = common_data.host_id__ap_map[host.bk_host_id].get_agent_config(host.os_type)["setup_path"]
host_ap: models.AccessPoint = self.get_host_ap(common_data=common_data, host=host)
setup_path = host_ap.get_agent_config(host.os_type)["setup_path"]
agent_path = path_handler.join(setup_path, general_node_type, "bin", ctl_exe_name)

return f"{agent_path} {cmd_suffix}"
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def script_name(self):
def get_script_content(self, data, common_data: AgentCommonData, host: models.Host) -> str:
agent_upgrade_pkg_name = self.get_agent_pkg_name(common_data, host, is_upgrade=True)
general_node_type = self.get_general_node_type(host.node_type)
agent_config = common_data.host_id__ap_map[host.bk_host_id].get_agent_config(host.os_type)
host_ap: models.AccessPoint = self.get_host_ap(common_data=common_data, host=host)
agent_config = host_ap.get_agent_config(host.os_type)

if host.os_type == constants.OsType.WINDOWS:
scripts = WINDOWS_UPGRADE_CMD_TEMPLATE.format(
Expand Down
17 changes: 14 additions & 3 deletions apps/backend/subscription/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
)
from apps.backend.utils.data_renderer import nested_render_data
from apps.component.esbclient import client_v2
from apps.core.ipchooser.tools.base import HostQuerySqlHelper
from apps.exceptions import ComponentCallError
from apps.node_man import constants, models
from apps.node_man import tools as node_man_tools
from apps.utils.basic import chunk_lists, distinct_dict_list, order_dict
from apps.utils.batch_request import batch_request, request_multi_thread
from apps.utils.cache import func_cache_decorator
from apps.utils.time_handler import strftime_local
from apps.core.ipchooser.tools.base import HostQuerySqlHelper

logger = logging.getLogger("app")

Expand Down Expand Up @@ -681,7 +681,7 @@ def wrapper(scope: Dict[str, Union[Dict, Any]], *args, **kwargs) -> Dict[str, Di
"object_type": scope["object_type"],
"node_type": scope["node_type"],
"nodes": list(nodes),
"instance_selector": scope.get("instance_selector")
"instance_selector": scope.get("instance_selector"),
},
**kwargs,
}
Expand Down Expand Up @@ -806,10 +806,14 @@ def get_instances_by_scope(scope: Dict[str, Union[Dict, int, Any]]) -> Dict[str,

if not need_register:
# 补充必要的主机或实例相关信息

add_host_info_to_instances(bk_biz_id, scope, instances)
add_scope_info_to_instances(nodes, scope, instances, module_to_topo)
add_process_info_to_instances(bk_biz_id, scope, instances)

# 回填原始参数
add_meta_info_to_instance(scope, instances)

instances_dict = {}
data = {
"object_type": scope["object_type"],
Expand All @@ -831,7 +835,7 @@ def get_instances_by_scope(scope: Dict[str, Union[Dict, int, Any]]) -> Dict[str,
instance_selector_host_ids = HostQuerySqlHelper.multiple_cond_sql(
params={"bk_host_id": bk_host_ids, "conditions": instance_selector},
biz_scope=[bk_biz_id],
return_all_node_type=True
return_all_node_type=True,
).values_list("bk_host_id", flat=True)

selector_instances_dict = {}
Expand All @@ -847,6 +851,13 @@ def get_instances_by_scope(scope: Dict[str, Union[Dict, int, Any]]) -> Dict[str,
return instances_dict


def add_meta_info_to_instance(scope: Dict, instances: Dict):
if scope["object_type"] == models.Subscription.ObjectType.HOST:
host_dict = {host["bk_host_id"]: host for host in scope["nodes"] if host.get("bk_host_id") is not None}
for instance in instances:
instance["host"].update(host_dict.get(instance["host"]["bk_host_id"], {}))


def add_host_info_to_instances(bk_biz_id: int, scope: Dict, instances: Dict):
"""
补全实例的主机信息
Expand Down
16 changes: 6 additions & 10 deletions apps/node_man/handlers/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from django.conf import settings
from django.core.paginator import Paginator
from django.db.models import Q
from django.db import transaction
from django.db.models import Q
from django.utils import timezone
from django.utils.translation import get_language
from django.utils.translation import ugettext_lazy as _
Expand Down Expand Up @@ -187,9 +187,7 @@ def list(self, params: dict, username: str):
biz_scope_query_q = Q()
else:
biz_scope_query_q = reduce(
operator.or_,
[Q(bk_biz_scope__contains=bk_biz_id) for bk_biz_id in biz_scope],
Q()
operator.or_, [Q(bk_biz_scope__contains=bk_biz_id) for bk_biz_id in biz_scope], Q()
)
# 仅查询所有业务时,自身创建的 job 可见
if not search_biz_ids:
Expand Down Expand Up @@ -616,15 +614,13 @@ def update_host(accept_list: list, ip_filter_list: list, is_manual: bool = False

return update_data_info["subscription_host_ids"], ip_filter_list

def operate(self, job_type, bk_host_ids, bk_biz_scope, extra_params, extra_config):
def operate(self, job_type, hosts, bk_biz_scope, extra_params, extra_config):
"""
用于只有bk_host_id参数的下线、重启等操作
"""
# 校验器进行校验

subscription = self.create_subscription(
job_type, bk_host_ids, extra_params=extra_params, extra_config=extra_config
)
subscription = self.create_subscription(job_type, hosts, extra_params=extra_params, extra_config=extra_config)

return tools.JobTools.create_job(
job_type=job_type,
Expand All @@ -634,9 +630,9 @@ def operate(self, job_type, bk_host_ids, bk_biz_scope, extra_params, extra_confi
statistics={
"success_count": 0,
"failed_count": 0,
"pending_count": len(bk_host_ids),
"pending_count": len(hosts),
"running_count": 0,
"total_count": len(bk_host_ids),
"total_count": len(hosts),
},
)

Expand Down
4 changes: 2 additions & 2 deletions apps/node_man/handlers/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,14 +331,14 @@ def operate(params: dict, username: str):
).values("bk_host_id", "bk_biz_id", "bk_cloud_id", "inner_ip", "node_type", "os_type")

# 校验器进行校验
db_host_ids, host_biz_scope = operate_validator(list(db_host_sql))
db_hosts, host_biz_scope = operate_validator(list(db_host_sql))

plugin_name__job_id__map = {}
for plugin_params in params["plugin_params_list"]:
plugin_name, plugin_version = plugin_params["name"], plugin_params["version"]
subscription_create_result = PluginHandler.create_subscription(
job_type=params["job_type"],
nodes=db_host_ids,
nodes=db_hosts,
name=plugin_name,
version=plugin_version,
keep_config=plugin_params.get("keep_config"),
Expand Down
11 changes: 8 additions & 3 deletions apps/node_man/handlers/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def install_validate(
return ip_filter_list, accept_list, proxy_not_alive


def operate_validator(db_host_sql):
def operate_validator(db_host_sql, host_info: typing.Dict[int, typing.Dict[str, typing.Any]] = {}):
"""
用于operate任务的校验
:param db_host_sql: 用户操作主机的详细信息
Expand Down Expand Up @@ -595,6 +595,11 @@ def operate_validator(db_host_sql):
# 获得业务ID
host_biz_scope = list({host["bk_biz_id"] for host in db_host_sql})

db_host_ids = [{"bk_host_id": host_id} for host_id in permission_host_ids]
db_hosts: typing.List[typing.Dict[str, typing.Any]] = []

return db_host_ids, host_biz_scope
for host_id in permission_host_ids:
_host = {"bk_host_id": host_id}
_host.update(host_info.get(host_id, {}))
db_hosts.append(_host)

return db_hosts, host_biz_scope
69 changes: 63 additions & 6 deletions apps/node_man/serializers/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class HostSerializer(serializers.Serializer):
peer_exchange_switch_for_agent = serializers.IntegerField(label=_("加速设置"), required=False, default=0)
bt_speed_limit = serializers.IntegerField(label=_("传输限速"), required=False)
data_path = serializers.CharField(label=_("数据文件路径"), required=False, allow_blank=True)
is_need_inject_ap_id = serializers.IntegerField(label=_("是否需要注入ap_id到meta"), required=False, default=False)
is_need_inject_ap_id = serializers.BooleanField(label=_("是否需要注入ap_id到meta"), required=False, default=False)
enable_compression = serializers.BooleanField(label=_("数据压缩开关"), required=False, default=False)

def validate(self, attrs):
Expand Down Expand Up @@ -273,6 +273,26 @@ def validate(self, attrs):
return attrs


class OperateHostSerializer(serializers.Serializer):
"""
操作类任务主机序列化器
"""

bk_host_id = serializers.IntegerField(label=_("主机ID"))
ap_id = serializers.IntegerField(label=_("接入点ID"), required=False)
is_use_ap_map = serializers.BooleanField(label=_("是否使用映射接入点"), required=False, default=False)

# 以下参数不需要用户传入
is_need_inject_ap_id = serializers.BooleanField(label=_("是否需要注入ap_id到meta"), required=False, default=False)

def validate(self, attrs):
# 计算is_need_inject_ap_id参数
if attrs.get("ap_id") is not None or attrs["is_use_ap_map"]:
attrs["is_need_inject_ap_id"] = True

return attrs


class OperateSerializer(serializers.Serializer):
agent_setup_info = AgentSetupInfoSerializer(label=_("Agent 设置信息"), required=False)
job_type = serializers.ChoiceField(label=_("任务类型"), choices=list(constants.JOB_TYPE_DICT))
Expand All @@ -281,6 +301,7 @@ class OperateSerializer(serializers.Serializer):
conditions = serializers.ListField(label=_("搜索条件"), required=False, child=serializers.DictField())
bk_host_id = serializers.ListField(label=_("主机ID"), required=False, child=serializers.IntegerField())
exclude_hosts = serializers.ListField(label=_("跨页全选排除主机"), required=False, child=serializers.IntegerField())
hosts = OperateHostSerializer(label=_("主机信息"), required=False, many=True)
is_install_latest_plugins = serializers.BooleanField(label=_("是否安装最新版本插件"), required=False, default=True)
is_install_other_agent = serializers.BooleanField(label=_("是否为安装额外Agent"), required=False, default=False)

Expand All @@ -302,9 +323,39 @@ def validate(self, attrs):

if attrs.get("exclude_hosts") is not None and attrs.get("bk_host_id") is not None:
raise ValidationError(_("跨页全选模式下不允许传bk_host_id参数"))
if attrs.get("exclude_hosts") is None and attrs.get("bk_host_id") is None:
if all(
[
attrs.get("exclude_hosts") is None,
attrs.get("bk_host_id") is None,
attrs.get("hosts") is None,
]
):
raise ValidationError(_("必须选择一种模式(【是否跨页全选】)"))

if attrs.get("hosts", []):
# 如果主机使用使用映射,回填映射ap id
use_ap_map_host_ids: typing.List[int] = [
host["bk_host_id"] for host in attrs["hosts"] if host["is_use_ap_map"]
]
if use_ap_map_host_ids:
gray_ap_map: typing.Dict[int, int] = GrayHandler.get_gray_ap_map()
host_queryset = models.Host.objects.filter(bk_host_id__in=use_ap_map_host_ids).values(
"bk_host_id", "ap_id"
)
host_id_ap_map: typing.Dict[int, int] = {_host["bk_host_id"]: _host["ap_id"] for _host in host_queryset}
for host in attrs["hosts"]:
if not host["is_use_ap_map"]:
continue

try:
host["ap_id"] = gray_ap_map[host_id_ap_map[host["bk_host_id"]]]
except KeyError:
raise ValidationError(
_("缺少与主机ID: {bk_host_id} AP ID: {ap_id} 对应的接入点映射,请联系管理员配置").format(
bk_host_id=host["bk_host_id"], ap_id=host_id_ap_map[host["bk_host_id"]]
)
)

if attrs["node_type"] == constants.NodeType.PROXY:
# 是否为针对代理的操作,用户有权限获取的业务
# 格式 { bk_biz_id: bk_biz_name , ...}
Expand All @@ -318,6 +369,10 @@ def validate(self, attrs):
filter_node_types = [constants.NodeType.AGENT, constants.NodeType.PAGENT]
is_proxy = False

host_info: typing.Dict[int, typing.Dict[str, typing.Any]] = {
_host["bk_host_id"]: _host for _host in attrs.get("hosts", [])
}

if attrs.get("exclude_hosts") is not None:
# 跨页全选
db_host_sql = (
Expand All @@ -328,12 +383,13 @@ def validate(self, attrs):

else:
# 不是跨页全选
input_bk_host_ids: typing.List[int] = attrs.get("bk_host_id", []) or host_info.keys()
db_host_sql = models.Host.objects.filter(
bk_host_id__in=attrs["bk_host_id"], node_type__in=filter_node_types
bk_host_id__in=input_bk_host_ids, node_type__in=filter_node_types
).values("bk_host_id", "bk_biz_id", "bk_cloud_id", "inner_ip", "node_type", "os_type")

bk_host_ids, bk_biz_scope = validator.operate_validator(list(db_host_sql))
attrs["bk_host_ids"] = bk_host_ids
db_hosts, bk_biz_scope = validator.operate_validator(list(db_host_sql), host_info=host_info)
attrs["hosts"] = db_hosts
attrs["bk_biz_scope"] = bk_biz_scope

set_agent_setup_info_to_attrs(attrs)
Expand All @@ -345,7 +401,7 @@ def validate(self, attrs):
# 没有 V2 接入点或者全部为 V2 接入点时,无需进行重定向处理
return attrs

host_ids: typing.List[int] = [host_info["bk_host_id"] for host_info in bk_host_ids]
host_ids: typing.List[int] = [host_info["bk_host_id"] for host_info in db_hosts]

# 进入灰度的管控区域,所属管控区域主机接入点重定向到 V2
gse_v2_cloud_ids: typing.Set[int] = set(
Expand Down Expand Up @@ -384,6 +440,7 @@ def validate(self, attrs):
GrayHandler.activate(host_nodes=update_result["host_nodes"], rollback=False, only_status=True)
except ApiError:
pass

return attrs


Expand Down
Loading

0 comments on commit 4d162f0

Please sign in to comment.