Initial commit
This commit is contained in:
272
bot/bot.py
Normal file
272
bot/bot.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
bot.py - Discord bot client with session management and command routing
|
||||
|
||||
Features:
|
||||
- Login flow with username/password
|
||||
- Session management with JWT tokens
|
||||
- AI-powered command parsing via registry
|
||||
- Background task loop for polling
|
||||
"""
|
||||
|
||||
import discord
|
||||
from discord.ext import tasks
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import base64
|
||||
import requests
|
||||
import bcrypt
|
||||
import pickle
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from bot.command_registry import get_handler, list_registered
|
||||
import ai.parser as ai_parser
|
||||
|
||||
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
|
||||
API_URL = os.getenv("API_URL", "http://app:5000")
|
||||
|
||||
user_sessions = {}
|
||||
login_state = {}
|
||||
message_history = {}
|
||||
user_cache = {}
|
||||
CACHE_FILE = "/app/user_cache.pkl"
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
|
||||
def decodeJwtPayload(token):
|
||||
payload = token.split(".")[1]
|
||||
payload += "=" * (4 - len(payload) % 4)
|
||||
return json.loads(base64.urlsafe_b64decode(payload))
|
||||
|
||||
|
||||
def apiRequest(method, endpoint, token=None, data=None):
|
||||
url = f"{API_URL}{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
try:
|
||||
resp = getattr(requests, method)(url, headers=headers, json=data, timeout=10)
|
||||
try:
|
||||
return resp.json(), resp.status_code
|
||||
except ValueError:
|
||||
return {}, resp.status_code
|
||||
except requests.RequestException:
|
||||
return {"error": "API unavailable"}, 503
|
||||
|
||||
|
||||
def loadCache():
|
||||
try:
|
||||
if os.path.exists(CACHE_FILE):
|
||||
with open(CACHE_FILE, "rb") as f:
|
||||
global user_cache
|
||||
user_cache = pickle.load(f)
|
||||
print(f"Loaded cache for {len(user_cache)} users")
|
||||
except Exception as e:
|
||||
print(f"Error loading cache: {e}")
|
||||
|
||||
|
||||
def saveCache():
|
||||
try:
|
||||
with open(CACHE_FILE, "wb") as f:
|
||||
pickle.dump(user_cache, f)
|
||||
except Exception as e:
|
||||
print(f"Error saving cache: {e}")
|
||||
|
||||
|
||||
def hashPassword(password):
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verifyPassword(password, hashed):
|
||||
return bcrypt.checkpw(password.encode("utf-8"), hashed.encode("utf-8"))
|
||||
|
||||
|
||||
def getCachedUser(discord_id):
|
||||
return user_cache.get(discord_id)
|
||||
|
||||
|
||||
def setCachedUser(discord_id, user_data):
|
||||
user_cache[discord_id] = user_data
|
||||
saveCache()
|
||||
|
||||
|
||||
def negotiateToken(discord_id, username, password):
|
||||
cached = getCachedUser(discord_id)
|
||||
if (
|
||||
cached
|
||||
and cached.get("username") == username
|
||||
and verifyPassword(password, cached.get("hashed_password"))
|
||||
):
|
||||
result, status = apiRequest(
|
||||
"post", "/api/login", data={"username": username, "password": password}
|
||||
)
|
||||
if status == 200 and "token" in result:
|
||||
token = result["token"]
|
||||
payload = decodeJwtPayload(token)
|
||||
user_uuid = payload["sub"]
|
||||
setCachedUser(
|
||||
discord_id,
|
||||
{
|
||||
"hashed_password": cached["hashed_password"],
|
||||
"user_uuid": user_uuid,
|
||||
"username": username,
|
||||
},
|
||||
)
|
||||
return token, user_uuid
|
||||
return None, None
|
||||
|
||||
result, status = apiRequest(
|
||||
"post", "/api/login", data={"username": username, "password": password}
|
||||
)
|
||||
if status == 200 and "token" in result:
|
||||
token = result["token"]
|
||||
payload = decodeJwtPayload(token)
|
||||
user_uuid = payload["sub"]
|
||||
setCachedUser(
|
||||
discord_id,
|
||||
{
|
||||
"hashed_password": hashPassword(password),
|
||||
"user_uuid": user_uuid,
|
||||
"username": username,
|
||||
},
|
||||
)
|
||||
return token, user_uuid
|
||||
return None, None
|
||||
|
||||
|
||||
async def handleAuthFailure(message):
|
||||
discord_id = message.author.id
|
||||
user_sessions.pop(discord_id, None)
|
||||
await message.channel.send(
|
||||
"Your session has expired. Send any message to log in again."
|
||||
)
|
||||
|
||||
|
||||
async def handleLoginStep(message):
|
||||
discord_id = message.author.id
|
||||
state = login_state[discord_id]
|
||||
|
||||
if state["step"] == "username":
|
||||
state["username"] = message.content.strip()
|
||||
state["step"] = "password"
|
||||
await message.channel.send("Password?")
|
||||
|
||||
elif state["step"] == "password":
|
||||
username = state["username"]
|
||||
password = message.content.strip()
|
||||
del login_state[discord_id]
|
||||
|
||||
token, user_uuid = negotiateToken(discord_id, username, password)
|
||||
|
||||
if token and user_uuid:
|
||||
user_sessions[discord_id] = {
|
||||
"token": token,
|
||||
"user_uuid": user_uuid,
|
||||
"username": username,
|
||||
}
|
||||
registered = ", ".join(list_registered()) or "none"
|
||||
await message.channel.send(
|
||||
f"Welcome back **{username}**!\n\n"
|
||||
f"Registered modules: {registered}\n\n"
|
||||
f"Send 'help' for available commands."
|
||||
)
|
||||
else:
|
||||
await message.channel.send(
|
||||
"Invalid credentials. Send any message to try again."
|
||||
)
|
||||
|
||||
|
||||
async def sendHelpMessage(message):
|
||||
registered = list_registered()
|
||||
help_msg = f"**Available Modules:**\n{chr(10).join(f'- {m}' for m in registered) if registered else '- No modules registered'}\n\nJust talk naturally and I'll help you out!"
|
||||
await message.channel.send(help_msg)
|
||||
|
||||
|
||||
async def routeCommand(message):
|
||||
discord_id = message.author.id
|
||||
session = user_sessions[discord_id]
|
||||
user_input = message.content.lower()
|
||||
|
||||
if "help" in user_input or "what can i say" in user_input:
|
||||
await sendHelpMessage(message)
|
||||
return
|
||||
|
||||
async with message.channel.typing():
|
||||
history = message_history.get(discord_id, [])
|
||||
parsed = ai_parser.parse(message.content, "command_parser", history=history)
|
||||
|
||||
if discord_id not in message_history:
|
||||
message_history[discord_id] = []
|
||||
message_history[discord_id].append((message.content, parsed))
|
||||
message_history[discord_id] = message_history[discord_id][-5:]
|
||||
|
||||
if "needs_clarification" in parsed:
|
||||
await message.channel.send(
|
||||
f"I'm not quite sure what you mean. {parsed['needs_clarification']}"
|
||||
)
|
||||
return
|
||||
|
||||
if "error" in parsed:
|
||||
await message.channel.send(
|
||||
f"I had trouble understanding that: {parsed['error']}"
|
||||
)
|
||||
return
|
||||
|
||||
interaction_type = parsed.get("interaction_type")
|
||||
handler = get_handler(interaction_type)
|
||||
|
||||
if handler:
|
||||
await handler(message, session, parsed)
|
||||
else:
|
||||
registered = ", ".join(list_registered()) or "none"
|
||||
await message.channel.send(
|
||||
f"Unknown command type '{interaction_type}'. Registered modules: {registered}"
|
||||
)
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Bot logged in as {client.user}")
|
||||
loadCache()
|
||||
backgroundLoop.start()
|
||||
|
||||
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
if message.author == client.user:
|
||||
return
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
return
|
||||
|
||||
discord_id = message.author.id
|
||||
|
||||
if discord_id in login_state:
|
||||
await handleLoginStep(message)
|
||||
return
|
||||
|
||||
if discord_id not in user_sessions:
|
||||
login_state[discord_id] = {"step": "username"}
|
||||
await message.channel.send("Welcome! Send your username to log in.")
|
||||
return
|
||||
|
||||
await routeCommand(message)
|
||||
|
||||
|
||||
@tasks.loop(seconds=60)
|
||||
async def backgroundLoop():
|
||||
"""Override this in your domain module or extend as needed."""
|
||||
pass
|
||||
|
||||
|
||||
@backgroundLoop.before_loop
|
||||
async def beforeBackgroundLoop():
|
||||
await client.wait_until_ready()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
client.run(DISCORD_BOT_TOKEN)
|
||||
Reference in New Issue
Block a user