Files
controls-web/ai_agents/cane_agent/ui/backend/main.py
2026-02-17 09:29:34 -06:00

182 lines
5.8 KiB
Python

from __future__ import annotations
import logging
import json
from pathlib import Path
from typing import Literal, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from db_agent.client import SqlAgentClient, default_client
from db_agent.log_utils import log_interaction
from db_agent.sql_executor import SqlExecutionResult, SqlValidationError, execute_sql
logger = logging.getLogger("db_agent_ui")
class QueryRequest(BaseModel):
question: str = Field(..., description="Natural language question")
tables: Optional[list[str]] = Field(default=None, description="Optional table hints")
execute: bool = Field(default=False, description="Run the generated SQL against MSSQL")
max_rows: int = Field(default=500, ge=1, description="Row cap applied via TOP clause")
feedback: Optional[str] = Field(default=None, description="Optional user feedback tag")
class QueryResponse(BaseModel):
sql: str
summary: str
sanitized_sql: Optional[str] = None
rows: Optional[list[dict[str, object]]] = None
columns: Optional[list[str]] = None
row_count: Optional[int] = None
llm_warning: Optional[str] = None
execution_status: Optional[str] = None
execution_error: Optional[str] = None
user_feedback: Optional[str] = None
app = FastAPI(title="SugarScale SQL Agent UI", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
frontend_dir = Path(__file__).resolve().parent.parent / "frontend"
static_dir = frontend_dir / "static"
if static_dir.exists():
app.mount("/static", StaticFiles(directory=static_dir), name="static")
_agent_client: Optional[SqlAgentClient] = None
def get_client() -> SqlAgentClient:
global _agent_client
if _agent_client is None:
_agent_client = default_client()
return _agent_client
class FeedbackRequest(BaseModel):
question: str = Field(..., description="Question that produced the answer")
sql: str = Field(..., description="Generated SQL that was reviewed")
summary: Optional[str] = None
sanitized_sql: Optional[str] = None
feedback: Literal["correct", "incorrect"] = Field(..., description="User feedback tag")
@app.post("/api/query", response_model=QueryResponse)
def query_endpoint(payload: QueryRequest) -> QueryResponse:
client = get_client()
response = client.query(question=payload.question, table_hints=payload.tables)
sql = response.get("sql", "").strip()
summary = response.get("summary", "").strip()
warning = response.get("llm_warning")
raw_response = json.dumps(response, ensure_ascii=False)
sanitized_sql = None
rows = None
columns = None
row_count = None
execution_status = None
execution_error = None
normalized_sql = sql.lstrip()
has_sql = bool(normalized_sql)
sql_is_select = normalized_sql.lower().startswith("select") if has_sql else False
if payload.execute:
if not has_sql:
warning = warning or "Model did not return executable SQL; nothing was run."
elif not sql_is_select:
warning = warning or "Model did not return a valid SELECT statement; execution skipped."
else:
try:
exec_result: SqlExecutionResult = execute_sql(
normalized_sql,
max_rows=payload.max_rows,
)
except SqlValidationError as exc:
execution_error = str(exc)
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pylint: disable=broad-except
logger.exception("SQL execution error")
execution_error = str(exc)
raise HTTPException(status_code=500, detail="SQL execution failed") from exc
else:
sanitized_sql = exec_result.sql
rows = exec_result.rows
columns = exec_result.columns
row_count = len(rows)
execution_status = "success"
log_interaction(
question=payload.question,
generated_sql=sql,
summary=summary,
sanitized_sql=sanitized_sql,
row_count=row_count,
execution_status=execution_status,
execution_error=execution_error,
raw_response=raw_response,
user_feedback=payload.feedback,
metadata={
"source": "ui",
"execute": payload.execute,
"tables": payload.tables,
"llm_warning": warning,
},
)
return QueryResponse(
sql=sql,
summary=summary,
sanitized_sql=sanitized_sql,
rows=rows,
columns=columns,
row_count=row_count,
llm_warning=warning,
execution_status=execution_status,
execution_error=execution_error,
user_feedback=payload.feedback,
)
@app.post("/api/feedback")
def record_feedback(payload: FeedbackRequest) -> dict[str, str]:
feedback_value = payload.feedback.lower()
log_interaction(
question=payload.question,
generated_sql=payload.sql,
summary=payload.summary or "",
sanitized_sql=payload.sanitized_sql,
row_count=None,
execution_status=None,
execution_error=None,
raw_response=None,
user_feedback=feedback_value,
metadata={
"source": "ui-feedback",
"feedback_only": True,
},
)
return {"status": "recorded"}
@app.get("/health")
def healthcheck() -> dict[str, str]:
return {"status": "ok"}
@app.get("/")
def index() -> FileResponse:
return FileResponse(frontend_dir / "index.html")