chore: initial import
This commit is contained in:
258
api.py
Normal file
258
api.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from Runner import Runner
|
||||
from Memory import MemoryManager
|
||||
from PromptLibrary import PromptLibrary
|
||||
|
||||
app = FastAPI(title="ADHDbot API")
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ChatTurn(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
userId: Optional[str] = None
|
||||
category: str = "general"
|
||||
promptName: str = "welcome"
|
||||
context: str = "API triggered run"
|
||||
history: List[ChatTurn] = Field(default_factory=list)
|
||||
modeHint: Optional[str] = None
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
userId: Optional[str]
|
||||
category: str
|
||||
promptName: str
|
||||
context: str
|
||||
message: str
|
||||
|
||||
|
||||
class NoteCreate(BaseModel):
|
||||
note: str = Field(..., min_length=1, description="Raw text of the note")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class NotesResponse(BaseModel):
|
||||
userId: str
|
||||
notes: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
userId: str
|
||||
summaries: List[Dict[str, Any]]
|
||||
notes: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class PromptCatalogResponse(BaseModel):
|
||||
catalog: Dict[str, Dict[str, str]]
|
||||
|
||||
|
||||
class ContextRequest(BaseModel):
|
||||
context: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
class ActionItemRecord(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
cadence: str
|
||||
details: Optional[str] = None
|
||||
interval_minutes: Optional[int] = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
progress: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ActionItemCreate(BaseModel):
|
||||
title: str = Field(..., min_length=1)
|
||||
cadence: str = Field(default="daily")
|
||||
details: Optional[str] = None
|
||||
interval_minutes: Optional[int] = Field(default=None, ge=0)
|
||||
|
||||
|
||||
class ActionItemUpdate(BaseModel):
|
||||
title: Optional[str] = Field(default=None, min_length=1)
|
||||
cadence: Optional[str] = None
|
||||
details: Optional[str] = None
|
||||
interval_minutes: Optional[int] = Field(default=None, ge=0)
|
||||
|
||||
|
||||
class ActionItemProgressCreate(BaseModel):
|
||||
status: str = Field(default="update", min_length=1)
|
||||
note: Optional[str] = None
|
||||
|
||||
|
||||
class ActionItemsResponse(BaseModel):
|
||||
userId: str
|
||||
action_items: List[ActionItemRecord]
|
||||
|
||||
|
||||
def defaultUserId():
|
||||
return os.getenv("TARGET_USER_ID")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@router.post("/run", response_model=RunResponse)
|
||||
def run_bot(request: RunRequest):
|
||||
userId = ensureUserId(request.userId)
|
||||
message = Runner.run(
|
||||
userId,
|
||||
request.category,
|
||||
request.promptName,
|
||||
request.context,
|
||||
history=[turn.model_dump() for turn in request.history],
|
||||
modeHint=request.modeHint,
|
||||
)
|
||||
return RunResponse(
|
||||
userId=userId,
|
||||
category=request.category,
|
||||
promptName=request.promptName,
|
||||
context=request.context,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users/{userId}/notes", response_model=NotesResponse)
|
||||
def get_user_notes(userId: str, limit: int = 10):
|
||||
memory = MemoryManager.loadUserMemory(userId)
|
||||
notes = memory.get("notes", [])
|
||||
if limit > 0:
|
||||
notes = notes[-limit:]
|
||||
return NotesResponse(userId=userId, notes=notes)
|
||||
|
||||
|
||||
@router.post("/users/{userId}/notes", response_model=NotesResponse, status_code=201)
|
||||
def create_user_note(userId: str, payload: NoteCreate):
|
||||
cleaned_metadata = payload.metadata or {}
|
||||
MemoryManager.recordNote(userId, payload.note, cleaned_metadata)
|
||||
updated = MemoryManager.loadUserMemory(userId)
|
||||
return NotesResponse(userId=userId, notes=updated.get("notes", []))
|
||||
|
||||
|
||||
@router.get("/users/{userId}/memory", response_model=MemoryResponse)
|
||||
def get_user_memory(userId: str):
|
||||
memory = MemoryManager.loadUserMemory(userId)
|
||||
return MemoryResponse(
|
||||
userId=userId,
|
||||
notes=memory.get("notes", []),
|
||||
summaries=memory.get("summaries", []),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users/{userId}/actions", response_model=ActionItemsResponse)
|
||||
def list_user_actions(userId: str):
|
||||
actions = MemoryManager.listActionItems(userId)
|
||||
return ActionItemsResponse(userId=userId, action_items=actions)
|
||||
|
||||
|
||||
@router.post("/users/{userId}/actions", response_model=ActionItemsResponse, status_code=201)
|
||||
def create_user_action(userId: str, payload: ActionItemCreate):
|
||||
created = MemoryManager.createActionItem(
|
||||
userId,
|
||||
payload.title,
|
||||
cadence=payload.cadence,
|
||||
intervalMinutes=payload.interval_minutes,
|
||||
details=payload.details,
|
||||
)
|
||||
if not created:
|
||||
raise HTTPException(status_code=400, detail="title is required")
|
||||
actions = MemoryManager.listActionItems(userId)
|
||||
return ActionItemsResponse(userId=userId, action_items=actions)
|
||||
|
||||
|
||||
@router.put("/users/{userId}/actions/{actionId}", response_model=ActionItemRecord)
|
||||
def update_user_action(userId: str, actionId: str, payload: ActionItemUpdate):
|
||||
updated = MemoryManager.updateActionItem(
|
||||
userId,
|
||||
actionId,
|
||||
{
|
||||
"title": payload.title,
|
||||
"details": payload.details,
|
||||
"cadence": payload.cadence,
|
||||
"interval_minutes": payload.interval_minutes,
|
||||
},
|
||||
)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail="Action item not found")
|
||||
return ActionItemRecord.model_validate(updated)
|
||||
|
||||
|
||||
@router.delete("/users/{userId}/actions/{actionId}", status_code=204)
|
||||
def delete_user_action(userId: str, actionId: str):
|
||||
deleted = MemoryManager.deleteActionItem(userId, actionId)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Action item not found")
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users/{userId}/actions/{actionId}/progress",
|
||||
response_model=ActionItemRecord,
|
||||
status_code=201,
|
||||
)
|
||||
def add_action_progress(userId: str, actionId: str, payload: ActionItemProgressCreate):
|
||||
recorded = MemoryManager.recordActionProgress(
|
||||
userId,
|
||||
actionId,
|
||||
status=payload.status,
|
||||
note=payload.note,
|
||||
)
|
||||
if not recorded:
|
||||
raise HTTPException(status_code=404, detail="Action item not found")
|
||||
action = resolve_action_or_404(userId, actionId)
|
||||
return ActionItemRecord.model_validate(action)
|
||||
|
||||
|
||||
@router.post("/prompts/reload", response_model=PromptCatalogResponse)
|
||||
def reload_prompts():
|
||||
PromptLibrary.reloadCatalog()
|
||||
return PromptCatalogResponse(catalog=PromptLibrary.promptCatalog)
|
||||
|
||||
|
||||
@router.get("/prompts", response_model=PromptCatalogResponse)
|
||||
def list_prompts():
|
||||
if not PromptLibrary.promptCatalog:
|
||||
PromptLibrary.reloadCatalog()
|
||||
return PromptCatalogResponse(catalog=PromptLibrary.promptCatalog)
|
||||
|
||||
|
||||
@router.post("/users/{userId}/notes/test", response_model=RunResponse)
|
||||
def force_note_capture(userId: str, payload: ContextRequest):
|
||||
"""Helper endpoint for QA to trigger the welcome prompt with a custom context."""
|
||||
message = Runner.run(userId, "general", "welcome", payload.context)
|
||||
return RunResponse(
|
||||
userId=userId,
|
||||
category="general",
|
||||
promptName="welcome",
|
||||
context=payload.context,
|
||||
message=message,
|
||||
)
|
||||
|
||||
|
||||
def ensureUserId(userId: Optional[str]) -> str:
|
||||
resolved = userId or defaultUserId()
|
||||
if not resolved:
|
||||
raise HTTPException(status_code=400, detail="userId is required")
|
||||
return resolved
|
||||
|
||||
|
||||
def resolve_action_or_404(userId: str, actionId: str) -> Dict[str, Any]:
|
||||
actions = MemoryManager.listActionItems(userId)
|
||||
for action in actions:
|
||||
if action.get("id") == actionId:
|
||||
return action
|
||||
raise HTTPException(status_code=404, detail="Action item not found")
|
||||
|
||||
|
||||
app.include_router(router)
|
||||
app.include_router(router, prefix="/api")
|
||||
Reference in New Issue
Block a user