Files
controls-web/ai_agents/db_agent/sql_executor.py
2026-02-17 09:29:34 -06:00

140 lines
3.9 KiB
Python

"""Validation and execution helpers for generated SQL queries."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence
import sqlalchemy as sa
from sqlalchemy.engine import Engine
from .extractor.settings import DbSettings
class SqlValidationError(RuntimeError):
"""Raised when a generated SQL statement fails safety checks."""
PROHIBITED_KEYWORDS: Sequence[str] = (
"INSERT",
"UPDATE",
"DELETE",
"MERGE",
"DROP",
"ALTER",
"TRUNCATE",
"CREATE",
"EXEC",
"EXECUTE",
"GRANT",
"REVOKE",
"BEGIN",
"COMMIT",
"ROLLBACK",
)
UNSUPPORTED_PATTERNS: Sequence[str] = (
r"\bILIKE\b",
r"NULLS\s+LAST",
r"NULLS\s+FIRST",
r"\bLIMIT\b",
r"\bOFFSET\b",
r"\bUSING\b",
)
@dataclass
class SqlExecutionResult:
sql: str
rows: List[Dict[str, Any]]
columns: List[str]
def _build_engine(settings: DbSettings) -> Engine:
connection_string = settings.connection_string()
uri = "mssql+pyodbc:///?odbc_connect=" + sa.engine.url.quote_plus(connection_string)
return sa.create_engine(uri, fast_executemany=True)
def _clean_sql(sql: str) -> str:
return sql.strip().rstrip(";")
def _contains_multiple_statements(sql: str) -> bool:
body = sql.rstrip()
if ";" not in body:
return False
# Allow a single trailing semicolon but forbid anything after it
last_semicolon = body.rfind(";")
trailing = body[last_semicolon + 1 :].strip()
return trailing != ""
def _enforce_top_clause(sql: str, max_rows: int) -> str:
pattern = re.compile(r"select\s+(distinct\s+)?", re.IGNORECASE)
leading_whitespace = len(sql) - len(sql.lstrip())
body = sql.lstrip()
match = pattern.match(body)
if not match:
return sql
# Detect existing TOP clause in the initial projection
leading_segment = body[: match.end() + 20].lower()
if "top" in leading_segment:
return sql
prefix = match.group(0)
remainder = body[match.end() :]
if match.group(1):
# prefix already includes "SELECT DISTINCT "
replacement = prefix + f"TOP ({max_rows}) " + remainder
else:
replacement = "SELECT TOP ({}) ".format(max_rows) + body[len("SELECT ") :]
return sql[:leading_whitespace] + replacement
def sanitize_sql(sql: str, max_rows: int = 500) -> str:
"""Ensure the SQL is safe to execute and apply a row limit."""
candidate = _clean_sql(sql)
if not candidate:
raise SqlValidationError("Generated SQL is empty.")
lowered = candidate.lstrip().lower()
if not lowered.startswith("select"):
raise SqlValidationError("Only SELECT statements are allowed.")
if _contains_multiple_statements(candidate):
raise SqlValidationError("Multiple statements detected; reject for safety.")
upper = candidate.upper()
for keyword in PROHIBITED_KEYWORDS:
if re.search(rf"\b{keyword}\b", upper):
raise SqlValidationError(f"Statement contains disallowed keyword: {keyword}")
for pattern in UNSUPPORTED_PATTERNS:
if re.search(pattern, candidate, flags=re.IGNORECASE):
raise SqlValidationError("Statement contains non-SQL Server syntax (e.g., ILIKE, NULLS LAST, LIMIT, OFFSET, USING).")
if max_rows <= 0:
raise SqlValidationError("max_rows must be greater than zero.")
limited = _enforce_top_clause(candidate, max_rows=max_rows)
return limited
def execute_sql(sql: str, max_rows: int = 500) -> SqlExecutionResult:
"""Validate and run the provided SQL against the configured MSSQL database."""
sanitized = sanitize_sql(sql, max_rows=max_rows)
settings = DbSettings.from_env()
engine = _build_engine(settings)
with engine.connect() as conn:
result = conn.execute(sa.text(sanitized))
columns = list(result.keys())
rows = [dict(row._mapping) for row in result]
return SqlExecutionResult(sql=sanitized, rows=rows, columns=columns)