240 lines
7.8 KiB
Python
240 lines
7.8 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import json
|
|
import base64
|
|
import os
|
|
import secrets
|
|
from pathlib import Path
|
|
from typing import Iterable, Literal, Optional
|
|
|
|
from fastapi import FastAPI, HTTPException, Request, Response
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from pydantic import BaseModel, Field
|
|
|
|
from db_agent.client import SqlAgentClient, default_client
|
|
from db_agent.log_utils import log_interaction
|
|
from db_agent.sql_executor import SqlExecutionResult, SqlValidationError, execute_sql
|
|
|
|
logger = logging.getLogger("db_agent_ui")
|
|
|
|
|
|
class BasicAuthMiddleware(BaseHTTPMiddleware):
|
|
"""Protect routes with HTTP Basic auth when credentials are configured."""
|
|
|
|
def __init__(
|
|
self,
|
|
app: FastAPI,
|
|
*,
|
|
username: str,
|
|
password: str,
|
|
exclude_prefixes: Iterable[str] = (),
|
|
) -> None:
|
|
super().__init__(app)
|
|
self.username = username
|
|
self.password = password
|
|
self.exclude_prefixes = tuple(exclude_prefixes)
|
|
|
|
async def dispatch(self, request: Request, call_next): # type: ignore[override]
|
|
if self.exclude_prefixes and request.url.path.startswith(self.exclude_prefixes):
|
|
return await call_next(request)
|
|
|
|
header = request.headers.get("Authorization")
|
|
if not header or not header.startswith("Basic "):
|
|
return self._unauthorized()
|
|
|
|
try:
|
|
decoded = base64.b64decode(header[6:]).decode("utf-8")
|
|
except (ValueError, UnicodeDecodeError):
|
|
return self._unauthorized()
|
|
|
|
provided_user, _, provided_password = decoded.partition(":")
|
|
if not (secrets.compare_digest(provided_user, self.username) and secrets.compare_digest(provided_password, self.password)):
|
|
return self._unauthorized()
|
|
|
|
return await call_next(request)
|
|
|
|
@staticmethod
|
|
def _unauthorized() -> Response:
|
|
return Response(status_code=401, headers={"WWW-Authenticate": "Basic realm=\"SQL Agent UI\""})
|
|
|
|
|
|
class QueryRequest(BaseModel):
|
|
question: str = Field(..., description="Natural language question")
|
|
tables: Optional[list[str]] = Field(default=None, description="Optional table hints")
|
|
execute: bool = Field(default=False, description="Run the generated SQL against MSSQL")
|
|
max_rows: int = Field(default=500, ge=1, description="Row cap applied via TOP clause")
|
|
feedback: Optional[str] = Field(default=None, description="Optional user feedback tag")
|
|
|
|
|
|
class QueryResponse(BaseModel):
|
|
sql: str
|
|
summary: str
|
|
sanitized_sql: Optional[str] = None
|
|
rows: Optional[list[dict[str, object]]] = None
|
|
columns: Optional[list[str]] = None
|
|
row_count: Optional[int] = None
|
|
llm_warning: Optional[str] = None
|
|
execution_status: Optional[str] = None
|
|
execution_error: Optional[str] = None
|
|
user_feedback: Optional[str] = None
|
|
|
|
|
|
app = FastAPI(title="SugarScale SQL Agent UI", version="0.1.0")
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
basic_user = os.getenv("UI_BASIC_USER")
|
|
basic_password = os.getenv("UI_BASIC_PASSWORD")
|
|
|
|
if basic_user and basic_password:
|
|
logger.info("Enabling HTTP Basic auth for UI endpoints")
|
|
app.add_middleware(
|
|
BasicAuthMiddleware,
|
|
username=basic_user,
|
|
password=basic_password,
|
|
exclude_prefixes=("/health",),
|
|
)
|
|
elif basic_user or basic_password:
|
|
logger.warning("UI basic auth is partially configured; set both UI_BASIC_USER and UI_BASIC_PASSWORD to enable it.")
|
|
|
|
frontend_dir = Path(__file__).resolve().parent.parent / "frontend"
|
|
static_dir = frontend_dir / "static"
|
|
if static_dir.exists():
|
|
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
|
|
|
_agent_client: Optional[SqlAgentClient] = None
|
|
|
|
|
|
def get_client() -> SqlAgentClient:
|
|
global _agent_client
|
|
if _agent_client is None:
|
|
_agent_client = default_client()
|
|
return _agent_client
|
|
|
|
|
|
class FeedbackRequest(BaseModel):
|
|
question: str = Field(..., description="Question that produced the answer")
|
|
sql: str = Field(..., description="Generated SQL that was reviewed")
|
|
summary: Optional[str] = None
|
|
sanitized_sql: Optional[str] = None
|
|
feedback: Literal["correct", "incorrect"] = Field(..., description="User feedback tag")
|
|
|
|
|
|
@app.post("/api/query", response_model=QueryResponse)
|
|
def query_endpoint(payload: QueryRequest) -> QueryResponse:
|
|
client = get_client()
|
|
response = client.query(question=payload.question, table_hints=payload.tables)
|
|
|
|
sql = response.get("sql", "").strip()
|
|
summary = response.get("summary", "").strip()
|
|
warning = response.get("llm_warning")
|
|
raw_response = json.dumps(response, ensure_ascii=False)
|
|
sanitized_sql = None
|
|
rows = None
|
|
columns = None
|
|
row_count = None
|
|
execution_status = None
|
|
execution_error = None
|
|
|
|
normalized_sql = sql.lstrip()
|
|
has_sql = bool(normalized_sql)
|
|
sql_is_select = normalized_sql.lower().startswith("select") if has_sql else False
|
|
|
|
if payload.execute:
|
|
if not has_sql:
|
|
warning = warning or "Model did not return executable SQL; nothing was run."
|
|
elif not sql_is_select:
|
|
warning = warning or "Model did not return a valid SELECT statement; execution skipped."
|
|
else:
|
|
try:
|
|
exec_result: SqlExecutionResult = execute_sql(
|
|
normalized_sql,
|
|
max_rows=payload.max_rows,
|
|
)
|
|
except SqlValidationError as exc:
|
|
execution_error = str(exc)
|
|
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
|
except Exception as exc: # pylint: disable=broad-except
|
|
logger.exception("SQL execution error")
|
|
execution_error = str(exc)
|
|
raise HTTPException(status_code=500, detail="SQL execution failed") from exc
|
|
else:
|
|
sanitized_sql = exec_result.sql
|
|
rows = exec_result.rows
|
|
columns = exec_result.columns
|
|
row_count = len(rows)
|
|
execution_status = "success"
|
|
|
|
log_interaction(
|
|
question=payload.question,
|
|
generated_sql=sql,
|
|
summary=summary,
|
|
sanitized_sql=sanitized_sql,
|
|
row_count=row_count,
|
|
execution_status=execution_status,
|
|
execution_error=execution_error,
|
|
raw_response=raw_response,
|
|
user_feedback=payload.feedback,
|
|
metadata={
|
|
"source": "ui",
|
|
"execute": payload.execute,
|
|
"tables": payload.tables,
|
|
"llm_warning": warning,
|
|
},
|
|
)
|
|
|
|
return QueryResponse(
|
|
sql=sql,
|
|
summary=summary,
|
|
sanitized_sql=sanitized_sql,
|
|
rows=rows,
|
|
columns=columns,
|
|
row_count=row_count,
|
|
llm_warning=warning,
|
|
execution_status=execution_status,
|
|
execution_error=execution_error,
|
|
user_feedback=payload.feedback,
|
|
)
|
|
|
|
|
|
@app.post("/api/feedback")
|
|
def record_feedback(payload: FeedbackRequest) -> dict[str, str]:
|
|
feedback_value = payload.feedback.lower()
|
|
|
|
log_interaction(
|
|
question=payload.question,
|
|
generated_sql=payload.sql,
|
|
summary=payload.summary or "",
|
|
sanitized_sql=payload.sanitized_sql,
|
|
row_count=None,
|
|
execution_status=None,
|
|
execution_error=None,
|
|
raw_response=None,
|
|
user_feedback=feedback_value,
|
|
metadata={
|
|
"source": "ui-feedback",
|
|
"feedback_only": True,
|
|
},
|
|
)
|
|
|
|
return {"status": "recorded"}
|
|
|
|
|
|
@app.get("/health")
|
|
def healthcheck() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
|
|
@app.get("/")
|
|
def index() -> FileResponse:
|
|
return FileResponse(frontend_dir / "index.html")
|