Handle both dict and list formats in load_knowledge_base function to fix AttributeError
195 lines
6.2 KiB
Python
195 lines
6.2 KiB
Python
import json
|
|
import time
|
|
import numpy as np
|
|
from openai import OpenAI
|
|
|
|
# --- Configuration ---
|
|
CONFIG_PATH = 'config.json'
|
|
KNOWLEDGE_BASE_PATH = 'dbt_knowledge.json'
|
|
|
|
class SimpleVectorStore:
|
|
"""A simple in-memory vector store using NumPy."""
|
|
def __init__(self):
|
|
self.vectors = []
|
|
self.metadata = []
|
|
|
|
def add(self, vectors, metadatas):
|
|
self.vectors.extend(vectors)
|
|
self.metadata.extend(metadatas)
|
|
|
|
def search(self, query_vector, top_k=5):
|
|
if not self.vectors:
|
|
return []
|
|
|
|
# Convert to numpy arrays for efficient math
|
|
query_vec = np.array(query_vector)
|
|
doc_vecs = np.array(self.vectors)
|
|
|
|
# Cosine Similarity: (A . B) / (||A|| * ||B||)
|
|
# Note: Both vectors must have the same dimension (e.g., 4096)
|
|
norms = np.linalg.norm(doc_vecs, axis=1)
|
|
|
|
# Avoid division by zero
|
|
valid_indices = norms > 0
|
|
scores = np.zeros(len(doc_vecs))
|
|
|
|
# Calculate dot product
|
|
dot_products = np.dot(doc_vecs, query_vec)
|
|
|
|
# Calculate cosine similarity only for valid norms
|
|
scores[valid_indices] = dot_products[valid_indices] / (norms[valid_indices] * np.linalg.norm(query_vec))
|
|
|
|
# Get top_k indices
|
|
top_indices = np.argsort(scores)[-top_k:][::-1]
|
|
|
|
results = []
|
|
for idx in top_indices:
|
|
results.append({
|
|
"metadata": self.metadata[idx],
|
|
"score": scores[idx]
|
|
})
|
|
return results
|
|
|
|
class JurySystem:
|
|
def __init__(self):
|
|
self.config = self.load_config()
|
|
|
|
# Initialize OpenRouter Client
|
|
self.client = OpenAI(
|
|
base_url="https://openrouter.ai/api/v1",
|
|
api_key=self.config['openrouter_api_key']
|
|
)
|
|
|
|
self.vector_store = SimpleVectorStore()
|
|
self.load_knowledge_base()
|
|
|
|
def load_config(self):
|
|
with open(CONFIG_PATH, 'r') as f:
|
|
return json.load(f)
|
|
|
|
def load_knowledge_base(self):
|
|
"""Loads the pre-computed embeddings from the JSON file."""
|
|
print(f"Loading knowledge base from {KNOWLEDGE_BASE_PATH}...")
|
|
try:
|
|
with open(KNOWLEDGE_BASE_PATH, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
vectors = []
|
|
metadata = []
|
|
|
|
for item in data:
|
|
vectors.append(item['embedding'])
|
|
metadata.append({
|
|
"id": item['id'],
|
|
"source": item['source'],
|
|
"text": item['text']
|
|
})
|
|
|
|
self.vector_store.add(vectors, metadata)
|
|
print(f"Loaded {len(vectors)} chunks into vector store.")
|
|
|
|
except FileNotFoundError:
|
|
print(f"Error: {KNOWLEDGE_BASE_PATH} not found. Did you run the embedder script?")
|
|
exit(1)
|
|
except Exception as e:
|
|
print(f"Error loading knowledge base: {e}")
|
|
exit(1)
|
|
|
|
def retrieve_context(self, query, top_k=5):
|
|
print("[1. Retrieving Context...]")
|
|
|
|
try:
|
|
# --- CRITICAL FIX: Use the EXACT same model as the embedder ---
|
|
# Embedder used: "qwen/qwen3-embedding-8b" -> Dimension 4096
|
|
# We must use the same here to avoid shape mismatch.
|
|
response = self.client.embeddings.create(
|
|
model="qwen/qwen3-embedding-8b",
|
|
input=query
|
|
)
|
|
|
|
query_emb = response.data[0].embedding
|
|
|
|
# Search the vector store
|
|
context_chunks = self.vector_store.search(query_emb, top_k=top_k)
|
|
|
|
return context_chunks
|
|
|
|
except Exception as e:
|
|
print(f"Error retrieving context: {e}")
|
|
return []
|
|
|
|
def generate_answer(self, query, context_chunks):
|
|
print("[2. Generating Answer...]")
|
|
|
|
# Build the context string
|
|
context_text = "\n\n---\n\n".join([chunk['metadata']['text'] for chunk in context_chunks])
|
|
|
|
system_prompt = """You are a helpful AI assistant specializing in DBT (Dialectical Behavior Therapy).
|
|
Use the provided context to answer the user's question.
|
|
If the answer is not in the context, say you don't know based on the provided text.
|
|
Be concise and compassionate."""
|
|
|
|
user_prompt = f"""Context:
|
|
{context_text}
|
|
|
|
Question: {query}"""
|
|
|
|
try:
|
|
# Using a strong model for the final generation
|
|
response = self.client.chat.completions.create(
|
|
model="openai/gpt-4o-mini", # You can change this to "qwen/qwen-3-8b" or similar if desired
|
|
messages=[
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt}
|
|
],
|
|
temperature=0.7
|
|
)
|
|
|
|
return response.choices[0].message.content
|
|
|
|
except Exception as e:
|
|
return f"Error generating answer: {e}"
|
|
|
|
def process_query(self, query):
|
|
# 1. Retrieve
|
|
context = self.retrieve_context(query)
|
|
|
|
if not context:
|
|
return "I couldn't find any relevant information in the knowledge base."
|
|
|
|
# Optional: Print sources for debugging
|
|
print(f" Found {len(context)} relevant chunks (Top score: {context[0]['score']:.4f})")
|
|
|
|
# 2. Generate
|
|
answer = self.generate_answer(query, context)
|
|
|
|
return answer
|
|
|
|
def main():
|
|
print("Initializing AI Jury System...")
|
|
system = JurySystem()
|
|
|
|
print("\nSystem Ready. Ask a question (or type 'exit').")
|
|
|
|
while True:
|
|
try:
|
|
user_query = input("\nYou: ").strip()
|
|
|
|
if user_query.lower() in ['exit', 'quit']:
|
|
print("Goodbye!")
|
|
break
|
|
|
|
if not user_query:
|
|
continue
|
|
|
|
response = system.process_query(user_query)
|
|
print(f"\nAI: {response}")
|
|
|
|
except KeyboardInterrupt:
|
|
print("\nGoodbye!")
|
|
break
|
|
except Exception as e:
|
|
print(f"\nAn error occurred: {e}")
|
|
|
|
if __name__ == "__main__":
|
|
main() |