Skip to content

Commit

Permalink
fixed an amount of duplication problem
Browse files Browse the repository at this point in the history
  • Loading branch information
Chanakan5591 committed Aug 21, 2024
1 parent bc8b169 commit 710fea0
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 110 deletions.
5 changes: 2 additions & 3 deletions cyntelligence/FileAnalyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ def __init__(self, file_hashes: list[str]):
def get_vt(self):
if VIRUSTOTAL_SOURCE:
return self.vt.get_info()

return None

def get_all_info(self):
full_info = [{"files_VirusTotal": self.get_vt()}]

print(full_info)
return full_info
return full_info
7 changes: 4 additions & 3 deletions cyntelligence/IPEnrich.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, ip_set: list[str]):
def get_vt(self):
if VIRUSTOTAL_SOURCE:
return self.vt.get_info()

return None

def get_abuseipdb(self):
Expand All @@ -21,6 +21,7 @@ def get_abuseipdb(self):

# All in this case only applied to enabled TIP
def get_all_info(self):
full_info = [{"AbuseIPDB": self.get_abuseipdb()}, {"VirusTotal": self.get_vt()}]
full_info = [{"ip_AbuseIPDB": self.get_abuseipdb()}, {"ip_VirusTotal": self.get_vt()}]
print("CALLING VT")

return full_info
return full_info
9 changes: 9 additions & 0 deletions cyntelligence/datasource/BaseSource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod

class BaseSource(ABC):
def __init__(self, query: str):
self.query = query

@abstractmethod
def get_info(self) -> list[dict]:
return [{}]
Empty file.
7 changes: 4 additions & 3 deletions cyntelligence/intelsource/AbuseIPDB.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# To access the platform's API endpoint
from typing import Any
import requests

# Cache 3rd party
Expand All @@ -19,13 +20,13 @@ def __init__(self, targets: list[str]):
self.ip_set = list(dict.fromkeys(targets)) # remove duplicated elements

@functools.cache
def _get_info_cache(self, ip) -> any:
def _get_info_cache(self, ip) -> Any:
response = requests.get(self.base_url.format(ip), headers=self.req_headers)

if response.status_code == 200: # success
body_response = response.json()
return body_response

# if not 200
return False

Expand All @@ -43,4 +44,4 @@ def get_info(self):
'totalReports': response['totalReports']
}
})
return full_info
return full_info
4 changes: 2 additions & 2 deletions cyntelligence/intelsource/BaseSource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ def __init__(self, targets: list[str]):
self.targets = targets

@abstractmethod
def get_info(self):
pass
def get_info(self) -> list[dict]:
return [{}]
23 changes: 17 additions & 6 deletions cyntelligence/intelsource/VirusTotal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, targets: list[str], type: Literal['ip', 'domain', 'url', 'has

@functools.cache
def _get_info_cache(self, target):
info = None
if self.type == 'ip':
info = self.vt.get_object(f'/ip_addresses/{target}')
elif self.type == 'domain':
Expand All @@ -32,9 +33,15 @@ def _get_info_cache(self, target):

def get_info(self):
full_info = []
for target in self.targets:
for target in self.targets:
info = self._get_info_cache(target)

print(info)

if not info:
full_info.append({target: {}})
continue

useful_keys = ['whois', 'continent', 'meaningful_name', 'creation_date', 'last_submission_date']
final_info = {}

Expand All @@ -47,15 +54,19 @@ def get_info(self):

final_info['engines'] = []

for engine_name, engine_info in info.get('last_analysis_results').items():
engine_names = list(info.get('last_analysis_results').keys())
engine_names_to_process = engine_names[:10]

for engine_name in engine_names_to_process:
engine_info = info.get('last_analysis_results')[engine_name]
final_info[f'engine_{engine_name}_method'] = engine_info['method']
final_info[f'engine_{engine_name}_category'] = engine_info['category']
final_info[f'engine_{engine_name}_result'] = engine_info['result']


full_info.append({target: final_info})
return full_info
full_info.append({target: final_info})

return full_info


def close_vt(self):
self.vt.close()
self.vt.close()
54 changes: 37 additions & 17 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

# To write JSON output temporarily to file
import tempfile
from typing import cast

# For loading API Keys from the env
from dotenv import load_dotenv
Expand All @@ -18,7 +19,7 @@
# LLMs
from langchain_openai import ChatOpenAI
from langchain_core.tools import create_retriever_tool, tool
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage, ToolMessage

## LLMs RAGs
from langchain_chroma import Chroma
Expand All @@ -27,6 +28,7 @@
)
from langchain_community.document_loaders import JSONLoader
from langchain_text_splitters import RecursiveJsonSplitter
from langchain_community.document_transformers import EmbeddingsRedundantFilter

