add all files

This commit is contained in:
Rucus
2026-02-17 09:29:34 -06:00
parent b8c8d67c67
commit 782d203799
21925 changed files with 2433086 additions and 0 deletions

View File

@@ -0,0 +1,239 @@
from __future__ import annotations
import logging
import json
import base64
import os
import secrets
from pathlib import Path
from typing import Iterable, Literal, Optional
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from starlette.middleware.base import BaseHTTPMiddleware
from pydantic import BaseModel, Field
from db_agent.client import SqlAgentClient, default_client
from db_agent.log_utils import log_interaction
from db_agent.sql_executor import SqlExecutionResult, SqlValidationError, execute_sql
logger = logging.getLogger("db_agent_ui")
class BasicAuthMiddleware(BaseHTTPMiddleware):
"""Protect routes with HTTP Basic auth when credentials are configured."""
def __init__(
self,
app: FastAPI,
*,
username: str,
password: str,
exclude_prefixes: Iterable[str] = (),
) -> None:
super().__init__(app)
self.username = username
self.password = password
self.exclude_prefixes = tuple(exclude_prefixes)
async def dispatch(self, request: Request, call_next): # type: ignore[override]
if self.exclude_prefixes and request.url.path.startswith(self.exclude_prefixes):
return await call_next(request)
header = request.headers.get("Authorization")
if not header or not header.startswith("Basic "):
return self._unauthorized()
try:
decoded = base64.b64decode(header[6:]).decode("utf-8")
except (ValueError, UnicodeDecodeError):
return self._unauthorized()
provided_user, _, provided_password = decoded.partition(":")
if not (secrets.compare_digest(provided_user, self.username) and secrets.compare_digest(provided_password, self.password)):
return self._unauthorized()
return await call_next(request)
@staticmethod
def _unauthorized() -> Response:
return Response(status_code=401, headers={"WWW-Authenticate": "Basic realm=\"SQL Agent UI\""})
class QueryRequest(BaseModel):
question: str = Field(..., description="Natural language question")
tables: Optional[list[str]] = Field(default=None, description="Optional table hints")
execute: bool = Field(default=False, description="Run the generated SQL against MSSQL")
max_rows: int = Field(default=500, ge=1, description="Row cap applied via TOP clause")
feedback: Optional[str] = Field(default=None, description="Optional user feedback tag")
class QueryResponse(BaseModel):
sql: str
summary: str
sanitized_sql: Optional[str] = None
rows: Optional[list[dict[str, object]]] = None
columns: Optional[list[str]] = None
row_count: Optional[int] = None
llm_warning: Optional[str] = None
execution_status: Optional[str] = None
execution_error: Optional[str] = None
user_feedback: Optional[str] = None
app = FastAPI(title="SugarScale SQL Agent UI", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
basic_user = os.getenv("UI_BASIC_USER")
basic_password = os.getenv("UI_BASIC_PASSWORD")
if basic_user and basic_password:
logger.info("Enabling HTTP Basic auth for UI endpoints")
app.add_middleware(
BasicAuthMiddleware,
username=basic_user,
password=basic_password,
exclude_prefixes=("/health",),
)
elif basic_user or basic_password:
logger.warning("UI basic auth is partially configured; set both UI_BASIC_USER and UI_BASIC_PASSWORD to enable it.")
frontend_dir = Path(__file__).resolve().parent.parent / "frontend"
static_dir = frontend_dir / "static"
if static_dir.exists():
app.mount("/static", StaticFiles(directory=static_dir), name="static")
_agent_client: Optional[SqlAgentClient] = None
def get_client() -> SqlAgentClient:
global _agent_client
if _agent_client is None:
_agent_client = default_client()
return _agent_client
class FeedbackRequest(BaseModel):
question: str = Field(..., description="Question that produced the answer")
sql: str = Field(..., description="Generated SQL that was reviewed")
summary: Optional[str] = None
sanitized_sql: Optional[str] = None
feedback: Literal["correct", "incorrect"] = Field(..., description="User feedback tag")
@app.post("/api/query", response_model=QueryResponse)
def query_endpoint(payload: QueryRequest) -> QueryResponse:
client = get_client()
response = client.query(question=payload.question, table_hints=payload.tables)
sql = response.get("sql", "").strip()
summary = response.get("summary", "").strip()
warning = response.get("llm_warning")
raw_response = json.dumps(response, ensure_ascii=False)
sanitized_sql = None
rows = None
columns = None
row_count = None
execution_status = None
execution_error = None
normalized_sql = sql.lstrip()
has_sql = bool(normalized_sql)
sql_is_select = normalized_sql.lower().startswith("select") if has_sql else False
if payload.execute:
if not has_sql:
warning = warning or "Model did not return executable SQL; nothing was run."
elif not sql_is_select:
warning = warning or "Model did not return a valid SELECT statement; execution skipped."
else:
try:
exec_result: SqlExecutionResult = execute_sql(
normalized_sql,
max_rows=payload.max_rows,
)
except SqlValidationError as exc:
execution_error = str(exc)
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc: # pylint: disable=broad-except
logger.exception("SQL execution error")
execution_error = str(exc)
raise HTTPException(status_code=500, detail="SQL execution failed") from exc
else:
sanitized_sql = exec_result.sql
rows = exec_result.rows
columns = exec_result.columns
row_count = len(rows)
execution_status = "success"
log_interaction(
question=payload.question,
generated_sql=sql,
summary=summary,
sanitized_sql=sanitized_sql,
row_count=row_count,
execution_status=execution_status,
execution_error=execution_error,
raw_response=raw_response,
user_feedback=payload.feedback,
metadata={
"source": "ui",
"execute": payload.execute,
"tables": payload.tables,
"llm_warning": warning,
},
)
return QueryResponse(
sql=sql,
summary=summary,
sanitized_sql=sanitized_sql,
rows=rows,
columns=columns,
row_count=row_count,
llm_warning=warning,
execution_status=execution_status,
execution_error=execution_error,
user_feedback=payload.feedback,
)
@app.post("/api/feedback")
def record_feedback(payload: FeedbackRequest) -> dict[str, str]:
feedback_value = payload.feedback.lower()
log_interaction(
question=payload.question,
generated_sql=payload.sql,
summary=payload.summary or "",
sanitized_sql=payload.sanitized_sql,
row_count=None,
execution_status=None,
execution_error=None,
raw_response=None,
user_feedback=feedback_value,
metadata={
"source": "ui-feedback",
"feedback_only": True,
},
)
return {"status": "recorded"}
@app.get("/health")
def healthcheck() -> dict[str, str]:
return {"status": "ok"}
@app.get("/")
def index() -> FileResponse:
return FileResponse(frontend_dir / "index.html")