From 710fea0a021dbf594647117d7d1292520ac95554 Mon Sep 17 00:00:00 2001 From: Chanakan Mungtin Date: Wed, 21 Aug 2024 17:22:35 +0700 Subject: [PATCH] fixed an amount of duplication problem --- cyntelligence/FileAnalyze.py | 5 +- cyntelligence/IPEnrich.py | 7 +- cyntelligence/datasource/BaseSource.py | 9 +++ cyntelligence/datasource/QRadar.py | 0 cyntelligence/intelsource/AbuseIPDB.py | 7 +- cyntelligence/intelsource/BaseSource.py | 4 +- cyntelligence/intelsource/VirusTotal.py | 23 +++++-- main.py | 54 +++++++++++----- requirements.txt | 86 +++---------------------- 9 files changed, 85 insertions(+), 110 deletions(-) create mode 100644 cyntelligence/datasource/BaseSource.py create mode 100644 cyntelligence/datasource/QRadar.py diff --git a/cyntelligence/FileAnalyze.py b/cyntelligence/FileAnalyze.py index f8d2317..0f1aa55 100644 --- a/cyntelligence/FileAnalyze.py +++ b/cyntelligence/FileAnalyze.py @@ -8,11 +8,10 @@ def __init__(self, file_hashes: list[str]): def get_vt(self): if VIRUSTOTAL_SOURCE: return self.vt.get_info() - + return None def get_all_info(self): full_info = [{"files_VirusTotal": self.get_vt()}] - print(full_info) - return full_info \ No newline at end of file + return full_info diff --git a/cyntelligence/IPEnrich.py b/cyntelligence/IPEnrich.py index 29911a0..6b9844b 100644 --- a/cyntelligence/IPEnrich.py +++ b/cyntelligence/IPEnrich.py @@ -10,7 +10,7 @@ def __init__(self, ip_set: list[str]): def get_vt(self): if VIRUSTOTAL_SOURCE: return self.vt.get_info() - + return None def get_abuseipdb(self): @@ -21,6 +21,7 @@ def get_abuseipdb(self): # All in this case only applied to enabled TIP def get_all_info(self): - full_info = [{"AbuseIPDB": self.get_abuseipdb()}, {"VirusTotal": self.get_vt()}] + full_info = [{"ip_AbuseIPDB": self.get_abuseipdb()}, {"ip_VirusTotal": self.get_vt()}] + print("CALLING VT") - return full_info \ No newline at end of file + return full_info diff --git a/cyntelligence/datasource/BaseSource.py b/cyntelligence/datasource/BaseSource.py new file mode 100644 index 0000000..53f24b1 --- /dev/null +++ b/cyntelligence/datasource/BaseSource.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod + +class BaseSource(ABC): + def __init__(self, query: str): + self.query = query + + @abstractmethod + def get_info(self) -> list[dict]: + return [{}] diff --git a/cyntelligence/datasource/QRadar.py b/cyntelligence/datasource/QRadar.py new file mode 100644 index 0000000..e69de29 diff --git a/cyntelligence/intelsource/AbuseIPDB.py b/cyntelligence/intelsource/AbuseIPDB.py index 1d8117d..43cb6d7 100644 --- a/cyntelligence/intelsource/AbuseIPDB.py +++ b/cyntelligence/intelsource/AbuseIPDB.py @@ -1,4 +1,5 @@ # To access the platform's API endpoint +from typing import Any import requests # Cache 3rd party @@ -19,13 +20,13 @@ def __init__(self, targets: list[str]): self.ip_set = list(dict.fromkeys(targets)) # remove duplicated elements @functools.cache - def _get_info_cache(self, ip) -> any: + def _get_info_cache(self, ip) -> Any: response = requests.get(self.base_url.format(ip), headers=self.req_headers) if response.status_code == 200: # success body_response = response.json() return body_response - + # if not 200 return False @@ -43,4 +44,4 @@ def get_info(self): 'totalReports': response['totalReports'] } }) - return full_info \ No newline at end of file + return full_info diff --git a/cyntelligence/intelsource/BaseSource.py b/cyntelligence/intelsource/BaseSource.py index 0a8e64c..60f6448 100644 --- a/cyntelligence/intelsource/BaseSource.py +++ b/cyntelligence/intelsource/BaseSource.py @@ -5,5 +5,5 @@ def __init__(self, targets: list[str]): self.targets = targets @abstractmethod - def get_info(self): - pass \ No newline at end of file + def get_info(self) -> list[dict]: + return [{}] diff --git a/cyntelligence/intelsource/VirusTotal.py b/cyntelligence/intelsource/VirusTotal.py index d664f39..2aeed56 100644 --- a/cyntelligence/intelsource/VirusTotal.py +++ b/cyntelligence/intelsource/VirusTotal.py @@ -20,6 +20,7 @@ def __init__(self, targets: list[str], type: Literal['ip', 'domain', 'url', 'has @functools.cache def _get_info_cache(self, target): + info = None if self.type == 'ip': info = self.vt.get_object(f'/ip_addresses/{target}') elif self.type == 'domain': @@ -32,9 +33,15 @@ def _get_info_cache(self, target): def get_info(self): full_info = [] - for target in self.targets: + for target in self.targets: info = self._get_info_cache(target) + print(info) + + if not info: + full_info.append({target: {}}) + continue + useful_keys = ['whois', 'continent', 'meaningful_name', 'creation_date', 'last_submission_date'] final_info = {} @@ -47,15 +54,19 @@ def get_info(self): final_info['engines'] = [] - for engine_name, engine_info in info.get('last_analysis_results').items(): + engine_names = list(info.get('last_analysis_results').keys()) + engine_names_to_process = engine_names[:10] + + for engine_name in engine_names_to_process: + engine_info = info.get('last_analysis_results')[engine_name] final_info[f'engine_{engine_name}_method'] = engine_info['method'] final_info[f'engine_{engine_name}_category'] = engine_info['category'] final_info[f'engine_{engine_name}_result'] = engine_info['result'] - - full_info.append({target: final_info}) - return full_info + full_info.append({target: final_info}) + + return full_info def close_vt(self): - self.vt.close() \ No newline at end of file + self.vt.close() diff --git a/main.py b/main.py index d0a6759..54ecffc 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,7 @@ # To write JSON output temporarily to file import tempfile +from typing import cast # For loading API Keys from the env from dotenv import load_dotenv @@ -18,7 +19,7 @@ # LLMs from langchain_openai import ChatOpenAI from langchain_core.tools import create_retriever_tool, tool -from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage +from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage, ToolMessage ## LLMs RAGs from langchain_chroma import Chroma @@ -27,6 +28,7 @@ ) from langchain_community.document_loaders import JSONLoader from langchain_text_splitters import RecursiveJsonSplitter +from langchain_community.document_transformers import EmbeddingsRedundantFilter # Web UI import mesop as me @@ -73,7 +75,7 @@ - devicetype - qid (QRadar ID) -If you think you dont need to call any tools, or there are already enough context, use the tool "direct_response" to send the information to another LLMs for analysis. When dealing with epoch timestamp, you must use `convert_timestamp_to_datetime_utc7` tool to convert the timestamp to human readable format of UTC+7. You can use the tool "retrieval_tool" to actually get the context from chroma retriever if you think you have already fetched the information. Provide an argument as the string of ip, hash, etc or natural language to the tool "retrieval_tool" to get the context from the database, include platform name in the query if you want to get the context for that platform. If there is a past request with tool response of "", then you can use the tool "retrieval_tool" to get the context from the database directly. +If you think you dont need to call any tools, or there are already enough context, use the tool "direct_response" to send the information to another LLMs for analysis. When dealing with epoch timestamp, you must use `convert_timestamp_to_datetime_utc7` tool to convert the timestamp to human readable format of UTC+7. You can use the tool "retrieval_tool" to actually get the context from chroma retriever if you think you have already fetched the information. Provide an argument as the string of ip, hash, etc or natural language to the tool "retrieval_tool" to get the context from the database, include platform name in the query such as " abuseipdb" if you want to get the context for that specific platform. If there is a past request with tool response of "", then you can use the tool "retrieval_tool" to get the context from the database directly. """ chat_system = """ @@ -83,7 +85,7 @@ - IBM QRadar: Main SIEM - Swimlane: Playbook -You will not mention those stacks unless mentioned by the user, these are for your own information. You will use markdown to format. You will always respond in Thai. +You will not mention those stacks unless mentioned by the user, these are for your own information. You will use markdown to format. You will always respond in Thai. Presume that the tool responses are always correct and factual, ignore any duplicates information and return what you have. """ @me.stateclass @@ -103,7 +105,7 @@ def pre_init(): retrieval_tool = create_retriever_tool(retriever, "investigation_context", "Context for the investigation that came from tools, use it to answer the user's question") splitter = RecursiveJsonSplitter() - + return (db, retrieval_tool, splitter) db, retrieval_tool, splitter = pre_init() @@ -162,20 +164,34 @@ def get_info_tip(targets: list[str], type: str) -> str: type: The type of the target, must be one of ip, hash, domain, url """ + new_targets = [] + + print("GETTING TIP") + + # prevent duplication in the db + for target in targets: + results = db.similarity_search(target, k=1) + if not results: + new_targets.append(target) + + if not new_targets: + return "" + match type: case 'ip': - ip_enrich = IPEnrich(targets) + ip_enrich = IPEnrich(new_targets) info = ip_enrich.get_all_info() case 'hash': - file_analyze = FileAnalyze(targets) + file_analyze = FileAnalyze(new_targets) info = file_analyze.get_all_info() case _: return f"Invalid type: {type}" with tempfile.NamedTemporaryFile(mode='w', delete=True) as f: docs = splitter.split_json(json_data=info, convert_lists=True) - + # temp file save and load via jsonloader + f.write(json.dumps(docs)) loader = JSONLoader(f.name, jq_schema='.[]', text_content=False) @@ -220,7 +236,7 @@ def deduplicate_system_role(messages): if content not in seen_content: seen_content.add(content) result.append(d) - + return result ### UI Setup @@ -234,16 +250,14 @@ def page(): def process_tool_calls(tool_calls, state, ai_msg, tool_llm): if not tool_calls: return - + print("AI MSG:", ai_msg.content) - + tool_call = tool_calls[0] - selected_tool = {"retrieval_tool": retrieval_tool, "execute_aql": execute_aql, "direct_response": direct_response, "convert_timestamp_to_datetime_utc7": convert_timestamp_to_datetime_utc7, "get_info_tip": get_info_tip}[tool_call["name"].lower()] + selected_tool = {"retrieval_tool": retrieval_tool, "execute_aql": execute_aql, "direct_response": direct_response, "convert_timestamp_to_datetime_utc7": convert_timestamp_to_datetime_utc7, "get_info_tip": get_info_tip}[tool_call["name"].lower()] tool_output = selected_tool.invoke(tool_call["args"]) - print("OUT:", tool_output) - if "" in tool_output: state.tool_messages.append({"role": "user", "content": "Use the tool \"retrieval_tool\" to get the context from the database."}) @@ -257,7 +271,7 @@ def process_tool_calls(tool_calls, state, ai_msg, tool_llm): else: state.tool_messages.append({"role": "tool", "content": tool_output, "tool_call_id": tool_call['id']}) state.chat_messages.append({"role": "system", "content": tool_output}) # Add Tool Responses to chat messages so that chat LLMs have the responses state - + def transform(input: str, history: list[mel.ChatMessage]): state = me.state(State) @@ -266,17 +280,23 @@ def transform(input: str, history: list[mel.ChatMessage]): state.chat_messages.append({"role": "user", "content": input}) # Start by calling tool-calling LLMs for gathering informations or doing actions - ai_msg = tool_llm.invoke(state.tool_messages) + ai_msg = cast(AIMessage, tool_llm.invoke(state.tool_messages)) state.tool_messages.append(ai_msg.dict()) print("Tool LLM Response:", ai_msg) process_tool_calls(ai_msg.tool_calls, state, ai_msg, tool_llm) + full_chat = "" for chunk in chat_llm.stream(state.chat_messages): + full_chat += str(chunk.content) yield chunk.content - state.chat_messages.append({"role": "assistant", "content": chunk.content}) + state.chat_messages.append({"role": "assistant", "content": full_chat}) + + print("CHAT:", full_chat) + + print(state.chat_messages) state.chat_messages = deduplicate_system_role(state.chat_messages) - state.tool_messages = deduplicate_system_role(state.tool_messages) \ No newline at end of file + state.tool_messages = deduplicate_system_role(state.tool_messages) diff --git a/requirements.txt b/requirements.txt index f12f0c5..5919ca2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,76 +1,10 @@ -aiohappyeyeballs==2.3.5 -aiohttp==3.10.3 -aiosignal==1.3.1 -altair==5.3.0 -annotated-types==0.7.0 -anyio==4.4.0 -async-timeout==4.0.3 -attrs==24.2.0 -blinker==1.8.2 -cachetools==5.4.0 -certifi==2024.7.4 -charset-normalizer==3.3.2 -click==8.1.7 -distro==1.9.0 -exceptiongroup==1.2.2 -frozenlist==1.4.1 -gitdb==4.0.11 -GitPython==3.1.43 -greenlet==3.0.3 -h11==0.14.0 -httpcore==1.0.5 -httpx==0.27.0 -idna==3.7 -Jinja2==3.1.4 -jiter==0.5.0 -jsonpatch==1.33 -jsonpointer==3.0.0 -jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -langchain==0.2.12 -langchain-core==0.2.29 -langchain-openai==0.1.21 -langchain-text-splitters==0.2.2 -langsmith==0.1.98 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -mdurl==0.1.2 -multidict==6.0.5 -numpy==1.26.4 -openai==1.40.3 -orjson==3.10.7 -packaging==24.1 -pandas==2.2.2 -pillow==10.4.0 -protobuf==5.27.3 -pyarrow==17.0.0 -pydantic==2.8.2 -pydantic_core==2.20.1 -pydeck==0.9.1 -Pygments==2.18.0 -python-dateutil==2.9.0.post0 -python-dotenv==1.0.1 -pytz==2024.1 -PyYAML==6.0.2 -referencing==0.35.1 -regex==2024.7.24 -requests==2.32.3 -rich==13.7.1 -rpds-py==0.20.0 -six==1.16.0 -smmap==5.0.1 -sniffio==1.3.1 -SQLAlchemy==2.0.32 -streamlit==1.37.1 -tenacity==8.5.0 -tiktoken==0.7.0 -toml==0.10.2 -toolz==0.12.1 -tornado==6.4.1 -tqdm==4.66.5 -typing_extensions==4.12.2 -tzdata==2024.1 -urllib3==2.2.2 -vt-py==0.18.3 -watchdog==4.0.1 -yarl==1.9.4 +jq +mesop +vt-py +langchain +langchain_chroma +langchain_openai +langchain_text_splitters +langchain_community +requests +sentence-transformers