from __future__ import annotations import logging import json from pathlib import Path from typing import Literal, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field from db_agent.client import SqlAgentClient, default_client from db_agent.log_utils import log_interaction from db_agent.sql_executor import SqlExecutionResult, SqlValidationError, execute_sql logger = logging.getLogger("db_agent_ui") class QueryRequest(BaseModel): question: str = Field(..., description="Natural language question") tables: Optional[list[str]] = Field(default=None, description="Optional table hints") execute: bool = Field(default=False, description="Run the generated SQL against MSSQL") max_rows: int = Field(default=500, ge=1, description="Row cap applied via TOP clause") feedback: Optional[str] = Field(default=None, description="Optional user feedback tag") class QueryResponse(BaseModel): sql: str summary: str sanitized_sql: Optional[str] = None rows: Optional[list[dict[str, object]]] = None columns: Optional[list[str]] = None row_count: Optional[int] = None llm_warning: Optional[str] = None execution_status: Optional[str] = None execution_error: Optional[str] = None user_feedback: Optional[str] = None app = FastAPI(title="SugarScale SQL Agent UI", version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) frontend_dir = Path(__file__).resolve().parent.parent / "frontend" static_dir = frontend_dir / "static" if static_dir.exists(): app.mount("/static", StaticFiles(directory=static_dir), name="static") _agent_client: Optional[SqlAgentClient] = None def get_client() -> SqlAgentClient: global _agent_client if _agent_client is None: _agent_client = default_client() return _agent_client class FeedbackRequest(BaseModel): question: str = Field(..., description="Question that produced the answer") sql: str = Field(..., description="Generated SQL that was reviewed") summary: Optional[str] = None sanitized_sql: Optional[str] = None feedback: Literal["correct", "incorrect"] = Field(..., description="User feedback tag") @app.post("/api/query", response_model=QueryResponse) def query_endpoint(payload: QueryRequest) -> QueryResponse: client = get_client() response = client.query(question=payload.question, table_hints=payload.tables) sql = response.get("sql", "").strip() summary = response.get("summary", "").strip() warning = response.get("llm_warning") raw_response = json.dumps(response, ensure_ascii=False) sanitized_sql = None rows = None columns = None row_count = None execution_status = None execution_error = None normalized_sql = sql.lstrip() has_sql = bool(normalized_sql) sql_is_select = normalized_sql.lower().startswith("select") if has_sql else False if payload.execute: if not has_sql: warning = warning or "Model did not return executable SQL; nothing was run." elif not sql_is_select: warning = warning or "Model did not return a valid SELECT statement; execution skipped." else: try: exec_result: SqlExecutionResult = execute_sql( normalized_sql, max_rows=payload.max_rows, ) except SqlValidationError as exc: execution_error = str(exc) raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: # pylint: disable=broad-except logger.exception("SQL execution error") execution_error = str(exc) raise HTTPException(status_code=500, detail="SQL execution failed") from exc else: sanitized_sql = exec_result.sql rows = exec_result.rows columns = exec_result.columns row_count = len(rows) execution_status = "success" log_interaction( question=payload.question, generated_sql=sql, summary=summary, sanitized_sql=sanitized_sql, row_count=row_count, execution_status=execution_status, execution_error=execution_error, raw_response=raw_response, user_feedback=payload.feedback, metadata={ "source": "ui", "execute": payload.execute, "tables": payload.tables, "llm_warning": warning, }, ) return QueryResponse( sql=sql, summary=summary, sanitized_sql=sanitized_sql, rows=rows, columns=columns, row_count=row_count, llm_warning=warning, execution_status=execution_status, execution_error=execution_error, user_feedback=payload.feedback, ) @app.post("/api/feedback") def record_feedback(payload: FeedbackRequest) -> dict[str, str]: feedback_value = payload.feedback.lower() log_interaction( question=payload.question, generated_sql=payload.sql, summary=payload.summary or "", sanitized_sql=payload.sanitized_sql, row_count=None, execution_status=None, execution_error=None, raw_response=None, user_feedback=feedback_value, metadata={ "source": "ui-feedback", "feedback_only": True, }, ) return {"status": "recorded"} @app.get("/health") def healthcheck() -> dict[str, str]: return {"status": "ok"} @app.get("/") def index() -> FileResponse: return FileResponse(frontend_dir / "index.html")