# Web UI
import mesop as me
Expand Down Expand Up @@ -73,7 +75,7 @@
- devicetype
- qid (QRadar ID)
If you think you dont need to call any tools, or there are already enough context, use the tool "direct_response" to send the information to another LLMs for analysis. When dealing with epoch timestamp, you must use `convert_timestamp_to_datetime_utc7` tool to convert the timestamp to human readable format of UTC+7. You can use the tool "retrieval_tool" to actually get the context from chroma retriever if you think you have already fetched the information. Provide an argument as the string of ip, hash, etc or natural language to the tool "retrieval_tool" to get the context from the database, include platform name in the query if you want to get the context for that platform. If there is a past request with tool response of "<ADDED_TO_RETRIEVER>", then you can use the tool "retrieval_tool" to get the context from the database directly.
If you think you dont need to call any tools, or there are already enough context, use the tool "direct_response" to send the information to another LLMs for analysis. When dealing with epoch timestamp, you must use `convert_timestamp_to_datetime_utc7` tool to convert the timestamp to human readable format of UTC+7. You can use the tool "retrieval_tool" to actually get the context from chroma retriever if you think you have already fetched the information. Provide an argument as the string of ip, hash, etc or natural language to the tool "retrieval_tool" to get the context from the database, include platform name in the query such as "<IP_ADDRESS> abuseipdb" if you want to get the context for that specific platform. If there is a past request with tool response of "<ADDED_TO_RETRIEVER>", then you can use the tool "retrieval_tool" to get the context from the database directly.
"""

chat_system = """
Expand All @@ -83,7 +85,7 @@
- IBM QRadar: Main SIEM
- Swimlane: Playbook
You will not mention those stacks unless mentioned by the user, these are for your own information. You will use markdown to format. You will always respond in Thai.
You will not mention those stacks unless mentioned by the user, these are for your own information. You will use markdown to format. You will always respond in Thai. Presume that the tool responses are always correct and factual, ignore any duplicates information and return what you have.
"""

@me.stateclass
Expand All @@ -103,7 +105,7 @@ def pre_init():
retrieval_tool = create_retriever_tool(retriever, "investigation_context", "Context for the investigation that came from tools, use it to answer the user's question")

splitter = RecursiveJsonSplitter()

return (db, retrieval_tool, splitter)

db, retrieval_tool, splitter = pre_init()
Expand Down Expand Up @@ -162,20 +164,34 @@ def get_info_tip(targets: list[str], type: str) -> str:
type: The type of the target, must be one of ip, hash, domain, url
"""

new_targets = []

print("GETTING TIP")

# prevent duplication in the db
for target in targets:
results = db.similarity_search(target, k=1)
if not results:
new_targets.append(target)

if not new_targets:
return "<ADDED_TO_RETRIEVER>"

match type:
case 'ip':
ip_enrich = IPEnrich(targets)
ip_enrich = IPEnrich(new_targets)
info = ip_enrich.get_all_info()
case 'hash':
file_analyze = FileAnalyze(targets)
file_analyze = FileAnalyze(new_targets)
info = file_analyze.get_all_info()
case _:
return f"Invalid type: {type}"

with tempfile.NamedTemporaryFile(mode='w', delete=True) as f:
docs = splitter.split_json(json_data=info, convert_lists=True)

# temp file save and load via jsonloader

f.write(json.dumps(docs))

loader = JSONLoader(f.name, jq_schema='.[]', text_content=False)
Expand Down Expand Up @@ -220,7 +236,7 @@ def deduplicate_system_role(messages):
if content not in seen_content:
seen_content.add(content)
result.append(d)

return result

### UI Setup
Expand All @@ -234,16 +250,14 @@ def page():
def process_tool_calls(tool_calls, state, ai_msg, tool_llm):
if not tool_calls:
return

print("AI MSG:", ai_msg.content)

