Skip to content

Commit

Permalink
Add Discord scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Nov 5, 2024
1 parent d17b626 commit 9bebbff
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 13 deletions.
20 changes: 10 additions & 10 deletions .github/workflows/train_workflow.yml
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
name: Training Workflow
on:
workflow_dispatch:
inputs:
script_content:
description: 'Content of train.py'
required: true
type: string # Explicitly specify the type

jobs:
train:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'

- name: Install dependencies
run: |
pip install -r ci-requirements.txt
- name: Run training
- name: Create and run training script
run: |
echo "${{ inputs.script_content }}" > train.py
cat train.py # Debug: print the content
python train.py > training.log 2>&1
- name: Upload logs
uses: actions/upload-artifact@v3
if: always() # Upload logs whether the job succeeds or fails
Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This is the code for the Discord bot we'll be using to queue jobs to a cluster of GPUs that our generous sponsors have provided.

The key idea is that we're using Github Actions as a job scheduling engine and primarily making the Discord bot interact with the cluster via issuing Github Actions and and monitoring their status
The key idea is that we're using Github Actions as a job scheduling engine and primarily making the Discord bot interact with the cluster via issuing Github Actions and and monitoring their status and while we're focused on having a nice user experience on discord.gg/gpumode, we're happy to accept PRs that make it easier for other Discord communities to hook GPUs.

## How to run the bot locally

Expand All @@ -20,10 +20,18 @@ Every triggered job is containerized so we don't have to worry too much about se

Instead of testing on GPU MODE directly we can leverage a staging environment called "Discord Cluster Staging". If you need access to this server please ping "Seraphim"

Bot needs to be invited using an oauth2 token and needs the `Message Content Intent` permission

The bot also needs to permissions to read and write messages which is easy to setup if you click on https://discord.com/api/oauth2/authorize?client_id=1303135152091697183&permissions=68608&scope=bot%20applications.commands

### How to add a new GPU to the cluster

Github has some nice instructions here https://docs.github.com/en/actions/hosting-your-own-runners/managing-self-hosted-runners/adding-self-hosted-runners but essentially the whole thing works by running a script on some GPU people own.

### Future work
* Maybe we shouldn't use Github Action and can roll our own thing?
* Make registering new GPUs simpler
* Make registering new GPUs simpler

## Acknowledgements
* Luca Antiga did something very similar for the NeurIPS LLM efficiency competition, it was great!
* Midjourney was a similar inspiration in terms of UX
219 changes: 219 additions & 0 deletions discord-bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
from dotenv import load_dotenv
from github import Github
import os
import time
from datetime import datetime, timezone
import requests
import discord
import asyncio
import logging

# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()
logger.info("Environment variables loaded")

# Validate environment variables
if not os.getenv('DISCORD_TOKEN'):
logger.error("DISCORD_TOKEN not found in environment variables")
raise ValueError("DISCORD_TOKEN not found")
if not os.getenv('GITHUB_TOKEN'):
logger.error("GITHUB_TOKEN not found in environment variables")
raise ValueError("GITHUB_TOKEN not found")
if not os.getenv('GITHUB_REPO'):
logger.error("GITHUB_REPO not found in environment variables")
raise ValueError("GITHUB_REPO not found")

logger.info(f"Using GitHub repo: {os.getenv('GITHUB_REPO')}")

# Bot setup with minimal intents
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)

async def trigger_github_action(script_content):
"""
Triggers the GitHub action with custom train.py contents
"""
logger.info("Attempting to trigger GitHub action")
gh = Github(os.getenv('GITHUB_TOKEN'))
repo = gh.get_repo(os.getenv('GITHUB_REPO'))

try:
# Record the time before triggering
trigger_time = datetime.now(timezone.utc)

# Log workflow attempt
logger.info(f"Looking for workflow 'train_workflow.yml' in repo {os.getenv('GITHUB_REPO')}")

# Trigger the workflow with the script content
workflow = repo.get_workflow("train_workflow.yml")
logger.info("Found workflow, attempting to dispatch")

success = workflow.create_dispatch("main", {'script_content': script_content})
logger.info(f"Workflow dispatch result: {success}")

if success:
# Wait a moment for the run to be created
await asyncio.sleep(2)

# Get runs created after our trigger time
runs = list(workflow.get_runs())
logger.info(f"Found {len(runs)} total runs")

for run in runs:
logger.info(f"Checking run {run.id} created at {run.created_at}")
if run.created_at.replace(tzinfo=timezone.utc) > trigger_time:
logger.info(f"Found matching run with ID: {run.id}")
return run.id

logger.warning("No matching runs found after trigger")
return None

except Exception as e:
logger.error(f"Error in trigger_github_action: {str(e)}", exc_info=True)
return None

