191 lines
7.1 KiB
Python
191 lines
7.1 KiB
Python
"""Assemble LLM prompts for natural-language SQL generation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import textwrap
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Iterable, List, Optional, Sequence
|
|
|
|
import yaml
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PromptConfig:
|
|
"""File layout for context artifacts consumed by the prompt builder."""
|
|
|
|
context_dir: Path = field(default_factory=lambda: Path("db_agent/context"))
|
|
schema_filename: str = "schema.json"
|
|
glossary_filename: str = "glossary.md"
|
|
value_hints_filename: str = "value_hints.yaml"
|
|
examples_filename: str = "examples.json"
|
|
|
|
|
|
@dataclass
|
|
class AgentContext:
|
|
"""Loaded context assets ready for prompt rendering."""
|
|
|
|
schema: Dict[str, Any]
|
|
glossary_markdown: str
|
|
value_hints: Dict[str, Any]
|
|
examples: List[Dict[str, Any]]
|
|
|
|
|
|
def load_context(config: PromptConfig | None = None) -> AgentContext:
|
|
"""Load the canonical context files from disk."""
|
|
|
|
cfg = config or PromptConfig()
|
|
base = cfg.context_dir
|
|
|
|
schema_path = base / cfg.schema_filename
|
|
glossary_path = base / cfg.glossary_filename
|
|
value_hints_path = base / cfg.value_hints_filename
|
|
examples_path = base / cfg.examples_filename
|
|
|
|
schema = json.loads(schema_path.read_text(encoding="utf-8"))
|
|
glossary_markdown = glossary_path.read_text(encoding="utf-8")
|
|
value_hints = yaml.safe_load(value_hints_path.read_text(encoding="utf-8")) or {}
|
|
examples = json.loads(examples_path.read_text(encoding="utf-8"))
|
|
|
|
return AgentContext(
|
|
schema=schema,
|
|
glossary_markdown=glossary_markdown,
|
|
value_hints=value_hints,
|
|
examples=examples,
|
|
)
|
|
|
|
|
|
def _format_table(table_key: str, table_doc: Dict[str, Any], max_columns: int) -> str:
|
|
"""Render a single table definition block."""
|
|
|
|
lines: List[str] = [f"Table {table_key}: {table_doc.get('schema')}.{table_doc.get('name')}"]
|
|
column_docs = table_doc.get("columns", [])
|
|
for idx, column in enumerate(column_docs):
|
|
if idx >= max_columns:
|
|
lines.append(f" ... ({len(column_docs) - max_columns} more columns omitted)")
|
|
break
|
|
col_line = f" - {column['name']} ({column['type']})"
|
|
if column.get("nullable") is False:
|
|
col_line += " NOT NULL"
|
|
if column.get("default") not in (None, ""):
|
|
col_line += f" DEFAULT {column['default']}"
|
|
lines.append(col_line)
|
|
|
|
relationships = table_doc.get("relationships", [])
|
|
if relationships:
|
|
lines.append(" Relationships:")
|
|
for rel in relationships:
|
|
target = rel.get("target")
|
|
src_col = rel.get("source_column")
|
|
tgt_col = rel.get("target_column")
|
|
lines.append(f" * {src_col} -> {target}.{tgt_col}")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _format_schema(schema_doc: Dict[str, Any], include_tables: Optional[Sequence[str]], max_columns: int) -> str:
|
|
"""Render the subset of tables to include in the prompt."""
|
|
|
|
table_section: List[str] = []
|
|
tables = schema_doc.get("tables", {})
|
|
|
|
def table_selected(table_key: str, table_doc: Dict[str, Any]) -> bool:
|
|
if not include_tables:
|
|
return True
|
|
candidates = {table_key, table_doc.get("name"), f"{table_doc.get('schema')}.{table_doc.get('name')}"}
|
|
return any(name in include_tables for name in candidates if name)
|
|
|
|
for table_key, table_doc in tables.items():
|
|
if table_selected(table_key, table_doc):
|
|
table_section.append(_format_table(table_key, table_doc, max_columns=max_columns))
|
|
|
|
if not table_section:
|
|
return "(No tables selected; broaden include_tables to see schema context.)"
|
|
|
|
return "\n\n".join(table_section)
|
|
|
|
|
|
def _format_examples(examples: Iterable[Dict[str, Any]]) -> str:
|
|
"""Render examples as numbered items for clarity."""
|
|
|
|
lines: List[str] = []
|
|
for idx, example in enumerate(examples, start=1):
|
|
question = example.get("question", "(missing question)")
|
|
sql = example.get("sql", "-- SQL not provided")
|
|
notes = example.get("notes")
|
|
lines.append(f"Example {idx}: {question}")
|
|
lines.append("SQL:\n" + sql)
|
|
if notes:
|
|
lines.append(f"Notes: {notes}")
|
|
lines.append("")
|
|
return "\n".join(lines).strip()
|
|
|
|
|
|
def _format_value_hints(hints: Dict[str, Any]) -> str:
|
|
"""Dump value hints to YAML for prompt readability."""
|
|
|
|
if not hints:
|
|
return "(No value hints available.)"
|
|
return yaml.safe_dump(hints, sort_keys=False, allow_unicode=False)
|
|
|
|
|
|
def build_prompt(
|
|
question: str,
|
|
context: AgentContext,
|
|
table_hints: Optional[Sequence[str]] = None,
|
|
max_columns: int = 12,
|
|
format_hint: Optional[str] = None,
|
|
) -> str:
|
|
"""Compose the full system + user prompt for the SQL agent."""
|
|
|
|
schema_section = _format_schema(context.schema, include_tables=table_hints, max_columns=max_columns)
|
|
examples_section = _format_examples(context.examples)
|
|
value_hints_section = _format_value_hints(context.value_hints)
|
|
additional_instructions = ""
|
|
if format_hint:
|
|
additional_instructions = f"\n ### Additional Formatting Instructions\n {format_hint}\n"
|
|
|
|
prompt = textwrap.dedent(
|
|
f"""You are an assistant that translates operator questions into safe T-SQL for SQL Server.
|
|
Follow these rules:
|
|
- Use SELECT statements only; never modify data.
|
|
- Apply TOP or date filters only when explicitly requested or when needed to respect the max_rows limit; never infer time windows that the user did not mention.
|
|
- Reference fully qualified table names with schema (e.g., dbo.TableName).
|
|
- Use SQL Server date/time functions: prefer CONVERT(date, ...), CAST(... AS date), DATEPART, DATEADD, DATEDIFF. Do not use MySQL/Postgres functions like DATE(), EXTRACT(), or INTERVAL literals.
|
|
- Stick to SQL Server syntax: use LIKE (and optionally UPPER/LTRIM/RTRIM) instead of ILIKE; avoid LIMIT/OFFSET, NULLS FIRST/LAST, USING, or other non-T-SQL constructs.
|
|
- Only include literals or filters grounded in the question or provided hints; do not fabricate IDs, random suffixes, or placeholder strings.
|
|
- Every filter must be justified by the user question: "{question}". If the question does not reference a timeframe, do not add one. If unsure, return an empty sql value and ask for clarification in the summary.
|
|
- If information is missing, ask for clarification instead of guessing.
|
|
|
|
### Database Schema
|
|
{schema_section}
|
|
|
|
### Business Glossary
|
|
{context.glossary_markdown.strip()}
|
|
|
|
### Value Hints
|
|
{value_hints_section.strip()}
|
|
|
|
### Worked Examples
|
|
{examples_section}
|
|
|
|
### Task
|
|
Generate a SQL query that answers the user question and summarize the answer in natural language.
|
|
Return ONLY a JSON object with the following structure (no code fences, no extra commentary):
|
|
{{"sql": "<the SQL query>", "summary": "<one sentence summary>"}}
|
|
{additional_instructions}
|
|
|
|
User question: {question}
|
|
"""
|
|
).strip()
|
|
return prompt
|
|
|
|
|
|
__all__ = [
|
|
"PromptConfig",
|
|
"AgentContext",
|
|
"load_context",
|
|
"build_prompt",
|
|
]
|