tool_call = tool_calls[0]
selected_tool = {"retrieval_tool": retrieval_tool, "execute_aql": execute_aql, "direct_response": direct_response, "convert_timestamp_to_datetime_utc7": convert_timestamp_to_datetime_utc7, "get_info_tip": get_info_tip}[tool_call["name"].lower()]
selected_tool = {"retrieval_tool": retrieval_tool, "execute_aql": execute_aql, "direct_response": direct_response, "convert_timestamp_to_datetime_utc7": convert_timestamp_to_datetime_utc7, "get_info_tip": get_info_tip}[tool_call["name"].lower()]
tool_output = selected_tool.invoke(tool_call["args"])


print("OUT:", tool_output)

if "<ADDED_TO_RETRIEVER>" in tool_output:
state.tool_messages.append({"role": "user", "content": "Use the tool \"retrieval_tool\" to get the context from the database."})

Expand All @@ -257,7 +271,7 @@ def process_tool_calls(tool_calls, state, ai_msg, tool_llm):
else:
state.tool_messages.append({"role": "tool", "content": tool_output, "tool_call_id": tool_call['id']})
state.chat_messages.append({"role": "system", "content": tool_output}) # Add Tool Responses to chat messages so that chat LLMs have the responses state


def transform(input: str, history: list[mel.ChatMessage]):
state = me.state(State)
Expand All @@ -266,17 +280,23 @@ def transform(input: str, history: list[mel.ChatMessage]):
state.chat_messages.append({"role": "user", "content": input})

# Start by calling tool-calling LLMs for gathering informations or doing actions
ai_msg = tool_llm.invoke(state.tool_messages)
ai_msg = cast(AIMessage, tool_llm.invoke(state.tool_messages))
state.tool_messages.append(ai_msg.dict())

print("Tool LLM Response:", ai_msg)

process_tool_calls(ai_msg.tool_calls, state, ai_msg, tool_llm)

full_chat = ""
for chunk in chat_llm.stream(state.chat_messages):
full_chat += str(chunk.content)
yield chunk.content

state.chat_messages.append({"role": "assistant", "content": chunk.content})
state.chat_messages.append({"role": "assistant", "content": full_chat})

print("CHAT:", full_chat)

print(state.chat_messages)

state.chat_messages = deduplicate_system_role(state.chat_messages)
state.tool_messages = deduplicate_system_role(state.tool_messages)
state.tool_messages = deduplicate_system_role(state.tool_messages)
86 changes: 10 additions & 76 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,76 +1,10 @@
aiohappyeyeballs==2.3.5
aiohttp==3.10.3
aiosignal==1.3.1
altair==5.3.0
annotated-types==0.7.0
anyio==4.4.0
async-timeout==4.0.3
attrs==24.2.0
blinker==1.8.2
cachetools==5.4.0
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
distro==1.9.0
exceptiongroup==1.2.2
frozenlist==1.4.1
gitdb==4.0.11
GitPython==3.1.43
greenlet==3.0.3
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
idna==3.7
Jinja2==3.1.4
jiter==0.5.0
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
langchain==0.2.12
langchain-core==0.2.29
langchain-openai==0.1.21
langchain-text-splitters==0.2.2
langsmith==0.1.98
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
multidict==6.0.5
numpy==1.26.4
openai==1.40.3
orjson==3.10.7
packaging==24.1
pandas==2.2.2
pillow==10.4.0
protobuf==5.27.3
pyarrow==17.0.0
pydantic==2.8.2
pydantic_core==2.20.1
pydeck==0.9.1
Pygments==2.18.0
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
pytz==2024.1
PyYAML==6.0.2
referencing==0.35.1
regex==2024.7.24
requests==2.32.3
rich==13.7.1
rpds-py==0.20.0
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
SQLAlchemy==2.0.32
streamlit==1.37.1
tenacity==8.5.0
tiktoken==0.7.0
toml==0.10.2
toolz==0.12.1
tornado==6.4.1
tqdm==4.66.5
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
vt-py==0.18.3
watchdog==4.0.1
yarl==1.9.4
jq
mesop
vt-py
langchain
langchain_chroma
langchain_openai
langchain_text_splitters
langchain_community
requests
sentence-transformers

0 comments on commit 710fea0

Please sign in to comment.