Files
2026-02-17 09:29:34 -06:00

141 lines
4.8 KiB
Python

"""Example entry point for running a natural-language SQL query."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Optional
from .client import SqlAgentClient, default_client
from .log_utils import log_interaction
from .sql_executor import SqlExecutionResult, SqlValidationError, execute_sql
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run a question through the SQL agent.")
parser.add_argument("question", help="Natural language question to transform into SQL.")
parser.add_argument(
"--tables",
nargs="*",
default=None,
help="Optional list of tables to prioritize (e.g., dbo.SugarLoadData).",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Optional path to write the raw LLM JSON response.",
)
parser.add_argument(
"--execute",
action="store_true",
help="Validate and run the generated SQL against MSSQL.",
)
parser.add_argument(
"--max-rows",
type=int,
default=500,
help="Maximum rows to fetch; enforced via TOP clause when executing.",
)
parser.add_argument(
"--log-path",
type=str,
default=None,
help="Append interaction logs to this JSONL file (defaults to db_agent/logs/query_log.jsonl).",
)
parser.add_argument(
"--no-log",
action="store_true",
help="Disable interaction logging.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
client: SqlAgentClient = default_client()
response = client.query(question=args.question, table_hints=args.tables)
if args.output:
with open(args.output, "w", encoding="utf-8") as fp:
json.dump(response, fp, indent=2)
print(f"Response written to {args.output}")
else:
print(json.dumps(response, indent=2))
generated_sql = response.get("sql", "").strip()
summary = response.get("summary", "").strip()
execution_result: Optional[SqlExecutionResult] = None
normalized_sql = generated_sql.lstrip()
sql_is_select = normalized_sql.lower().startswith("select") if normalized_sql else False
execution_warning: Optional[str] = None
execution_status: Optional[str] = None
execution_error: Optional[str] = None
if args.execute:
if not normalized_sql:
execution_warning = "No SQL returned; skipping execution."
print(execution_warning)
elif not sql_is_select:
execution_warning = "Generated SQL is not a SELECT statement; skipping execution."
print(execution_warning)
else:
try:
execution_result = execute_sql(normalized_sql, max_rows=args.max_rows)
except SqlValidationError as exc:
execution_warning = f"SQL validation failed: {exc}"
execution_error = str(exc)
print(f"SQL validation failed: {exc}")
except Exception as exc:
execution_error = str(exc)
print(f"SQL execution failed: {exc}")
raise
else:
execution_status = "success"
print("Sanitized SQL:")
print(execution_result.sql)
print(f"Rows returned: {len(execution_result.rows)}")
if execution_result.rows:
preview = execution_result.rows[: min(5, len(execution_result.rows))]
print("Preview (up to 5 rows):")
print(json.dumps(preview, indent=2, default=str))
if not args.no_log:
log_path = Path(args.log_path) if args.log_path else None
metadata = {
"tables_hint": args.tables,
"execution_requested": args.execute,
}
if execution_result:
metadata.update(
{
"columns": execution_result.columns,
"preview_rows": min(5, len(execution_result.rows)),
}
)
if execution_warning:
metadata["execution_warning"] = execution_warning
log_interaction(
question=args.question,
generated_sql=generated_sql,
summary=summary,
sanitized_sql=execution_result.sql if execution_result else None,
row_count=len(execution_result.rows) if execution_result else None,
execution_status=execution_status,
execution_error=execution_error,
raw_response=json.dumps(response, ensure_ascii=False),
log_path=log_path,
metadata=metadata,
)
if log_path:
print(f"Interaction appended to {log_path}")
else:
print("Interaction appended to db_agent/logs/query_log.jsonl")
if __name__ == "__main__":
main()