async def download_artifact(run_id):
"""
Downloads the training log artifact from the workflow run
"""
logger.info(f"Attempting to download artifacts for run {run_id}")
gh = Github(os.getenv('GITHUB_TOKEN'))
repo = gh.get_repo(os.getenv('GITHUB_REPO'))

try:
# Get the specific run
run = repo.get_workflow_run(run_id)

# Get artifacts from the run
artifacts = run.get_artifacts()
logger.info(f"Found {artifacts.totalCount} artifacts")

for artifact in artifacts:
logger.info(f"Found artifact: {artifact.name}")
if artifact.name == 'training-logs':
# Download the artifact
url = artifact.archive_download_url
headers = {'Authorization': f'token {os.getenv("GITHUB_TOKEN")}'}
response = requests.get(url, headers=headers)

if response.status_code == 200:
logger.info("Successfully downloaded artifact")
with open('training.log.zip', 'wb') as f:
f.write(response.content)

# Read the log file from the zip
with zipfile.ZipFile('training.log.zip') as z:
with z.open('training.log') as f:
logs = f.read().decode('utf-8')

# Clean up the zip file
os.remove('training.log.zip')
return logs
else:
logger.error(f"Failed to download artifact. Status code: {response.status_code}")

logger.warning("No training-logs artifact found")
return "No training logs found in artifacts"
except Exception as e:
logger.error(f"Error in download_artifact: {str(e)}", exc_info=True)
return f"Error downloading artifacts: {str(e)}"

async def check_workflow_status(run_id, message):
"""
Monitors the GitHub Action workflow status and updates Discord
"""
logger.info(f"Starting to monitor workflow status for run {run_id}")
gh = Github(os.getenv('GITHUB_TOKEN'))
repo = gh.get_repo(os.getenv('GITHUB_REPO'))

while True:
try:
run = repo.get_workflow_run(run_id)
logger.info(f"Current status: {run.status}")

if run.status == "completed":
logger.info("Workflow completed, downloading artifacts")
logs = await download_artifact(run_id)
return run.conclusion, logs, run.html_url

await message.channel.send(f"Workflow still running... Status: {run.status}\nLive view: {run.html_url}")
await asyncio.sleep(30)
except Exception as e:
logger.error(f"Error in check_workflow_status: {str(e)}", exc_info=True)
return "error", str(e), None

@client.event
async def on_ready():
logger.info(f'Logged in as {client.user}')

@client.event
async def on_message(message):
# Ignore messages from the bot itself
if message.author == client.user:
return

# Check if the bot is mentioned and there's an attachment
if client.user in message.mentions:
logger.info(f"Bot mentioned in message with {len(message.attachments)} attachments")
if message.attachments:
for attachment in message.attachments:
logger.info(f"Processing attachment: {attachment.filename}")
if attachment.filename == "train.py":
await message.channel.send("Found train.py! Starting training process...")

try:
# Download the file content
logger.info("Downloading train.py content")
script_content = await attachment.read()
script_content = script_content.decode('utf-8')
logger.info("Successfully read train.py content")

# Trigger GitHub Action
run_id = await trigger_github_action(script_content)

if run_id:
logger.info(f"Successfully triggered workflow with run ID: {run_id}")
await message.channel.send(f"GitHub Action triggered successfully! Run ID: {run_id}\nMonitoring progress...")

# Monitor the workflow
status, logs, url = await check_workflow_status(run_id, message)

# Send results back to Discord
await message.channel.send(f"Training completed with status: {status}")

# Split logs if they're too long for Discord's message limit
if len(logs) > 1900:
chunks = [logs[i:i+1900] for i in range(0, len(logs), 1900)]
for i, chunk in enumerate(chunks):
await message.channel.send(f"```\nLogs (part {i+1}/{len(chunks)}):\n{chunk}\n```")
else:
await message.channel.send(f"```\nLogs:\n{logs}\n```")

if url:
await message.channel.send(f"View the full run at: {url}")
else:
logger.error("Failed to trigger GitHub Action")
await message.channel.send("Failed to trigger GitHub Action. Please check the configuration.")

except Exception as e:
logger.error(f"Error processing request: {str(e)}", exc_info=True)
await message.channel.send(f"Error processing request: {str(e)}")

break

if not any(att.filename == "train.py" for att in message.attachments):
await message.channel.send("Please attach a file named 'train.py' to your message.")

# Run the bot
if __name__ == "__main__":
logger.info("Starting bot...")
client.run(os.getenv('DISCORD_TOKEN'))
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
PyGithub
aiohttp
discord
discord.py
audioop-lts # discord.py imports using * syntax
python-dotenv
requests

0 comments on commit 9bebbff

Please sign in to comment.