140 lines
3.9 KiB
Python
140 lines
3.9 KiB
Python
"""Validation and execution helpers for generated SQL queries."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Sequence
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.engine import Engine
|
|
|
|
from .extractor.settings import DbSettings
|
|
|
|
|
|
class SqlValidationError(RuntimeError):
|
|
"""Raised when a generated SQL statement fails safety checks."""
|
|
|
|
|
|
PROHIBITED_KEYWORDS: Sequence[str] = (
|
|
"INSERT",
|
|
"UPDATE",
|
|
"DELETE",
|
|
"MERGE",
|
|
"DROP",
|
|
"ALTER",
|
|
"TRUNCATE",
|
|
"CREATE",
|
|
"EXEC",
|
|
"EXECUTE",
|
|
"GRANT",
|
|
"REVOKE",
|
|
"BEGIN",
|
|
"COMMIT",
|
|
"ROLLBACK",
|
|
)
|
|
|
|
UNSUPPORTED_PATTERNS: Sequence[str] = (
|
|
r"\bILIKE\b",
|
|
r"NULLS\s+LAST",
|
|
r"NULLS\s+FIRST",
|
|
r"\bLIMIT\b",
|
|
r"\bOFFSET\b",
|
|
r"\bUSING\b",
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SqlExecutionResult:
|
|
sql: str
|
|
rows: List[Dict[str, Any]]
|
|
columns: List[str]
|
|
|
|
|
|
def _build_engine(settings: DbSettings) -> Engine:
|
|
connection_string = settings.connection_string()
|
|
uri = "mssql+pyodbc:///?odbc_connect=" + sa.engine.url.quote_plus(connection_string)
|
|
return sa.create_engine(uri, fast_executemany=True)
|
|
|
|
|
|
def _clean_sql(sql: str) -> str:
|
|
return sql.strip().rstrip(";")
|
|
|
|
|
|
def _contains_multiple_statements(sql: str) -> bool:
|
|
body = sql.rstrip()
|
|
if ";" not in body:
|
|
return False
|
|
# Allow a single trailing semicolon but forbid anything after it
|
|
last_semicolon = body.rfind(";")
|
|
trailing = body[last_semicolon + 1 :].strip()
|
|
return trailing != ""
|
|
|
|
|
|
def _enforce_top_clause(sql: str, max_rows: int) -> str:
|
|
pattern = re.compile(r"select\s+(distinct\s+)?", re.IGNORECASE)
|
|
leading_whitespace = len(sql) - len(sql.lstrip())
|
|
body = sql.lstrip()
|
|
match = pattern.match(body)
|
|
if not match:
|
|
return sql
|
|
|
|
# Detect existing TOP clause in the initial projection
|
|
leading_segment = body[: match.end() + 20].lower()
|
|
if "top" in leading_segment:
|
|
return sql
|
|
|
|
prefix = match.group(0)
|
|
remainder = body[match.end() :]
|
|
if match.group(1):
|
|
# prefix already includes "SELECT DISTINCT "
|
|
replacement = prefix + f"TOP ({max_rows}) " + remainder
|
|
else:
|
|
replacement = "SELECT TOP ({}) ".format(max_rows) + body[len("SELECT ") :]
|
|
|
|
return sql[:leading_whitespace] + replacement
|
|
|
|
|
|
def sanitize_sql(sql: str, max_rows: int = 500) -> str:
|
|
"""Ensure the SQL is safe to execute and apply a row limit."""
|
|
|
|
candidate = _clean_sql(sql)
|
|
if not candidate:
|
|
raise SqlValidationError("Generated SQL is empty.")
|
|
|
|
lowered = candidate.lstrip().lower()
|
|
if not lowered.startswith("select"):
|
|
raise SqlValidationError("Only SELECT statements are allowed.")
|
|
|
|
if _contains_multiple_statements(candidate):
|
|
raise SqlValidationError("Multiple statements detected; reject for safety.")
|
|
|
|
upper = candidate.upper()
|
|
for keyword in PROHIBITED_KEYWORDS:
|
|
if re.search(rf"\b{keyword}\b", upper):
|
|
raise SqlValidationError(f"Statement contains disallowed keyword: {keyword}")
|
|
|
|
for pattern in UNSUPPORTED_PATTERNS:
|
|
if re.search(pattern, candidate, flags=re.IGNORECASE):
|
|
raise SqlValidationError("Statement contains non-SQL Server syntax (e.g., ILIKE, NULLS LAST, LIMIT, OFFSET, USING).")
|
|
|
|
if max_rows <= 0:
|
|
raise SqlValidationError("max_rows must be greater than zero.")
|
|
|
|
limited = _enforce_top_clause(candidate, max_rows=max_rows)
|
|
return limited
|
|
|
|
|
|
def execute_sql(sql: str, max_rows: int = 500) -> SqlExecutionResult:
|
|
"""Validate and run the provided SQL against the configured MSSQL database."""
|
|
|
|
sanitized = sanitize_sql(sql, max_rows=max_rows)
|
|
settings = DbSettings.from_env()
|
|
engine = _build_engine(settings)
|
|
|
|
with engine.connect() as conn:
|
|
result = conn.execute(sa.text(sanitized))
|
|
columns = list(result.keys())
|
|
rows = [dict(row._mapping) for row in result]
|
|
|
|
return SqlExecutionResult(sql=sanitized, rows=rows, columns=columns) |