141 lines
4.8 KiB
Python
141 lines
4.8 KiB
Python
"""Example entry point for running a natural-language SQL query."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
from .client import SqlAgentClient, default_client
|
|
from .log_utils import log_interaction
|
|
from .sql_executor import SqlExecutionResult, SqlValidationError, execute_sql
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Run a question through the SQL agent.")
|
|
parser.add_argument("question", help="Natural language question to transform into SQL.")
|
|
parser.add_argument(
|
|
"--tables",
|
|
nargs="*",
|
|
default=None,
|
|
help="Optional list of tables to prioritize (e.g., dbo.SugarLoadData).",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
default=None,
|
|
help="Optional path to write the raw LLM JSON response.",
|
|
)
|
|
parser.add_argument(
|
|
"--execute",
|
|
action="store_true",
|
|
help="Validate and run the generated SQL against MSSQL.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-rows",
|
|
type=int,
|
|
default=500,
|
|
help="Maximum rows to fetch; enforced via TOP clause when executing.",
|
|
)
|
|
parser.add_argument(
|
|
"--log-path",
|
|
type=str,
|
|
default=None,
|
|
help="Append interaction logs to this JSONL file (defaults to db_agent/logs/query_log.jsonl).",
|
|
)
|
|
parser.add_argument(
|
|
"--no-log",
|
|
action="store_true",
|
|
help="Disable interaction logging.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
client: SqlAgentClient = default_client()
|
|
response = client.query(question=args.question, table_hints=args.tables)
|
|
if args.output:
|
|
with open(args.output, "w", encoding="utf-8") as fp:
|
|
json.dump(response, fp, indent=2)
|
|
print(f"Response written to {args.output}")
|
|
else:
|
|
print(json.dumps(response, indent=2))
|
|
|
|
generated_sql = response.get("sql", "").strip()
|
|
summary = response.get("summary", "").strip()
|
|
execution_result: Optional[SqlExecutionResult] = None
|
|
|
|
normalized_sql = generated_sql.lstrip()
|
|
sql_is_select = normalized_sql.lower().startswith("select") if normalized_sql else False
|
|
|
|
execution_warning: Optional[str] = None
|
|
execution_status: Optional[str] = None
|
|
execution_error: Optional[str] = None
|
|
|
|
if args.execute:
|
|
if not normalized_sql:
|
|
execution_warning = "No SQL returned; skipping execution."
|
|
print(execution_warning)
|
|
elif not sql_is_select:
|
|
execution_warning = "Generated SQL is not a SELECT statement; skipping execution."
|
|
print(execution_warning)
|
|
else:
|
|
try:
|
|
execution_result = execute_sql(normalized_sql, max_rows=args.max_rows)
|
|
except SqlValidationError as exc:
|
|
execution_warning = f"SQL validation failed: {exc}"
|
|
execution_error = str(exc)
|
|
print(f"SQL validation failed: {exc}")
|
|
except Exception as exc:
|
|
execution_error = str(exc)
|
|
print(f"SQL execution failed: {exc}")
|
|
raise
|
|
else:
|
|
execution_status = "success"
|
|
print("Sanitized SQL:")
|
|
print(execution_result.sql)
|
|
print(f"Rows returned: {len(execution_result.rows)}")
|
|
if execution_result.rows:
|
|
preview = execution_result.rows[: min(5, len(execution_result.rows))]
|
|
print("Preview (up to 5 rows):")
|
|
print(json.dumps(preview, indent=2, default=str))
|
|
|
|
if not args.no_log:
|
|
log_path = Path(args.log_path) if args.log_path else None
|
|
metadata = {
|
|
"tables_hint": args.tables,
|
|
"execution_requested": args.execute,
|
|
}
|
|
if execution_result:
|
|
metadata.update(
|
|
{
|
|
"columns": execution_result.columns,
|
|
"preview_rows": min(5, len(execution_result.rows)),
|
|
}
|
|
)
|
|
if execution_warning:
|
|
metadata["execution_warning"] = execution_warning
|
|
|
|
log_interaction(
|
|
question=args.question,
|
|
generated_sql=generated_sql,
|
|
summary=summary,
|
|
sanitized_sql=execution_result.sql if execution_result else None,
|
|
row_count=len(execution_result.rows) if execution_result else None,
|
|
execution_status=execution_status,
|
|
execution_error=execution_error,
|
|
raw_response=json.dumps(response, ensure_ascii=False),
|
|
log_path=log_path,
|
|
metadata=metadata,
|
|
)
|
|
if log_path:
|
|
print(f"Interaction appended to {log_path}")
|
|
else:
|
|
print("Interaction appended to db_agent/logs/query_log.jsonl")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|