Files
2026-02-17 09:29:34 -06:00

368 lines
14 KiB
Python

"""Client utilities for calling the local TGI-hosted SQL model."""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass
from typing import Any, Dict, Optional
import requests
from .prompting import AgentContext, build_prompt, load_context
logger = logging.getLogger(__name__)
@dataclass
class LlmConfig:
base_url: str = "http://192.168.0.30:8080"
model: str = "defog/sqlcoder-7b-2"
temperature: float = 0.1
max_new_tokens: Optional[int] = 512
class SqlAgentClient:
"""Wrapper that builds prompts and fetches SQL from the LLM service."""
_FORMAT_REMINDER = (
"Reminder: Respond with valid JSON exactly as {\"sql\": \"<single T-SQL SELECT statement>\", "
"\"summary\": \"<one sentence summary>\"}. The sql value must start with SELECT, contain only one "
"statement, use SQL Server syntax (e.g., LIKE instead of ILIKE, ORDER BY without NULLS LAST/FIRST), and "
"avoid fabricated literals or extra filters. Only add WHERE conditions or time windows that the user explicitly mentioned. If clarification is required, set sql to an empty string and explain why in summary."
)
def __init__(self, context: AgentContext, llm_config: Optional[LlmConfig] = None) -> None:
self.context = context
self.config = llm_config or LlmConfig()
def build_request_payload(
self,
question: str,
table_hints: Optional[list[str]] = None,
*,
format_hint: Optional[str] = None,
) -> Dict[str, Any]:
prompt = build_prompt(
question=question,
context=self.context,
table_hints=table_hints,
format_hint=format_hint,
)
parameters: Dict[str, Any] = {
"temperature": max(self.config.temperature, 1e-5),
"return_full_text": False,
}
if self.config.max_new_tokens is not None:
parameters["max_new_tokens"] = self.config.max_new_tokens
return {
"inputs": prompt,
"parameters": parameters,
}
def query(self, question: str, table_hints: Optional[list[str]] = None) -> Dict[str, Any]:
warnings: list[str] = []
question_allows_timeframe = self._question_mentions_timeframe(question)
parsed: Optional[Dict[str, Any]] = None
for attempt in range(2):
if attempt == 1:
format_hint = (
self._FORMAT_REMINDER
+ f' The user question is: "{question}". Do not add filters beyond what is explicitly stated.'
)
else:
format_hint = None
payload = self.build_request_payload(question, table_hints, format_hint=format_hint)
parsed = self._invoke_generate(payload)
sql_text = parsed.get("sql", "") if isinstance(parsed, dict) else ""
normalized_sql = self._normalize_sql(sql_text) if sql_text else ""
parsed["sql"] = normalized_sql
if normalized_sql and not question_allows_timeframe:
normalized_sql, stripped_temporal = self._strip_unrequested_temporal_filters(normalized_sql)
if stripped_temporal:
parsed["sql"] = normalized_sql
warnings.append("Postprocess: Removed unrequested date/time filters from SQL output.")
if normalized_sql:
normalized_sql, stripped_compat = self._replace_non_sqlserver_syntax(normalized_sql)
if stripped_compat:
parsed["sql"] = normalized_sql
warnings.append("Postprocess: Normalized syntax to SQL Server equivalents (e.g., LIKE).")
if normalized_sql and normalized_sql.lstrip().lower().startswith("select"):
if warnings:
parsed["llm_warning"] = "; ".join(warnings)
return parsed
warning_message = parsed.get("llm_warning")
if not warning_message:
warning_message = "Model response did not include a valid SELECT statement."
# Capture repeated addition of unrequested filters for auditing
sql_text = parsed.get("sql", "")
if sql_text:
lowered_sql = sql_text.lower()
if any(phrase in lowered_sql for phrase in ("last 24", "last24", "24 hour", "24-hour")):
warning_message = (
"Model added an unrequested time filter. Ensure future responses only add filters present in the question."
)
warnings.append(f"Attempt {attempt + 1}: {warning_message}")
if attempt == 0:
logger.info("Retrying LLM request with explicit formatting reminder.")
continue
# Second attempt also failed
warnings.append("Attempt 2: Second attempt also failed to produce a valid SELECT statement.")
parsed["llm_warning"] = "; ".join(warnings)
return parsed
# Should not reach here, but provide a safe fallback
if parsed:
parsed.setdefault("sql", "")
parsed.setdefault("summary", "")
parsed["llm_warning"] = "; ".join(warnings) if warnings else "LLM did not produce a usable response."
return parsed
return {
"sql": "",
"summary": "",
"llm_warning": "LLM did not produce a usable response.",
}
def _invoke_generate(self, payload: Dict[str, Any]) -> Dict[str, Any]:
headers = {"Content-Type": "application/json"}
response = requests.post(
f"{self.config.base_url}/generate",
headers=headers,
data=json.dumps(payload),
timeout=120,
)
response.raise_for_status()
try:
data = response.json()
except ValueError:
body = response.text.strip()
return self._fallback_response(
raw=body,
message="LLM HTTP response was not JSON; returning raw text.",
)
content = data.get("generated_text") if isinstance(data, dict) else None
if not content:
raw = response.text.strip()
return self._fallback_response(
raw=raw or str(data),
message="LLM response missing generated_text field; returning raw text.",
)
return self._parse_response(content)
@staticmethod
def _strip_code_fence(text: str) -> str:
stripped = text.strip()
if stripped.startswith("```") and stripped.endswith("```"):
without_ticks = stripped.strip("`")
parts = without_ticks.split("\n", 1)
if len(parts) == 2:
body = parts[1]
else:
body = parts[0]
return body.rsplit("```", 1)[0] if "```" in body else body
return stripped
def _parse_response(self, content: str) -> Dict[str, Any]:
"""Attempt to decode JSON, handling common formatting wrappers."""
raw = content.strip()
candidates = [raw]
candidates.append(self._strip_code_fence(raw))
# Extract substring between first and last braces if needed
start = raw.find("{")
end = raw.rfind("}")
if start != -1 and end != -1 and end > start:
candidates.append(raw[start : end + 1])
for candidate in candidates:
candidate = candidate.strip()
if not candidate:
continue
try:
return json.loads(candidate)
except json.JSONDecodeError:
continue
# Heuristic: parse patterns like "SQL query: ... Summary: ..."
lower = raw.lower()
if "sql query:" in lower:
sql_start = lower.find("sql query:") + len("sql query:")
summary_start = lower.find("summary:")
if summary_start != -1:
sql_text = raw[sql_start:summary_start].strip()
summary_text = raw[summary_start + len("summary:") :].strip()
return {"sql": sql_text, "summary": summary_text}
if lower.startswith("sql:"):
sql_text = raw.split(":", 1)[1].strip()
return {"sql": sql_text, "summary": ""}
# Handle outputs like "... Summary: ..." (with or without explicit SQL label)
summary_idx = lower.rfind("summary:")
if summary_idx != -1:
sql_section = raw[:summary_idx].strip()
summary_text = raw[summary_idx + len("summary:") :].strip()
# Remove wrapping quotes if present
if summary_text.startswith("\"") and summary_text.endswith("\""):
summary_text = summary_text[1:-1]
if sql_section.lower().startswith("sql generated:"):
sql_section = sql_section[len("sql generated:") :].strip()
return {"sql": sql_section, "summary": summary_text}
# Final fallback: treat entire content as SQL and leave summary blank.
if raw.upper().startswith("SELECT") or "SELECT" in raw.upper():
return {"sql": raw.strip(), "summary": ""}
return self._fallback_response(
raw=raw,
message="LLM output could not be parsed; returning raw text.",
)
@staticmethod
def _fallback_response(raw: str, message: str) -> Dict[str, Any]:
preview = raw[:200] if raw else "(empty response)"
logger.warning("%s preview=%s", message, preview)
return {
"sql": "",
"summary": raw,
"llm_warning": f"{message} Preview: {preview}",
}
@staticmethod
def _question_mentions_timeframe(question: str) -> bool:
lowered = question.lower()
if re.search(
r"\b(today|yesterday|tonight|this|current|previous|prior|recent|last|past|since|ago|between|before|after|during)\b",
lowered,
):
return True
if re.search(
r"\b(day|days|week|weeks|month|months|quarter|quarters|year|years|hour|hours|minute|minutes)\b",
lowered,
):
return True
if re.search(r"\b(january|february|march|april|may|june|july|august|september|october|november|december)\b", lowered):
return True
if re.search(r"\b(mon|tue|wed|thu|fri|sat|sun)(day)?\b", lowered):
return True
if re.search(r"\b20\d{2}\b", lowered):
return True
return False
@staticmethod
def _strip_unrequested_temporal_filters(sql: str) -> tuple[str, bool]:
candidate = sql
temporal_regex = r"(DATEADD|DATEDIFF|DATEPART|SYSDATETIME|GETDATE|CURRENT_TIMESTAMP)"
changed = False
clause_pattern = re.compile(
rf"\s+(AND|OR)\s+[^;]*?{temporal_regex}[^;]*?(?=(\bAND\b|\bOR\b|\bGROUP BY\b|\bORDER BY\b|\bHAVING\b|$))",
re.IGNORECASE | re.DOTALL,
)
candidate, replaced = clause_pattern.subn("", candidate)
if replaced:
changed = True
where_pattern = re.compile(
rf"\bWHERE\b\s+[^;]*?{temporal_regex}[^;]*?(?=(\bGROUP BY\b|\bORDER BY\b|\bHAVING\b|$))",
re.IGNORECASE | re.DOTALL,
)
candidate, replaced = where_pattern.subn("", candidate)
if replaced:
changed = True
dangling_where_pattern = re.compile(r"\bWHERE\b\s*(?=(GROUP BY|ORDER BY|HAVING|$))", re.IGNORECASE)
candidate, replaced = dangling_where_pattern.subn("", candidate)
if replaced:
changed = True
if changed:
candidate = re.sub(r"\s{2,}", " ", candidate)
candidate = re.sub(r"\s+(GROUP BY|ORDER BY|HAVING)\b", r" \1", candidate, flags=re.IGNORECASE)
candidate = re.sub(r"\s+;", ";", candidate)
return candidate.strip(), changed
@staticmethod
def _normalize_sql(sql_text: str) -> str:
text = sql_text.strip()
# Remove markdown fences if still present
if text.startswith("```") and text.endswith("```"):
inner = text.strip("`")
parts = inner.split("\n", 1)
text = parts[1] if len(parts) == 2 else parts[0]
text = text.rsplit("```", 1)[0]
prefix_pattern = re.compile(r"^[\s\.\?\-:\*•]*sql(?:\s+(?:query|statement))?\s*:\s*", re.IGNORECASE)
text = prefix_pattern.sub("", text, count=1)
alt_prefix_pattern = re.compile(r"^[\s\.\?\-:\*•]*(?:generated sql|query|answer)\s*:\s*", re.IGNORECASE)
text = alt_prefix_pattern.sub("", text, count=1)
# Drop leading punctuation or bullet artifacts before the SELECT keyword
lowered = text.lower()
select_idx = lowered.find("select")
if select_idx > 0:
leading = text[:select_idx]
if leading.strip(" \n\t\r:-*•?.") == "":
text = text[select_idx:]
# Remove any trailing summary section that may still be attached
summary_split = re.split(r"(?i)\bsummary\s*:", text, maxsplit=1)
if summary_split:
text = summary_split[0]
return text.strip()
@staticmethod
def _replace_non_sqlserver_syntax(sql_text: str) -> tuple[str, bool]:
changed = False
sql = sql_text
# Replace ILIKE with LIKE + UPPER for case-insensitive matching
def replace_ilike(match: re.Match[str]) -> str:
nonlocal changed
column = match.group("column")
value = match.group("value")
if column.startswith('"') and column.endswith('"'):
column = f"[{column[1:-1]}]"
if value.startswith('"') and value.endswith('"'):
value = f"'{value[1:-1]}'"
changed = True
return f"UPPER({column}) LIKE UPPER({value})"
ilike_pattern = re.compile(
r"(?P<column>(?:[\w\.\[\]]+|\"[^\"]+\"))\s+ILIKE\s+(?P<value>'[^']*'|\"[^\"]*\"|:[\w_]+)",
re.IGNORECASE,
)
sql = ilike_pattern.sub(replace_ilike, sql)
# Remove NULLS FIRST/LAST clauses
nulls_pattern = re.compile(r"\s+NULLS\s+(?:FIRST|LAST)", re.IGNORECASE)
if nulls_pattern.search(sql):
sql = nulls_pattern.sub("", sql)
changed = True
return sql.strip(), changed
def default_client() -> SqlAgentClient:
context = load_context()
config = LlmConfig()
return SqlAgentClient(context=context, llm_config=config)