Files
controls-web/ai_agents/cane_agent/prompting/prompt_builder.py
2026-02-17 09:29:34 -06:00

191 lines
7.1 KiB
Python

"""Assemble LLM prompts for natural-language SQL generation."""
from __future__ import annotations
import json
import textwrap
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence
import yaml
@dataclass(frozen=True)
class PromptConfig:
"""File layout for context artifacts consumed by the prompt builder."""
context_dir: Path = field(default_factory=lambda: Path("db_agent/context"))
schema_filename: str = "schema.json"
glossary_filename: str = "glossary.md"
value_hints_filename: str = "value_hints.yaml"
examples_filename: str = "examples.json"
@dataclass
class AgentContext:
"""Loaded context assets ready for prompt rendering."""
schema: Dict[str, Any]
glossary_markdown: str
value_hints: Dict[str, Any]
examples: List[Dict[str, Any]]
def load_context(config: PromptConfig | None = None) -> AgentContext:
"""Load the canonical context files from disk."""
cfg = config or PromptConfig()
base = cfg.context_dir
schema_path = base / cfg.schema_filename
glossary_path = base / cfg.glossary_filename
value_hints_path = base / cfg.value_hints_filename
examples_path = base / cfg.examples_filename
schema = json.loads(schema_path.read_text(encoding="utf-8"))
glossary_markdown = glossary_path.read_text(encoding="utf-8")
value_hints = yaml.safe_load(value_hints_path.read_text(encoding="utf-8")) or {}
examples = json.loads(examples_path.read_text(encoding="utf-8"))
return AgentContext(
schema=schema,
glossary_markdown=glossary_markdown,
value_hints=value_hints,
examples=examples,
)
def _format_table(table_key: str, table_doc: Dict[str, Any], max_columns: int) -> str:
"""Render a single table definition block."""
lines: List[str] = [f"Table {table_key}: {table_doc.get('schema')}.{table_doc.get('name')}"]
column_docs = table_doc.get("columns", [])
for idx, column in enumerate(column_docs):
if idx >= max_columns:
lines.append(f" ... ({len(column_docs) - max_columns} more columns omitted)")
break
col_line = f" - {column['name']} ({column['type']})"
if column.get("nullable") is False:
col_line += " NOT NULL"
if column.get("default") not in (None, ""):
col_line += f" DEFAULT {column['default']}"
lines.append(col_line)
relationships = table_doc.get("relationships", [])
if relationships:
lines.append(" Relationships:")
for rel in relationships:
target = rel.get("target")
src_col = rel.get("source_column")
tgt_col = rel.get("target_column")
lines.append(f" * {src_col} -> {target}.{tgt_col}")
return "\n".join(lines)
def _format_schema(schema_doc: Dict[str, Any], include_tables: Optional[Sequence[str]], max_columns: int) -> str:
"""Render the subset of tables to include in the prompt."""
table_section: List[str] = []
tables = schema_doc.get("tables", {})
def table_selected(table_key: str, table_doc: Dict[str, Any]) -> bool:
if not include_tables:
return True
candidates = {table_key, table_doc.get("name"), f"{table_doc.get('schema')}.{table_doc.get('name')}"}
return any(name in include_tables for name in candidates if name)
for table_key, table_doc in tables.items():
if table_selected(table_key, table_doc):
table_section.append(_format_table(table_key, table_doc, max_columns=max_columns))
if not table_section:
return "(No tables selected; broaden include_tables to see schema context.)"
return "\n\n".join(table_section)
def _format_examples(examples: Iterable[Dict[str, Any]]) -> str:
"""Render examples as numbered items for clarity."""
lines: List[str] = []
for idx, example in enumerate(examples, start=1):
question = example.get("question", "(missing question)")
sql = example.get("sql", "-- SQL not provided")
notes = example.get("notes")
lines.append(f"Example {idx}: {question}")
lines.append("SQL:\n" + sql)
if notes:
lines.append(f"Notes: {notes}")
lines.append("")
return "\n".join(lines).strip()
def _format_value_hints(hints: Dict[str, Any]) -> str:
"""Dump value hints to YAML for prompt readability."""
if not hints:
return "(No value hints available.)"
return yaml.safe_dump(hints, sort_keys=False, allow_unicode=False)
def build_prompt(
question: str,
context: AgentContext,
table_hints: Optional[Sequence[str]] = None,
max_columns: int = 12,
format_hint: Optional[str] = None,
) -> str:
"""Compose the full system + user prompt for the SQL agent."""
schema_section = _format_schema(context.schema, include_tables=table_hints, max_columns=max_columns)
examples_section = _format_examples(context.examples)
value_hints_section = _format_value_hints(context.value_hints)
additional_instructions = ""
if format_hint:
additional_instructions = f"\n ### Additional Formatting Instructions\n {format_hint}\n"
prompt = textwrap.dedent(
f"""You are an assistant that translates operator questions into safe T-SQL for SQL Server.
Follow these rules:
- Use SELECT statements only; never modify data.
- Apply TOP or date filters only when explicitly requested or when needed to respect the max_rows limit; never infer time windows that the user did not mention.
- Reference fully qualified table names with schema (e.g., dbo.TableName).
- Use SQL Server date/time functions: prefer CONVERT(date, ...), CAST(... AS date), DATEPART, DATEADD, DATEDIFF. Do not use MySQL/Postgres functions like DATE(), EXTRACT(), or INTERVAL literals.
- Stick to SQL Server syntax: use LIKE (and optionally UPPER/LTRIM/RTRIM) instead of ILIKE; avoid LIMIT/OFFSET, NULLS FIRST/LAST, USING, or other non-T-SQL constructs.
- Only include literals or filters grounded in the question or provided hints; do not fabricate IDs, random suffixes, or placeholder strings.
- Every filter must be justified by the user question: "{question}". If the question does not reference a timeframe, do not add one. If unsure, return an empty sql value and ask for clarification in the summary.
- If information is missing, ask for clarification instead of guessing.
### Database Schema
{schema_section}
### Business Glossary
{context.glossary_markdown.strip()}
### Value Hints
{value_hints_section.strip()}
### Worked Examples
{examples_section}
### Task
Generate a SQL query that answers the user question and summarize the answer in natural language.
Return ONLY a JSON object with the following structure (no code fences, no extra commentary):
{{"sql": "<the SQL query>", "summary": "<one sentence summary>"}}
{additional_instructions}
User question: {question}
"""
).strip()
return prompt
__all__ = [
"PromptConfig",
"AgentContext",
"load_context",
"build_prompt",
]