"""Validation and execution helpers for generated SQL queries.""" from __future__ import annotations import re from dataclasses import dataclass from typing import Any, Dict, List, Sequence import sqlalchemy as sa from sqlalchemy.engine import Engine from .extractor.settings import DbSettings class SqlValidationError(RuntimeError): """Raised when a generated SQL statement fails safety checks.""" PROHIBITED_KEYWORDS: Sequence[str] = ( "INSERT", "UPDATE", "DELETE", "MERGE", "DROP", "ALTER", "TRUNCATE", "CREATE", "EXEC", "EXECUTE", "GRANT", "REVOKE", "BEGIN", "COMMIT", "ROLLBACK", ) UNSUPPORTED_PATTERNS: Sequence[str] = ( r"\bILIKE\b", r"NULLS\s+LAST", r"NULLS\s+FIRST", r"\bLIMIT\b", r"\bOFFSET\b", r"\bUSING\b", ) @dataclass class SqlExecutionResult: sql: str rows: List[Dict[str, Any]] columns: List[str] def _build_engine(settings: DbSettings) -> Engine: connection_string = settings.connection_string() uri = "mssql+pyodbc:///?odbc_connect=" + sa.engine.url.quote_plus(connection_string) return sa.create_engine(uri, fast_executemany=True) def _clean_sql(sql: str) -> str: return sql.strip().rstrip(";") def _contains_multiple_statements(sql: str) -> bool: body = sql.rstrip() if ";" not in body: return False # Allow a single trailing semicolon but forbid anything after it last_semicolon = body.rfind(";") trailing = body[last_semicolon + 1 :].strip() return trailing != "" def _enforce_top_clause(sql: str, max_rows: int) -> str: pattern = re.compile(r"select\s+(distinct\s+)?", re.IGNORECASE) leading_whitespace = len(sql) - len(sql.lstrip()) body = sql.lstrip() match = pattern.match(body) if not match: return sql # Detect existing TOP clause in the initial projection leading_segment = body[: match.end() + 20].lower() if "top" in leading_segment: return sql prefix = match.group(0) remainder = body[match.end() :] if match.group(1): # prefix already includes "SELECT DISTINCT " replacement = prefix + f"TOP ({max_rows}) " + remainder else: replacement = "SELECT TOP ({}) ".format(max_rows) + body[len("SELECT ") :] return sql[:leading_whitespace] + replacement def sanitize_sql(sql: str, max_rows: int = 500) -> str: """Ensure the SQL is safe to execute and apply a row limit.""" candidate = _clean_sql(sql) if not candidate: raise SqlValidationError("Generated SQL is empty.") lowered = candidate.lstrip().lower() if not lowered.startswith("select"): raise SqlValidationError("Only SELECT statements are allowed.") if _contains_multiple_statements(candidate): raise SqlValidationError("Multiple statements detected; reject for safety.") upper = candidate.upper() for keyword in PROHIBITED_KEYWORDS: if re.search(rf"\b{keyword}\b", upper): raise SqlValidationError(f"Statement contains disallowed keyword: {keyword}") for pattern in UNSUPPORTED_PATTERNS: if re.search(pattern, candidate, flags=re.IGNORECASE): raise SqlValidationError("Statement contains non-SQL Server syntax (e.g., ILIKE, NULLS LAST, LIMIT, OFFSET, USING).") if max_rows <= 0: raise SqlValidationError("max_rows must be greater than zero.") limited = _enforce_top_clause(candidate, max_rows=max_rows) return limited def execute_sql(sql: str, max_rows: int = 500) -> SqlExecutionResult: """Validate and run the provided SQL against the configured MSSQL database.""" sanitized = sanitize_sql(sql, max_rows=max_rows) settings = DbSettings.from_env() engine = _build_engine(settings) with engine.connect() as conn: result = conn.execute(sa.text(sanitized)) columns = list(result.keys()) rows = [dict(row._mapping) for row in result] return SqlExecutionResult(sql=sanitized, rows=rows, columns=columns)