from __future__ import annotations import logging import json import base64 import os import secrets from pathlib import Path from typing import Iterable, Literal, Optional from fastapi import FastAPI, HTTPException, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from starlette.middleware.base import BaseHTTPMiddleware from pydantic import BaseModel, Field from db_agent.client import SqlAgentClient, default_client from db_agent.log_utils import log_interaction from db_agent.sql_executor import SqlExecutionResult, SqlValidationError, execute_sql logger = logging.getLogger("db_agent_ui") class BasicAuthMiddleware(BaseHTTPMiddleware): """Protect routes with HTTP Basic auth when credentials are configured.""" def __init__( self, app: FastAPI, *, username: str, password: str, exclude_prefixes: Iterable[str] = (), ) -> None: super().__init__(app) self.username = username self.password = password self.exclude_prefixes = tuple(exclude_prefixes) async def dispatch(self, request: Request, call_next): # type: ignore[override] if self.exclude_prefixes and request.url.path.startswith(self.exclude_prefixes): return await call_next(request) header = request.headers.get("Authorization") if not header or not header.startswith("Basic "): return self._unauthorized() try: decoded = base64.b64decode(header[6:]).decode("utf-8") except (ValueError, UnicodeDecodeError): return self._unauthorized() provided_user, _, provided_password = decoded.partition(":") if not (secrets.compare_digest(provided_user, self.username) and secrets.compare_digest(provided_password, self.password)): return self._unauthorized() return await call_next(request) @staticmethod def _unauthorized() -> Response: return Response(status_code=401, headers={"WWW-Authenticate": "Basic realm=\"SQL Agent UI\""}) class QueryRequest(BaseModel): question: str = Field(..., description="Natural language question") tables: Optional[list[str]] = Field(default=None, description="Optional table hints") execute: bool = Field(default=False, description="Run the generated SQL against MSSQL") max_rows: int = Field(default=500, ge=1, description="Row cap applied via TOP clause") feedback: Optional[str] = Field(default=None, description="Optional user feedback tag") class QueryResponse(BaseModel): sql: str summary: str sanitized_sql: Optional[str] = None rows: Optional[list[dict[str, object]]] = None columns: Optional[list[str]] = None row_count: Optional[int] = None llm_warning: Optional[str] = None execution_status: Optional[str] = None execution_error: Optional[str] = None user_feedback: Optional[str] = None app = FastAPI(title="SugarScale SQL Agent UI", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) basic_user = os.getenv("UI_BASIC_USER") basic_password = os.getenv("UI_BASIC_PASSWORD") if basic_user and basic_password: logger.info("Enabling HTTP Basic auth for UI endpoints") app.add_middleware( BasicAuthMiddleware, username=basic_user, password=basic_password, exclude_prefixes=("/health",), ) elif basic_user or basic_password: logger.warning("UI basic auth is partially configured; set both UI_BASIC_USER and UI_BASIC_PASSWORD to enable it.") frontend_dir = Path(__file__).resolve().parent.parent / "frontend" static_dir = frontend_dir / "static" if static_dir.exists(): app.mount("/static", StaticFiles(directory=static_dir), name="static") _agent_client: Optional[SqlAgentClient] = None def get_client() -> SqlAgentClient: global _agent_client if _agent_client is None: _agent_client = default_client() return _agent_client class FeedbackRequest(BaseModel): question: str = Field(..., description="Question that produced the answer") sql: str = Field(..., description="Generated SQL that was reviewed") summary: Optional[str] = None sanitized_sql: Optional[str] = None feedback: Literal["correct", "incorrect"] = Field(..., description="User feedback tag") @app.post("/api/query", response_model=QueryResponse) def query_endpoint(payload: QueryRequest) -> QueryResponse: client = get_client() response = client.query(question=payload.question, table_hints=payload.tables) sql = response.get("sql", "").strip() summary = response.get("summary", "").strip() warning = response.get("llm_warning") raw_response = json.dumps(response, ensure_ascii=False) sanitized_sql = None rows = None columns = None row_count = None execution_status = None execution_error = None normalized_sql = sql.lstrip() has_sql = bool(normalized_sql) sql_is_select = normalized_sql.lower().startswith("select") if has_sql else False if payload.execute: if not has_sql: warning = warning or "Model did not return executable SQL; nothing was run." elif not sql_is_select: warning = warning or "Model did not return a valid SELECT statement; execution skipped." else: try: exec_result: SqlExecutionResult = execute_sql( normalized_sql, max_rows=payload.max_rows, ) except SqlValidationError as exc: execution_error = str(exc) raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pylint: disable=broad-except logger.exception("SQL execution error") execution_error = str(exc) raise HTTPException(status_code=500, detail="SQL execution failed") from exc else: sanitized_sql = exec_result.sql rows = exec_result.rows columns = exec_result.columns row_count = len(rows) execution_status = "success" log_interaction( question=payload.question, generated_sql=sql, summary=summary, sanitized_sql=sanitized_sql, row_count=row_count, execution_status=execution_status, execution_error=execution_error, raw_response=raw_response, user_feedback=payload.feedback, metadata={ "source": "ui", "execute": payload.execute, "tables": payload.tables, "llm_warning": warning, }, ) return QueryResponse( sql=sql, summary=summary, sanitized_sql=sanitized_sql, rows=rows, columns=columns, row_count=row_count, llm_warning=warning, execution_status=execution_status, execution_error=execution_error, user_feedback=payload.feedback, ) @app.post("/api/feedback") def record_feedback(payload: FeedbackRequest) -> dict[str, str]: feedback_value = payload.feedback.lower() log_interaction( question=payload.question, generated_sql=payload.sql, summary=payload.summary or "", sanitized_sql=payload.sanitized_sql, row_count=None, execution_status=None, execution_error=None, raw_response=None, user_feedback=feedback_value, metadata={ "source": "ui-feedback", "feedback_only": True, }, ) return {"status": "recorded"} @app.get("/health") def healthcheck() -> dict[str, str]: return {"status": "ok"} @app.get("/") def index() -> FileResponse: return FileResponse(frontend_dir / "index.html")