259 lines
7.5 KiB
Python
259 lines
7.5 KiB
Python
import os
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, FastAPI, HTTPException, Response
|
|
from pydantic import BaseModel, Field
|
|
|
|
from Runner import Runner
|
|
from Memory import MemoryManager
|
|
from PromptLibrary import PromptLibrary
|
|
|
|
app = FastAPI(title="ADHDbot API")
|
|
router = APIRouter()
|
|
|
|
|
|
class ChatTurn(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
|
|
class RunRequest(BaseModel):
|
|
userId: Optional[str] = None
|
|
category: str = "general"
|
|
promptName: str = "welcome"
|
|
context: str = "API triggered run"
|
|
history: List[ChatTurn] = Field(default_factory=list)
|
|
modeHint: Optional[str] = None
|
|
|
|
|
|
class RunResponse(BaseModel):
|
|
userId: Optional[str]
|
|
category: str
|
|
promptName: str
|
|
context: str
|
|
message: str
|
|
|
|
|
|
class NoteCreate(BaseModel):
|
|
note: str = Field(..., min_length=1, description="Raw text of the note")
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class NotesResponse(BaseModel):
|
|
userId: str
|
|
notes: List[Dict[str, Any]]
|
|
|
|
|
|
class MemoryResponse(BaseModel):
|
|
userId: str
|
|
summaries: List[Dict[str, Any]]
|
|
notes: List[Dict[str, Any]]
|
|
|
|
|
|
class PromptCatalogResponse(BaseModel):
|
|
catalog: Dict[str, Dict[str, str]]
|
|
|
|
|
|
class ContextRequest(BaseModel):
|
|
context: str = Field(..., min_length=1)
|
|
|
|
|
|
class ActionItemRecord(BaseModel):
|
|
id: str
|
|
title: str
|
|
cadence: str
|
|
details: Optional[str] = None
|
|
interval_minutes: Optional[int] = None
|
|
created_at: str
|
|
updated_at: str
|
|
progress: List[Dict[str, Any]] = Field(default_factory=list)
|
|
|
|
|
|
class ActionItemCreate(BaseModel):
|
|
title: str = Field(..., min_length=1)
|
|
cadence: str = Field(default="daily")
|
|
details: Optional[str] = None
|
|
interval_minutes: Optional[int] = Field(default=None, ge=0)
|
|
|
|
|
|
class ActionItemUpdate(BaseModel):
|
|
title: Optional[str] = Field(default=None, min_length=1)
|
|
cadence: Optional[str] = None
|
|
details: Optional[str] = None
|
|
interval_minutes: Optional[int] = Field(default=None, ge=0)
|
|
|
|
|
|
class ActionItemProgressCreate(BaseModel):
|
|
status: str = Field(default="update", min_length=1)
|
|
note: Optional[str] = None
|
|
|
|
|
|
class ActionItemsResponse(BaseModel):
|
|
userId: str
|
|
action_items: List[ActionItemRecord]
|
|
|
|
|
|
def defaultUserId():
|
|
return os.getenv("TARGET_USER_ID")
|
|
|
|
|
|
@router.get("/health")
|
|
def health():
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.post("/run", response_model=RunResponse)
|
|
def run_bot(request: RunRequest):
|
|
userId = ensureUserId(request.userId)
|
|
message = Runner.run(
|
|
userId,
|
|
request.category,
|
|
request.promptName,
|
|
request.context,
|
|
history=[turn.model_dump() for turn in request.history],
|
|
modeHint=request.modeHint,
|
|
)
|
|
return RunResponse(
|
|
userId=userId,
|
|
category=request.category,
|
|
promptName=request.promptName,
|
|
context=request.context,
|
|
message=message,
|
|
)
|
|
|
|
|
|
@router.get("/users/{userId}/notes", response_model=NotesResponse)
|
|
def get_user_notes(userId: str, limit: int = 10):
|
|
memory = MemoryManager.loadUserMemory(userId)
|
|
notes = memory.get("notes", [])
|
|
if limit > 0:
|
|
notes = notes[-limit:]
|
|
return NotesResponse(userId=userId, notes=notes)
|
|
|
|
|
|
@router.post("/users/{userId}/notes", response_model=NotesResponse, status_code=201)
|
|
def create_user_note(userId: str, payload: NoteCreate):
|
|
cleaned_metadata = payload.metadata or {}
|
|
MemoryManager.recordNote(userId, payload.note, cleaned_metadata)
|
|
updated = MemoryManager.loadUserMemory(userId)
|
|
return NotesResponse(userId=userId, notes=updated.get("notes", []))
|
|
|
|
|
|
@router.get("/users/{userId}/memory", response_model=MemoryResponse)
|
|
def get_user_memory(userId: str):
|
|
memory = MemoryManager.loadUserMemory(userId)
|
|
return MemoryResponse(
|
|
userId=userId,
|
|
notes=memory.get("notes", []),
|
|
summaries=memory.get("summaries", []),
|
|
)
|
|
|
|
|
|
@router.get("/users/{userId}/actions", response_model=ActionItemsResponse)
|
|
def list_user_actions(userId: str):
|
|
actions = MemoryManager.listActionItems(userId)
|
|
return ActionItemsResponse(userId=userId, action_items=actions)
|
|
|
|
|
|
@router.post("/users/{userId}/actions", response_model=ActionItemsResponse, status_code=201)
|
|
def create_user_action(userId: str, payload: ActionItemCreate):
|
|
created = MemoryManager.createActionItem(
|
|
userId,
|
|
payload.title,
|
|
cadence=payload.cadence,
|
|
intervalMinutes=payload.interval_minutes,
|
|
details=payload.details,
|
|
)
|
|
if not created:
|
|
raise HTTPException(status_code=400, detail="title is required")
|
|
actions = MemoryManager.listActionItems(userId)
|
|
return ActionItemsResponse(userId=userId, action_items=actions)
|
|
|
|
|
|
@router.put("/users/{userId}/actions/{actionId}", response_model=ActionItemRecord)
|
|
def update_user_action(userId: str, actionId: str, payload: ActionItemUpdate):
|
|
updated = MemoryManager.updateActionItem(
|
|
userId,
|
|
actionId,
|
|
{
|
|
"title": payload.title,
|
|
"details": payload.details,
|
|
"cadence": payload.cadence,
|
|
"interval_minutes": payload.interval_minutes,
|
|
},
|
|
)
|
|
if not updated:
|
|
raise HTTPException(status_code=404, detail="Action item not found")
|
|
return ActionItemRecord.model_validate(updated)
|
|
|
|
|
|
@router.delete("/users/{userId}/actions/{actionId}", status_code=204)
|
|
def delete_user_action(userId: str, actionId: str):
|
|
deleted = MemoryManager.deleteActionItem(userId, actionId)
|
|
if not deleted:
|
|
raise HTTPException(status_code=404, detail="Action item not found")
|
|
return Response(status_code=204)
|
|
|
|
|
|
@router.post(
|
|
"/users/{userId}/actions/{actionId}/progress",
|
|
response_model=ActionItemRecord,
|
|
status_code=201,
|
|
)
|
|
def add_action_progress(userId: str, actionId: str, payload: ActionItemProgressCreate):
|
|
recorded = MemoryManager.recordActionProgress(
|
|
userId,
|
|
actionId,
|
|
status=payload.status,
|
|
note=payload.note,
|
|
)
|
|
if not recorded:
|
|
raise HTTPException(status_code=404, detail="Action item not found")
|
|
action = resolve_action_or_404(userId, actionId)
|
|
return ActionItemRecord.model_validate(action)
|
|
|
|
|
|
@router.post("/prompts/reload", response_model=PromptCatalogResponse)
|
|
def reload_prompts():
|
|
PromptLibrary.reloadCatalog()
|
|
return PromptCatalogResponse(catalog=PromptLibrary.promptCatalog)
|
|
|
|
|
|
@router.get("/prompts", response_model=PromptCatalogResponse)
|
|
def list_prompts():
|
|
if not PromptLibrary.promptCatalog:
|
|
PromptLibrary.reloadCatalog()
|
|
return PromptCatalogResponse(catalog=PromptLibrary.promptCatalog)
|
|
|
|
|
|
@router.post("/users/{userId}/notes/test", response_model=RunResponse)
|
|
def force_note_capture(userId: str, payload: ContextRequest):
|
|
"""Helper endpoint for QA to trigger the welcome prompt with a custom context."""
|
|
message = Runner.run(userId, "general", "welcome", payload.context)
|
|
return RunResponse(
|
|
userId=userId,
|
|
category="general",
|
|
promptName="welcome",
|
|
context=payload.context,
|
|
message=message,
|
|
)
|
|
|
|
|
|
def ensureUserId(userId: Optional[str]) -> str:
|
|
resolved = userId or defaultUserId()
|
|
if not resolved:
|
|
raise HTTPException(status_code=400, detail="userId is required")
|
|
return resolved
|
|
|
|
|
|
def resolve_action_or_404(userId: str, actionId: str) -> Dict[str, Any]:
|
|
actions = MemoryManager.listActionItems(userId)
|
|
for action in actions:
|
|
if action.get("id") == actionId:
|
|
return action
|
|
raise HTTPException(status_code=404, detail="Action item not found")
|
|
|
|
|
|
app.include_router(router)
|
|
app.include_router(router, prefix="/api")
|