Compare commits
2 Commits
b3dab95cf9
...
c7be19611a
| Author | SHA1 | Date | |
|---|---|---|---|
| c7be19611a | |||
| b1bb05e879 |
560
bot/bot.py
560
bot/bot.py
@@ -1,409 +1,195 @@
|
|||||||
"""
|
|
||||||
bot.py - Discord bot client with session management and command routing
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Login flow with username/password
|
|
||||||
- Session management with JWT tokens
|
|
||||||
- AI-powered command parsing via registry
|
|
||||||
- Background task loop for polling
|
|
||||||
"""
|
|
||||||
|
|
||||||
import discord
|
|
||||||
from discord.ext import tasks
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
import base64
|
import time
|
||||||
import requests
|
import numpy as np
|
||||||
import bcrypt
|
from openai import OpenAI
|
||||||
import pickle
|
|
||||||
|
|
||||||
from bot.command_registry import get_handler, list_registered
|
# --- Configuration ---
|
||||||
import ai.parser as ai_parser
|
CONFIG_PATH = 'config.json'
|
||||||
import bot.commands.routines # noqa: F401 - registers handler
|
KNOWLEDGE_BASE_PATH = 'dbt_knowledge.json'
|
||||||
import bot.commands.medications # noqa: F401 - registers handler
|
|
||||||
import bot.commands.knowledge # noqa: F401 - registers handler
|
|
||||||
|
|
||||||
DISCORD_BOT_TOKEN = os.getenv("DISCORD_BOT_TOKEN")
|
class SimpleVectorStore:
|
||||||
API_URL = os.getenv("API_URL", "http://app:5000")
|
"""A simple in-memory vector store using NumPy."""
|
||||||
|
def __init__(self):
|
||||||
|
self.vectors = []
|
||||||
|
self.metadata = []
|
||||||
|
|
||||||
user_sessions = {}
|
def add(self, vectors, metadatas):
|
||||||
login_state = {}
|
self.vectors.extend(vectors)
|
||||||
message_history = {}
|
self.metadata.extend(metadatas)
|
||||||
user_cache = {}
|
|
||||||
CACHE_FILE = "/app/user_cache.pkl"
|
|
||||||
|
|
||||||
intents = discord.Intents.default()
|
def search(self, query_vector, top_k=5):
|
||||||
intents.message_content = True
|
if not self.vectors:
|
||||||
|
return []
|
||||||
|
|
||||||
client = discord.Client(intents=intents)
|
# Convert to numpy arrays for efficient math
|
||||||
|
query_vec = np.array(query_vector)
|
||||||
|
doc_vecs = np.array(self.vectors)
|
||||||
|
|
||||||
|
# Cosine Similarity: (A . B) / (||A|| * ||B||)
|
||||||
|
# Note: Both vectors must have the same dimension (e.g., 4096)
|
||||||
|
norms = np.linalg.norm(doc_vecs, axis=1)
|
||||||
|
|
||||||
|
# Avoid division by zero
|
||||||
|
valid_indices = norms > 0
|
||||||
|
scores = np.zeros(len(doc_vecs))
|
||||||
|
|
||||||
|
# Calculate dot product
|
||||||
|
dot_products = np.dot(doc_vecs, query_vec)
|
||||||
|
|
||||||
|
# Calculate cosine similarity only for valid norms
|
||||||
|
scores[valid_indices] = dot_products[valid_indices] / (norms[valid_indices] * np.linalg.norm(query_vec))
|
||||||
|
|
||||||
|
# Get top_k indices
|
||||||
|
top_indices = np.argsort(scores)[-top_k:][::-1]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for idx in top_indices:
|
||||||
|
results.append({
|
||||||
|
"metadata": self.metadata[idx],
|
||||||
|
"score": scores[idx]
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
def decodeJwtPayload(token):
|
class JurySystem:
|
||||||
payload = token.split(".")[1]
|
def __init__(self):
|
||||||
payload += "=" * (4 - len(payload) % 4)
|
self.config = self.load_config()
|
||||||
return json.loads(base64.urlsafe_b64decode(payload))
|
|
||||||
|
# Initialize OpenRouter Client
|
||||||
|
self.client = OpenAI(
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
api_key=self.config['openrouter_api_key']
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vector_store = SimpleVectorStore()
|
||||||
|
self.load_knowledge_base()
|
||||||
|
|
||||||
|
def load_config(self):
|
||||||
|
with open(CONFIG_PATH, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
def apiRequest(method, endpoint, token=None, data=None):
|
def load_knowledge_base(self):
|
||||||
url = f"{API_URL}{endpoint}"
|
"""Loads the pre-computed embeddings from the JSON file."""
|
||||||
headers = {"Content-Type": "application/json"}
|
print(f"Loading knowledge base from {KNOWLEDGE_BASE_PATH}...")
|
||||||
if token:
|
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
|
||||||
try:
|
|
||||||
resp = getattr(requests, method)(url, headers=headers, json=data, timeout=10)
|
|
||||||
try:
|
try:
|
||||||
return resp.json(), resp.status_code
|
with open(KNOWLEDGE_BASE_PATH, 'r', encoding='utf-8') as f:
|
||||||
except ValueError:
|
data = json.load(f)
|
||||||
return {}, resp.status_code
|
|
||||||
except requests.RequestException:
|
vectors = []
|
||||||
return {"error": "API unavailable"}, 503
|
metadata = []
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
vectors.append(item['embedding'])
|
||||||
|
metadata.append({
|
||||||
|
"id": item['id'],
|
||||||
|
"source": item['source'],
|
||||||
|
"text": item['text']
|
||||||
|
})
|
||||||
|
|
||||||
|
self.vector_store.add(vectors, metadata)
|
||||||
|
print(f"Loaded {len(vectors)} chunks into vector store.")
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: {KNOWLEDGE_BASE_PATH} not found. Did you run the embedder script?")
|
||||||
|
exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading knowledge base: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
def retrieve_context(self, query, top_k=5):
|
||||||
def loadCache():
|
print("[1. Retrieving Context...]")
|
||||||
try:
|
|
||||||
if os.path.exists(CACHE_FILE):
|
try:
|
||||||
with open(CACHE_FILE, "rb") as f:
|
# --- CRITICAL FIX: Use the EXACT same model as the embedder ---
|
||||||
global user_cache
|
# Embedder used: "qwen/qwen3-embedding-8b" -> Dimension 4096
|
||||||
user_cache = pickle.load(f)
|
# We must use the same here to avoid shape mismatch.
|
||||||
print(f"Loaded cache for {len(user_cache)} users")
|
response = self.client.embeddings.create(
|
||||||
except Exception as e:
|
model="qwen/qwen3-embedding-8b",
|
||||||
print(f"Error loading cache: {e}")
|
input=query
|
||||||
|
|
||||||
|
|
||||||
def saveCache():
|
|
||||||
try:
|
|
||||||
with open(CACHE_FILE, "wb") as f:
|
|
||||||
pickle.dump(user_cache, f)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving cache: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def hashPassword(password):
|
|
||||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def verifyPassword(password, hashed):
|
|
||||||
return bcrypt.checkpw(password.encode("utf-8"), hashed.encode("utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
def getCachedUser(discord_id):
|
|
||||||
return user_cache.get(discord_id)
|
|
||||||
|
|
||||||
|
|
||||||
def setCachedUser(discord_id, user_data):
|
|
||||||
user_cache[discord_id] = user_data
|
|
||||||
saveCache()
|
|
||||||
|
|
||||||
|
|
||||||
def negotiateToken(discord_id, username, password):
|
|
||||||
cached = getCachedUser(discord_id)
|
|
||||||
if (
|
|
||||||
cached
|
|
||||||
and cached.get("username") == username
|
|
||||||
and verifyPassword(password, cached.get("hashed_password"))
|
|
||||||
):
|
|
||||||
result, status = apiRequest(
|
|
||||||
"post", "/api/login", data={"username": username, "password": password}
|
|
||||||
)
|
|
||||||
if status == 200 and "token" in result:
|
|
||||||
token = result["token"]
|
|
||||||
payload = decodeJwtPayload(token)
|
|
||||||
user_uuid = payload["sub"]
|
|
||||||
setCachedUser(
|
|
||||||
discord_id,
|
|
||||||
{
|
|
||||||
"hashed_password": cached["hashed_password"],
|
|
||||||
"user_uuid": user_uuid,
|
|
||||||
"username": username,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return token, user_uuid
|
|
||||||
return None, None
|
query_emb = response.data[0].embedding
|
||||||
|
|
||||||
|
# Search the vector store
|
||||||
|
context_chunks = self.vector_store.search(query_emb, top_k=top_k)
|
||||||
|
|
||||||
|
return context_chunks
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error retrieving context: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
result, status = apiRequest(
|
def generate_answer(self, query, context_chunks):
|
||||||
"post", "/api/login", data={"username": username, "password": password}
|
print("[2. Generating Answer...]")
|
||||||
)
|
|
||||||
if status == 200 and "token" in result:
|
# Build the context string
|
||||||
token = result["token"]
|
context_text = "\n\n---\n\n".join([chunk['metadata']['text'] for chunk in context_chunks])
|
||||||
payload = decodeJwtPayload(token)
|
|
||||||
user_uuid = payload["sub"]
|
system_prompt = """You are a helpful AI assistant specializing in DBT (Dialectical Behavior Therapy).
|
||||||
setCachedUser(
|
Use the provided context to answer the user's question.
|
||||||
discord_id,
|
If the answer is not in the context, say you don't know based on the provided text.
|
||||||
{
|
Be concise and compassionate."""
|
||||||
"hashed_password": hashPassword(password),
|
|
||||||
"user_uuid": user_uuid,
|
|
||||||
"username": username,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return token, user_uuid
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
user_prompt = f"""Context:
|
||||||
|
{context_text}
|
||||||
|
|
||||||
async def handleAuthFailure(message):
|
Question: {query}"""
|
||||||
discord_id = message.author.id
|
|
||||||
user_sessions.pop(discord_id, None)
|
|
||||||
await message.channel.send(
|
|
||||||
"Your session has expired. Send any message to log in again."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
try:
|
||||||
async def handleLoginStep(message):
|
# Using a strong model for the final generation
|
||||||
discord_id = message.author.id
|
response = self.client.chat.completions.create(
|
||||||
state = login_state[discord_id]
|
model="openai/gpt-4o-mini", # You can change this to "qwen/qwen-3-8b" or similar if desired
|
||||||
|
messages=[
|
||||||
if state["step"] == "username":
|
{"role": "system", "content": system_prompt},
|
||||||
state["username"] = message.content.strip()
|
{"role": "user", "content": user_prompt}
|
||||||
state["step"] = "password"
|
],
|
||||||
await message.channel.send("Password?")
|
temperature=0.7
|
||||||
|
|
||||||
elif state["step"] == "password":
|
|
||||||
username = state["username"]
|
|
||||||
password = message.content.strip()
|
|
||||||
del login_state[discord_id]
|
|
||||||
|
|
||||||
token, user_uuid = negotiateToken(discord_id, username, password)
|
|
||||||
|
|
||||||
if token and user_uuid:
|
|
||||||
user_sessions[discord_id] = {
|
|
||||||
"token": token,
|
|
||||||
"user_uuid": user_uuid,
|
|
||||||
"username": username,
|
|
||||||
}
|
|
||||||
registered = ", ".join(list_registered()) or "none"
|
|
||||||
await message.channel.send(
|
|
||||||
f"Welcome back **{username}**!\n\n"
|
|
||||||
f"Registered modules: {registered}\n\n"
|
|
||||||
f"Send 'help' for available commands."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await message.channel.send(
|
|
||||||
"Invalid credentials. Send any message to try again."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error generating answer: {e}"
|
||||||
|
|
||||||
|
def process_query(self, query):
|
||||||
|
# 1. Retrieve
|
||||||
|
context = self.retrieve_context(query)
|
||||||
|
|
||||||
|
if not context:
|
||||||
|
return "I couldn't find any relevant information in the knowledge base."
|
||||||
|
|
||||||
|
# Optional: Print sources for debugging
|
||||||
|
print(f" Found {len(context)} relevant chunks (Top score: {context[0]['score']:.4f})")
|
||||||
|
|
||||||
|
# 2. Generate
|
||||||
|
answer = self.generate_answer(query, context)
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
async def sendHelpMessage(message):
|
def main():
|
||||||
help_msg = """**🤖 Synculous Bot - Natural Language Commands**
|
print("Initializing AI Jury System...")
|
||||||
|
system = JurySystem()
|
||||||
Just talk to me naturally! Here are some examples:
|
|
||||||
|
print("\nSystem Ready. Ask a question (or type 'exit').")
|
||||||
**💊 Medications:**
|
|
||||||
• "add lsd 50 mcg every tuesday at 4:20pm"
|
while True:
|
||||||
• "take my wellbutrin"
|
try:
|
||||||
• "what meds do i have today?"
|
user_query = input("\nYou: ").strip()
|
||||||
• "show my refills"
|
|
||||||
• "snooze my reminder for 30 minutes"
|
if user_query.lower() in ['exit', 'quit']:
|
||||||
• "check adherence"
|
print("Goodbye!")
|
||||||
|
break
|
||||||
**📋 Routines:**
|
|
||||||
• "create morning routine with brush teeth, shower, eat"
|
if not user_query:
|
||||||
• "start my morning routine"
|
continue
|
||||||
• "done" (complete current step)
|
|
||||||
• "skip" (skip current step)
|
response = system.process_query(user_query)
|
||||||
• "pause/resume" (pause or continue)
|
print(f"\nAI: {response}")
|
||||||
• "what steps are in my routine?"
|
|
||||||
• "schedule workout for monday wednesday friday at 7am"
|
except KeyboardInterrupt:
|
||||||
• "show my stats"
|
print("\nGoodbye!")
|
||||||
|
break
|
||||||
**💡 Tips:**
|
except Exception as e:
|
||||||
• I understand natural language, typos, and slang
|
print(f"\nAn error occurred: {e}")
|
||||||
• If I'm unsure, I'll ask for clarification
|
|
||||||
• For important actions, I'll ask you to confirm with "yes" or "no"
|
|
||||||
• When you're in a routine, shortcuts like "done", "skip", "pause" work automatically"""
|
|
||||||
await message.channel.send(help_msg)
|
|
||||||
|
|
||||||
|
|
||||||
async def checkActiveSession(session):
|
|
||||||
"""Check if user has an active routine session and return details."""
|
|
||||||
token = session.get("token")
|
|
||||||
if not token:
|
|
||||||
return None
|
|
||||||
|
|
||||||
resp, status = apiRequest("get", "/api/sessions/active", token)
|
|
||||||
if status == 200 and "session" in resp:
|
|
||||||
return resp
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def handleConfirmation(message, session):
|
|
||||||
"""Handle yes/no confirmation responses. Returns True if handled."""
|
|
||||||
discord_id = message.author.id
|
|
||||||
user_input = message.content.lower().strip()
|
|
||||||
|
|
||||||
if "pending_confirmations" not in session:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check for any pending confirmations
|
|
||||||
pending = session["pending_confirmations"]
|
|
||||||
if not pending:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Get the most recent pending confirmation
|
|
||||||
confirmation_id = list(pending.keys())[-1]
|
|
||||||
confirmation_data = pending[confirmation_id]
|
|
||||||
|
|
||||||
if user_input in ("yes", "y", "yeah", "sure", "ok", "confirm"):
|
|
||||||
# Execute the confirmed action
|
|
||||||
del pending[confirmation_id]
|
|
||||||
|
|
||||||
interaction_type = confirmation_data.get("interaction_type")
|
|
||||||
handler = get_handler(interaction_type)
|
|
||||||
|
|
||||||
if handler:
|
|
||||||
# Create a fake parsed object for the handler
|
|
||||||
fake_parsed = confirmation_data.copy()
|
|
||||||
fake_parsed["needs_confirmation"] = False
|
|
||||||
await handler(message, session, fake_parsed)
|
|
||||||
return True
|
|
||||||
|
|
||||||
elif user_input in ("no", "n", "nah", "cancel", "abort"):
|
|
||||||
del pending[confirmation_id]
|
|
||||||
await message.channel.send("❌ Cancelled.")
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def handleActiveSessionShortcuts(message, session, active_session):
|
|
||||||
"""Handle shortcuts like 'done', 'skip', 'next' when in active session."""
|
|
||||||
user_input = message.content.lower().strip()
|
|
||||||
|
|
||||||
# Map common shortcuts to actions
|
|
||||||
shortcuts = {
|
|
||||||
"done": ("routine", "complete"),
|
|
||||||
"finished": ("routine", "complete"),
|
|
||||||
"complete": ("routine", "complete"),
|
|
||||||
"next": ("routine", "complete"),
|
|
||||||
"skip": ("routine", "skip"),
|
|
||||||
"pass": ("routine", "skip"),
|
|
||||||
"pause": ("routine", "pause"),
|
|
||||||
"hold": ("routine", "pause"),
|
|
||||||
"resume": ("routine", "resume"),
|
|
||||||
"continue": ("routine", "resume"),
|
|
||||||
"stop": ("routine", "cancel"),
|
|
||||||
"quit": ("routine", "cancel"),
|
|
||||||
"abort": ("routine", "abort"),
|
|
||||||
}
|
|
||||||
|
|
||||||
if user_input in shortcuts:
|
|
||||||
interaction_type, action = shortcuts[user_input]
|
|
||||||
handler = get_handler(interaction_type)
|
|
||||||
if handler:
|
|
||||||
fake_parsed = {"action": action}
|
|
||||||
await handler(message, session, fake_parsed)
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def routeCommand(message):
|
|
||||||
discord_id = message.author.id
|
|
||||||
session = user_sessions[discord_id]
|
|
||||||
user_input = message.content.lower()
|
|
||||||
|
|
||||||
if "help" in user_input or "what can i say" in user_input:
|
|
||||||
await sendHelpMessage(message)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check for active session first
|
|
||||||
active_session = await checkActiveSession(session)
|
|
||||||
|
|
||||||
# Handle confirmation responses
|
|
||||||
confirmation_handled = await handleConfirmation(message, session)
|
|
||||||
if confirmation_handled:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle shortcuts when in active session
|
|
||||||
if active_session:
|
|
||||||
shortcut_handled = await handleActiveSessionShortcuts(
|
|
||||||
message, session, active_session
|
|
||||||
)
|
|
||||||
if shortcut_handled:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with message.channel.typing():
|
|
||||||
history = message_history.get(discord_id, [])
|
|
||||||
|
|
||||||
# Add context about active session to help AI understand
|
|
||||||
context = ""
|
|
||||||
if active_session:
|
|
||||||
session_data = active_session.get("session", {})
|
|
||||||
routine_name = session_data.get("routine_name", "a routine")
|
|
||||||
current_step = session_data.get("current_step_index", 0) + 1
|
|
||||||
total_steps = active_session.get("total_steps", 0)
|
|
||||||
context = f"\n[Context: User is currently in active session for '{routine_name}', on step {current_step} of {total_steps}. They can say 'done', 'skip', 'pause', 'resume', or 'stop'.]"
|
|
||||||
|
|
||||||
parsed = await ai_parser.parse(
|
|
||||||
message.content + context, "command_parser", history=history
|
|
||||||
)
|
|
||||||
|
|
||||||
if discord_id not in message_history:
|
|
||||||
message_history[discord_id] = []
|
|
||||||
message_history[discord_id].append((message.content, parsed))
|
|
||||||
message_history[discord_id] = message_history[discord_id][-5:]
|
|
||||||
|
|
||||||
if "needs_clarification" in parsed:
|
|
||||||
await message.channel.send(
|
|
||||||
f"I'm not quite sure what you mean. {parsed['needs_clarification']}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if "error" in parsed:
|
|
||||||
await message.channel.send(
|
|
||||||
f"I had trouble understanding that: {parsed['error']}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
interaction_type = parsed.get("interaction_type")
|
|
||||||
handler = get_handler(interaction_type)
|
|
||||||
|
|
||||||
if handler:
|
|
||||||
await handler(message, session, parsed)
|
|
||||||
else:
|
|
||||||
registered = ", ".join(list_registered()) or "none"
|
|
||||||
await message.channel.send(
|
|
||||||
f"Unknown command type '{interaction_type}'. Registered modules: {registered}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_ready():
|
|
||||||
print(f"Bot logged in as {client.user}")
|
|
||||||
loadCache()
|
|
||||||
backgroundLoop.start()
|
|
||||||
|
|
||||||
|
|
||||||
@client.event
|
|
||||||
async def on_message(message):
|
|
||||||
if message.author == client.user:
|
|
||||||
return
|
|
||||||
if not isinstance(message.channel, discord.DMChannel):
|
|
||||||
return
|
|
||||||
|
|
||||||
discord_id = message.author.id
|
|
||||||
|
|
||||||
if discord_id in login_state:
|
|
||||||
await handleLoginStep(message)
|
|
||||||
return
|
|
||||||
|
|
||||||
if discord_id not in user_sessions:
|
|
||||||
login_state[discord_id] = {"step": "username"}
|
|
||||||
await message.channel.send("Welcome! Send your username to log in.")
|
|
||||||
return
|
|
||||||
|
|
||||||
await routeCommand(message)
|
|
||||||
|
|
||||||
|
|
||||||
@tasks.loop(seconds=60)
|
|
||||||
async def backgroundLoop():
|
|
||||||
"""Override this in your domain module or extend as needed."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@backgroundLoop.before_loop
|
|
||||||
async def beforeBackgroundLoop():
|
|
||||||
await client.wait_until_ready()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
client.run(DISCORD_BOT_TOKEN)
|
main()
|
||||||
@@ -53,9 +53,24 @@ def load_knowledge_base(
|
|||||||
with open(file_path, "r") as f:
|
with open(file_path, "r") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
|
||||||
chunks = data.get("chunks", [])
|
# Handle both dict format {"chunks": [...], "embeddings": [...], "metadata": {...}}
|
||||||
embeddings = data.get("embeddings", [])
|
# and legacy list format where data is just the chunks
|
||||||
metadata = data.get("metadata", {})
|
if isinstance(data, dict):
|
||||||
|
chunks = data.get("chunks", [])
|
||||||
|
embeddings = data.get("embeddings", [])
|
||||||
|
metadata = data.get("metadata", {})
|
||||||
|
elif isinstance(data, list):
|
||||||
|
# Legacy format: assume it's just chunks, or list of [chunk, embedding] pairs
|
||||||
|
if data and isinstance(data[0], dict) and "text" in data[0]:
|
||||||
|
# Format: [{"text": "...", "embedding": [...]}, ...]
|
||||||
|
chunks = [item.get("text", "") for item in data]
|
||||||
|
embeddings = [item.get("embedding", []) for item in data]
|
||||||
|
metadata = {"format": "legacy_list_of_dicts"}
|
||||||
|
else:
|
||||||
|
# Unknown list format - can't process
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
# Add file_path to metadata for reference
|
# Add file_path to metadata for reference
|
||||||
metadata["_file_path"] = file_path
|
metadata["_file_path"] = file_path
|
||||||
|
|||||||
Reference in New Issue
Block a user