Properly merge original Discord bot with JurySystem DBT integration
This commit is contained in:
542
bot/bot.py
542
bot/bot.py
@@ -1,16 +1,48 @@
|
||||
"""
|
||||
bot.py - Discord bot client with session management and command routing
|
||||
|
||||
Features:
|
||||
- Login flow with username/password
|
||||
- Session management with JWT tokens
|
||||
- AI-powered command parsing via registry
|
||||
- Background task loop for polling
|
||||
- JurySystem DBT integration for mental health support
|
||||
"""
|
||||
|
||||
import discord
|
||||
from discord.ext import tasks
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import base64
|
||||
import requests
|
||||
import bcrypt
|
||||
import pickle
|
||||
import numpy as np
|
||||
from openai import OpenAI
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
# --- Configuration ---
|
||||
CONFIG_PATH = os.getenv("CONFIG_PATH", "config.json")
|
||||
KNOWLEDGE_BASE_PATH = os.getenv(
|
||||
"KNOWLEDGE_BASE_PATH", "bot/data/dbt_knowledge.embeddings.json"
|
||||
)
|
||||
from bot.command_registry import get_handler, list_registered, register_module
|
||||
import ai.parser as ai_parser
|
||||
import bot.commands.routines # noqa: F401 - registers handler
|
||||
import bot.commands.medications # noqa: F401 - registers handler
|
||||
import bot.commands.knowledge # noqa: F401 - registers handler
|
||||
|
||||
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
|
||||
API_URL = os.getenv("API_URL", "http://app:5000")
|
||||
|
||||
user_sessions = {}
|
||||
login_state = {}
|
||||
message_history = {}
|
||||
user_cache = {}
|
||||
CACHE_FILE = "/app/user_cache.pkl"
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
client = discord.Client(intents=intents)
|
||||
|
||||
|
||||
# ==================== JurySystem Integration ====================
|
||||
|
||||
|
||||
class SimpleVectorStore:
|
||||
@@ -46,23 +78,28 @@ class SimpleVectorStore:
|
||||
|
||||
|
||||
class JurySystem:
|
||||
"""DBT Knowledge Base Query System"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = self.load_config()
|
||||
config_path = os.getenv("CONFIG_PATH", "config.json")
|
||||
kb_path = os.getenv(
|
||||
"KNOWLEDGE_BASE_PATH", "bot/data/dbt_knowledge.embeddings.json"
|
||||
)
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
self.client = OpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=self.config["openrouter_api_key"],
|
||||
)
|
||||
self.vector_store = SimpleVectorStore()
|
||||
self.load_knowledge_base()
|
||||
self._load_knowledge_base(kb_path)
|
||||
|
||||
def load_config(self):
|
||||
with open(CONFIG_PATH, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def load_knowledge_base(self):
|
||||
print(f"Loading knowledge base from {KNOWLEDGE_BASE_PATH}...")
|
||||
def _load_knowledge_base(self, kb_path):
|
||||
print(f"Loading DBT knowledge base from {kb_path}...")
|
||||
try:
|
||||
with open(KNOWLEDGE_BASE_PATH, "r", encoding="utf-8") as f:
|
||||
with open(kb_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
vectors = []
|
||||
metadata = []
|
||||
@@ -72,38 +109,41 @@ class JurySystem:
|
||||
{"id": item["id"], "source": item["source"], "text": item["text"]}
|
||||
)
|
||||
self.vector_store.add(vectors, metadata)
|
||||
print(f"Loaded {len(vectors)} chunks into vector store.")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {KNOWLEDGE_BASE_PATH} not found.")
|
||||
raise
|
||||
print(f"Loaded {len(vectors)} chunks into DBT vector store.")
|
||||
except Exception as e:
|
||||
print(f"Error loading knowledge base: {e}")
|
||||
print(f"Error loading DBT knowledge base: {e}")
|
||||
raise
|
||||
|
||||
def process_query(self, query):
|
||||
async def query(self, query_text):
|
||||
"""Query the DBT knowledge base"""
|
||||
try:
|
||||
# Get embedding
|
||||
response = self.client.embeddings.create(
|
||||
model="qwen/qwen3-embedding-8b", input=query
|
||||
model="qwen/qwen3-embedding-8b", input=query_text
|
||||
)
|
||||
query_emb = response.data[0].embedding
|
||||
|
||||
# Search
|
||||
context_chunks = self.vector_store.search(query_emb, top_k=5)
|
||||
|
||||
if not context_chunks:
|
||||
return "I couldn't find any relevant information in the knowledge base."
|
||||
return "I couldn't find relevant DBT information for that query."
|
||||
|
||||
# Generate answer
|
||||
context_text = "\n\n---\n\n".join(
|
||||
[chunk["metadata"]["text"] for chunk in context_chunks]
|
||||
)
|
||||
|
||||
system_prompt = """You are a helpful AI assistant specializing in DBT (Dialectical Behavior Therapy).
|
||||
Use the provided context to answer the user's question.
|
||||
system_prompt = """You are a helpful DBT (Dialectical Behavior Therapy) assistant.
|
||||
Use the provided context from the DBT Skills Training Handouts to answer the user's question.
|
||||
If the answer is not in the context, say you don't know based on the provided text.
|
||||
Be concise and compassionate."""
|
||||
Be concise, compassionate, and practical."""
|
||||
|
||||
user_prompt = f"Context:\n{context_text}\n\nQuestion: {query}"
|
||||
user_prompt = f"Context:\n{context_text}\n\nQuestion: {query_text}"
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model="openai/gpt-4o-mini",
|
||||
model=self.config.get("models", {}).get(
|
||||
"generator", "openai/gpt-4o-mini"
|
||||
),
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
@@ -113,45 +153,429 @@ Be concise and compassionate."""
|
||||
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
return f"Error processing query: {e}"
|
||||
return f"Error querying DBT knowledge base: {e}"
|
||||
|
||||
|
||||
# Initialize the Jury System
|
||||
print("Initializing AI Jury System...")
|
||||
jury_system = JurySystem()
|
||||
print("Jury System ready!")
|
||||
|
||||
# Discord Bot Setup
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
bot = commands.Bot(command_prefix="!", intents=intents)
|
||||
# Initialize JurySystem
|
||||
jury_system = None
|
||||
try:
|
||||
jury_system = JurySystem()
|
||||
print("✓ JurySystem (DBT) initialized successfully")
|
||||
except Exception as e:
|
||||
print(f"⚠ JurySystem initialization failed: {e}")
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
print(f"Bot logged in as {bot.user}")
|
||||
# ==================== Original Bot Functions ====================
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_message(message):
|
||||
if message.author == bot.user:
|
||||
def decodeJwtPayload(token):
|
||||
payload = token.split(".")[1]
|
||||
payload += "=" * (4 - len(payload) % 4)
|
||||
return json.loads(base64.urlsafe_b64decode(payload))
|
||||
|
||||
|
||||
def apiRequest(method, endpoint, token=None, data=None):
|
||||
url = f"{API_URL}{endpoint}"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
try:
|
||||
resp = getattr(requests, method)(url, headers=headers, json=data, timeout=10)
|
||||
try:
|
||||
return resp.json(), resp.status_code
|
||||
except ValueError:
|
||||
return {}, resp.status_code
|
||||
except requests.RequestException:
|
||||
return {"error": "API unavailable"}, 503
|
||||
|
||||
|
||||
def loadCache():
|
||||
try:
|
||||
if os.path.exists(CACHE_FILE):
|
||||
with open(CACHE_FILE, "rb") as f:
|
||||
global user_cache
|
||||
user_cache = pickle.load(f)
|
||||
print(f"Loaded cache for {len(user_cache)} users")
|
||||
except Exception as e:
|
||||
print(f"Error loading cache: {e}")
|
||||
|
||||
|
||||
def saveCache():
|
||||
try:
|
||||
with open(CACHE_FILE, "wb") as f:
|
||||
pickle.dump(user_cache, f)
|
||||
except Exception as e:
|
||||
print(f"Error saving cache: {e}")
|
||||
|
||||
|
||||
def hashPassword(password):
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verifyPassword(password, hashed):
|
||||
return bcrypt.checkpw(password.encode("utf-8"), hashed.encode("utf-8"))
|
||||
|
||||
|
||||
def getCachedUser(discord_id):
|
||||
return user_cache.get(discord_id)
|
||||
|
||||
|
||||
def setCachedUser(discord_id, user_data):
|
||||
user_cache[discord_id] = user_data
|
||||
saveCache()
|
||||
|
||||
|
||||
def negotiateToken(discord_id, username, password):
|
||||
cached = getCachedUser(discord_id)
|
||||
if (
|
||||
cached
|
||||
and cached.get("username") == username
|
||||
and verifyPassword(password, cached.get("hashed_password"))
|
||||
):
|
||||
result, status = apiRequest(
|
||||
"post", "/api/login", data={"username": username, "password": password}
|
||||
)
|
||||
if status == 200 and "token" in result:
|
||||
token = result["token"]
|
||||
payload = decodeJwtPayload(token)
|
||||
user_uuid = payload["sub"]
|
||||
setCachedUser(
|
||||
discord_id,
|
||||
{
|
||||
"hashed_password": cached["hashed_password"],
|
||||
"user_uuid": user_uuid,
|
||||
"username": username,
|
||||
},
|
||||
)
|
||||
return token, user_uuid
|
||||
return None, None
|
||||
|
||||
result, status = apiRequest(
|
||||
"post", "/api/login", data={"username": username, "password": password}
|
||||
)
|
||||
if status == 200 and "token" in result:
|
||||
token = result["token"]
|
||||
payload = decodeJwtPayload(token)
|
||||
user_uuid = payload["sub"]
|
||||
setCachedUser(
|
||||
discord_id,
|
||||
{
|
||||
"hashed_password": hashPassword(password),
|
||||
"user_uuid": user_uuid,
|
||||
"username": username,
|
||||
},
|
||||
)
|
||||
return token, user_uuid
|
||||
return None, None
|
||||
|
||||
|
||||
async def handleAuthFailure(message):
|
||||
discord_id = message.author.id
|
||||
user_sessions.pop(discord_id, None)
|
||||
await message.channel.send(
|
||||
"Your session has expired. Send any message to log in again."
|
||||
)
|
||||
|
||||
|
||||
async def handleLoginStep(message):
|
||||
discord_id = message.author.id
|
||||
state = login_state[discord_id]
|
||||
|
||||
if state["step"] == "username":
|
||||
state["username"] = message.content.strip()
|
||||
state["step"] = "password"
|
||||
await message.channel.send("Password?")
|
||||
|
||||
elif state["step"] == "password":
|
||||
username = state["username"]
|
||||
password = message.content.strip()
|
||||
del login_state[discord_id]
|
||||
|
||||
token, user_uuid = negotiateToken(discord_id, username, password)
|
||||
|
||||
if token and user_uuid:
|
||||
user_sessions[discord_id] = {
|
||||
"token": token,
|
||||
"user_uuid": user_uuid,
|
||||
"username": username,
|
||||
}
|
||||
registered = ", ".join(list_registered()) or "none"
|
||||
await message.channel.send(
|
||||
f"Welcome back **{username}**!\n\n"
|
||||
f"Registered modules: {registered}\n\n"
|
||||
f"Send 'help' for available commands."
|
||||
)
|
||||
else:
|
||||
await message.channel.send(
|
||||
"Invalid credentials. Send any message to try again."
|
||||
)
|
||||
|
||||
|
||||
async def sendHelpMessage(message):
|
||||
help_msg = """**🤖 Synculous Bot - Natural Language Commands**
|
||||
|
||||
Just talk to me naturally! Here are some examples:
|
||||
|
||||
**💊 Medications:**
|
||||
• "add lsd 50 mcg every tuesday at 4:20pm"
|
||||
• "take my wellbutrin"
|
||||
• "what meds do i have today?"
|
||||
• "show my refills"
|
||||
• "snooze my reminder for 30 minutes"
|
||||
• "check adherence"
|
||||
|
||||
**📋 Routines:**
|
||||
• "create morning routine with brush teeth, shower, eat"
|
||||
• "start my morning routine"
|
||||
• "done" (complete current step)
|
||||
• "skip" (skip current step)
|
||||
• "pause/resume" (pause or continue)
|
||||
• "what steps are in my routine?"
|
||||
• "schedule workout for monday wednesday friday at 7am"
|
||||
• "show my stats"
|
||||
|
||||
**🧠 DBT Support:**
|
||||
• "how do I use distress tolerance?"
|
||||
• "explain radical acceptance"
|
||||
• "give me a DBT skill for anger"
|
||||
• "what are the TIPP skills?"
|
||||
|
||||
**💡 Tips:**
|
||||
• I understand natural language, typos, and slang
|
||||
• If I'm unsure, I'll ask for clarification
|
||||
• For important actions, I'll ask you to confirm with "yes" or "no"
|
||||
• When you're in a routine, shortcuts like "done", "skip", "pause" work automatically"""
|
||||
await message.channel.send(help_msg)
|
||||
|
||||
|
||||
async def checkActiveSession(session):
|
||||
"""Check if user has an active routine session and return details."""
|
||||
token = session.get("token")
|
||||
if not token:
|
||||
return None
|
||||
|
||||
resp, status = apiRequest("get", "/api/sessions/active", token)
|
||||
if status == 200 and "session" in resp:
|
||||
return resp
|
||||
return None
|
||||
|
||||
|
||||
async def handleConfirmation(message, session):
|
||||
"""Handle yes/no confirmation responses. Returns True if handled."""
|
||||
discord_id = message.author.id
|
||||
user_input = message.content.lower().strip()
|
||||
|
||||
if "pending_confirmations" not in session:
|
||||
return False
|
||||
|
||||
pending = session["pending_confirmations"]
|
||||
if not pending:
|
||||
return False
|
||||
|
||||
confirmation_id = list(pending.keys())[-1]
|
||||
confirmation_data = pending[confirmation_id]
|
||||
|
||||
if user_input in ("yes", "y", "yeah", "sure", "ok", "confirm"):
|
||||
del pending[confirmation_id]
|
||||
interaction_type = confirmation_data.get("interaction_type")
|
||||
handler = get_handler(interaction_type)
|
||||
|
||||
if handler:
|
||||
fake_parsed = confirmation_data.copy()
|
||||
fake_parsed["needs_confirmation"] = False
|
||||
await handler(message, session, fake_parsed)
|
||||
return True
|
||||
|
||||
elif user_input in ("no", "n", "nah", "cancel", "abort"):
|
||||
del pending[confirmation_id]
|
||||
await message.channel.send("❌ Cancelled.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def handleActiveSessionShortcuts(message, session, active_session):
|
||||
"""Handle shortcuts like 'done', 'skip', 'next' when in active session."""
|
||||
user_input = message.content.lower().strip()
|
||||
|
||||
shortcuts = {
|
||||
"done": ("routine", "complete"),
|
||||
"finished": ("routine", "complete"),
|
||||
"complete": ("routine", "complete"),
|
||||
"next": ("routine", "complete"),
|
||||
"skip": ("routine", "skip"),
|
||||
"pass": ("routine", "skip"),
|
||||
"pause": ("routine", "pause"),
|
||||
"hold": ("routine", "pause"),
|
||||
"resume": ("routine", "resume"),
|
||||
"continue": ("routine", "resume"),
|
||||
"stop": ("routine", "cancel"),
|
||||
"quit": ("routine", "cancel"),
|
||||
"abort": ("routine", "abort"),
|
||||
}
|
||||
|
||||
if user_input in shortcuts:
|
||||
interaction_type, action = shortcuts[user_input]
|
||||
handler = get_handler(interaction_type)
|
||||
if handler:
|
||||
fake_parsed = {"action": action}
|
||||
await handler(message, session, fake_parsed)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def handleDBTQuery(message):
|
||||
"""Handle DBT-related queries using JurySystem"""
|
||||
if not jury_system:
|
||||
return False
|
||||
|
||||
# Keywords that indicate a DBT query
|
||||
dbt_keywords = [
|
||||
"dbt",
|
||||
"distress tolerance",
|
||||
"emotion regulation",
|
||||
"interpersonal effectiveness",
|
||||
"mindfulness",
|
||||
"radical acceptance",
|
||||
"wise mind",
|
||||
"tipp",
|
||||
"dearman",
|
||||
"check the facts",
|
||||
"opposite action",
|
||||
"cope ahead",
|
||||
"abc please",
|
||||
"stop skill",
|
||||
"pros and cons",
|
||||
"half smile",
|
||||
"willing hands",
|
||||
]
|
||||
|
||||
user_input_lower = message.content.lower()
|
||||
is_dbt_query = any(keyword in user_input_lower for keyword in dbt_keywords)
|
||||
|
||||
if is_dbt_query:
|
||||
async with message.channel.typing():
|
||||
response = await jury_system.query(message.content)
|
||||
await message.channel.send(f"🧠 **DBT Support:**\n{response}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def routeCommand(message):
|
||||
discord_id = message.author.id
|
||||
session = user_sessions[discord_id]
|
||||
user_input = message.content.lower()
|
||||
|
||||
if "help" in user_input or "what can i say" in user_input:
|
||||
await sendHelpMessage(message)
|
||||
return
|
||||
|
||||
# Process all messages as DBT queries
|
||||
if not message.content.startswith("!"):
|
||||
async with message.channel.typing():
|
||||
response = jury_system.process_query(message.content)
|
||||
await message.reply(response)
|
||||
# Check for active session first
|
||||
active_session = await checkActiveSession(session)
|
||||
|
||||
await bot.process_commands(message)
|
||||
# Handle confirmation responses
|
||||
confirmation_handled = await handleConfirmation(message, session)
|
||||
if confirmation_handled:
|
||||
return
|
||||
|
||||
# Handle shortcuts when in active session
|
||||
if active_session:
|
||||
shortcut_handled = await handleActiveSessionShortcuts(
|
||||
message, session, active_session
|
||||
)
|
||||
if shortcut_handled:
|
||||
return
|
||||
|
||||
# Check for DBT queries
|
||||
dbt_handled = await handleDBTQuery(message)
|
||||
if dbt_handled:
|
||||
return
|
||||
|
||||
async with message.channel.typing():
|
||||
history = message_history.get(discord_id, [])
|
||||
|
||||
# Add context about active session to help AI understand
|
||||
context = ""
|
||||
if active_session:
|
||||
session_data = active_session.get("session", {})
|
||||
routine_name = session_data.get("routine_name", "a routine")
|
||||
current_step = session_data.get("current_step_index", 0) + 1
|
||||
total_steps = active_session.get("total_steps", 0)
|
||||
context = f"\n[Context: User is currently in active session for '{routine_name}', on step {current_step} of {total_steps}. They can say 'done', 'skip', 'pause', 'resume', or 'stop'.]"
|
||||
|
||||
parsed = await ai_parser.parse(
|
||||
message.content + context, "command_parser", history=history
|
||||
)
|
||||
|
||||
if discord_id not in message_history:
|
||||
message_history[discord_id] = []
|
||||
message_history[discord_id].append((message.content, parsed))
|
||||
message_history[discord_id] = message_history[discord_id][-5:]
|
||||
|
||||
if "needs_clarification" in parsed:
|
||||
await message.channel.send(
|
||||
f"I'm not quite sure what you mean. {parsed['needs_clarification']}"
|
||||
)
|
||||
return
|
||||
|
||||
if "error" in parsed:
|
||||
await message.channel.send(
|
||||
f"I had trouble understanding that: {parsed['error']}"
|
||||
)
|
||||
return
|
||||
|
||||
interaction_type = parsed.get("interaction_type")
|
||||
handler = get_handler(interaction_type)
|
||||
|
||||
if handler:
|
||||
await handler(message, session, parsed)
|
||||
else:
|
||||
registered = ", ".join(list_registered()) or "none"
|
||||
await message.channel.send(
|
||||
f"Unknown command type '{interaction_type}'. Registered modules: {registered}"
|
||||
)
|
||||
|
||||
|
||||
@bot.command(name="ask")
|
||||
async def ask_dbt(ctx, *, question):
|
||||
"""Ask a DBT-related question"""
|
||||
async with ctx.typing():
|
||||
response = jury_system.process_query(question)
|
||||
await ctx.send(response)
|
||||
@client.event
|
||||
async def on_ready():
|
||||
print(f"Bot logged in as {client.user}")
|
||||
loadCache()
|
||||
backgroundLoop.start()
|
||||
|
||||
|
||||
bot.run(DISCORD_BOT_TOKEN)
|
||||
@client.event
|
||||
async def on_message(message):
|
||||
if message.author == client.user:
|
||||
return
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
return
|
||||
|
||||
discord_id = message.author.id
|
||||
|
||||
if discord_id in login_state:
|
||||
await handleLoginStep(message)
|
||||
return
|
||||
|
||||
if discord_id not in user_sessions:
|
||||
login_state[discord_id] = {"step": "username"}
|
||||
await message.channel.send("Welcome! Send your username to log in.")
|
||||
return
|
||||
|
||||
await routeCommand(message)
|
||||
|
||||
|
||||
@tasks.loop(seconds=60)
|
||||
async def backgroundLoop():
|
||||
"""Override this in your domain module or extend as needed."""
|
||||
pass
|
||||
|
||||
|
||||
@backgroundLoop.before_loop
|
||||
async def beforeBackgroundLoop():
|
||||
await client.wait_until_ready()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
client.run(DISCORD_BOT_TOKEN)
|
||||
|
||||
Reference in New Issue
Block a user