"""Example entry point for running a natural-language SQL query.""" from __future__ import annotations import argparse import json from pathlib import Path from typing import Optional from .client import SqlAgentClient, default_client from .log_utils import log_interaction from .sql_executor import SqlExecutionResult, SqlValidationError, execute_sql def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run a question through the SQL agent.") parser.add_argument("question", help="Natural language question to transform into SQL.") parser.add_argument( "--tables", nargs="*", default=None, help="Optional list of tables to prioritize (e.g., dbo.SugarLoadData).", ) parser.add_argument( "--output", type=str, default=None, help="Optional path to write the raw LLM JSON response.", ) parser.add_argument( "--execute", action="store_true", help="Validate and run the generated SQL against MSSQL.", ) parser.add_argument( "--max-rows", type=int, default=500, help="Maximum rows to fetch; enforced via TOP clause when executing.", ) parser.add_argument( "--log-path", type=str, default=None, help="Append interaction logs to this JSONL file (defaults to db_agent/logs/query_log.jsonl).", ) parser.add_argument( "--no-log", action="store_true", help="Disable interaction logging.", ) return parser.parse_args() def main() -> None: args = parse_args() client: SqlAgentClient = default_client() response = client.query(question=args.question, table_hints=args.tables) if args.output: with open(args.output, "w", encoding="utf-8") as fp: json.dump(response, fp, indent=2) print(f"Response written to {args.output}") else: print(json.dumps(response, indent=2)) generated_sql = response.get("sql", "").strip() summary = response.get("summary", "").strip() execution_result: Optional[SqlExecutionResult] = None normalized_sql = generated_sql.lstrip() sql_is_select = normalized_sql.lower().startswith("select") if normalized_sql else False execution_warning: Optional[str] = None execution_status: Optional[str] = None execution_error: Optional[str] = None if args.execute: if not normalized_sql: execution_warning = "No SQL returned; skipping execution." print(execution_warning) elif not sql_is_select: execution_warning = "Generated SQL is not a SELECT statement; skipping execution." print(execution_warning) else: try: execution_result = execute_sql(normalized_sql, max_rows=args.max_rows) except SqlValidationError as exc: execution_warning = f"SQL validation failed: {exc}" execution_error = str(exc) print(f"SQL validation failed: {exc}") except Exception as exc: execution_error = str(exc) print(f"SQL execution failed: {exc}") raise else: execution_status = "success" print("Sanitized SQL:") print(execution_result.sql) print(f"Rows returned: {len(execution_result.rows)}") if execution_result.rows: preview = execution_result.rows[: min(5, len(execution_result.rows))] print("Preview (up to 5 rows):") print(json.dumps(preview, indent=2, default=str)) if not args.no_log: log_path = Path(args.log_path) if args.log_path else None metadata = { "tables_hint": args.tables, "execution_requested": args.execute, } if execution_result: metadata.update( { "columns": execution_result.columns, "preview_rows": min(5, len(execution_result.rows)), } ) if execution_warning: metadata["execution_warning"] = execution_warning log_interaction( question=args.question, generated_sql=generated_sql, summary=summary, sanitized_sql=execution_result.sql if execution_result else None, row_count=len(execution_result.rows) if execution_result else None, execution_status=execution_status, execution_error=execution_error, raw_response=json.dumps(response, ensure_ascii=False), log_path=log_path, metadata=metadata, ) if log_path: print(f"Interaction appended to {log_path}") else: print("Interaction appended to db_agent/logs/query_log.jsonl") if __name__ == "__main__": main()