368 lines
14 KiB
Python
368 lines
14 KiB
Python
"""Client utilities for calling the local TGI-hosted SQL model."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional
|
|
|
|
import requests
|
|
|
|
from .prompting import AgentContext, build_prompt, load_context
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class LlmConfig:
|
|
base_url: str = "http://192.168.0.30:8080"
|
|
model: str = "defog/sqlcoder-7b-2"
|
|
temperature: float = 0.1
|
|
max_new_tokens: Optional[int] = 512
|
|
|
|
|
|
class SqlAgentClient:
|
|
"""Wrapper that builds prompts and fetches SQL from the LLM service."""
|
|
|
|
_FORMAT_REMINDER = (
|
|
"Reminder: Respond with valid JSON exactly as {\"sql\": \"<single T-SQL SELECT statement>\", "
|
|
"\"summary\": \"<one sentence summary>\"}. The sql value must start with SELECT, contain only one "
|
|
"statement, use SQL Server syntax (e.g., LIKE instead of ILIKE, ORDER BY without NULLS LAST/FIRST), and "
|
|
"avoid fabricated literals or extra filters. Only add WHERE conditions or time windows that the user explicitly mentioned. If clarification is required, set sql to an empty string and explain why in summary."
|
|
)
|
|
|
|
def __init__(self, context: AgentContext, llm_config: Optional[LlmConfig] = None) -> None:
|
|
self.context = context
|
|
self.config = llm_config or LlmConfig()
|
|
|
|
def build_request_payload(
|
|
self,
|
|
question: str,
|
|
table_hints: Optional[list[str]] = None,
|
|
*,
|
|
format_hint: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
prompt = build_prompt(
|
|
question=question,
|
|
context=self.context,
|
|
table_hints=table_hints,
|
|
format_hint=format_hint,
|
|
)
|
|
parameters: Dict[str, Any] = {
|
|
"temperature": max(self.config.temperature, 1e-5),
|
|
"return_full_text": False,
|
|
}
|
|
if self.config.max_new_tokens is not None:
|
|
parameters["max_new_tokens"] = self.config.max_new_tokens
|
|
return {
|
|
"inputs": prompt,
|
|
"parameters": parameters,
|
|
}
|
|
|
|
def query(self, question: str, table_hints: Optional[list[str]] = None) -> Dict[str, Any]:
|
|
warnings: list[str] = []
|
|
question_allows_timeframe = self._question_mentions_timeframe(question)
|
|
parsed: Optional[Dict[str, Any]] = None
|
|
|
|
for attempt in range(2):
|
|
if attempt == 1:
|
|
format_hint = (
|
|
self._FORMAT_REMINDER
|
|
+ f' The user question is: "{question}". Do not add filters beyond what is explicitly stated.'
|
|
)
|
|
else:
|
|
format_hint = None
|
|
payload = self.build_request_payload(question, table_hints, format_hint=format_hint)
|
|
parsed = self._invoke_generate(payload)
|
|
|
|
sql_text = parsed.get("sql", "") if isinstance(parsed, dict) else ""
|
|
normalized_sql = self._normalize_sql(sql_text) if sql_text else ""
|
|
parsed["sql"] = normalized_sql
|
|
|
|
if normalized_sql and not question_allows_timeframe:
|
|
normalized_sql, stripped_temporal = self._strip_unrequested_temporal_filters(normalized_sql)
|
|
if stripped_temporal:
|
|
parsed["sql"] = normalized_sql
|
|
warnings.append("Postprocess: Removed unrequested date/time filters from SQL output.")
|
|
|
|
if normalized_sql:
|
|
normalized_sql, stripped_compat = self._replace_non_sqlserver_syntax(normalized_sql)
|
|
if stripped_compat:
|
|
parsed["sql"] = normalized_sql
|
|
warnings.append("Postprocess: Normalized syntax to SQL Server equivalents (e.g., LIKE).")
|
|
|
|
if normalized_sql and normalized_sql.lstrip().lower().startswith("select"):
|
|
if warnings:
|
|
parsed["llm_warning"] = "; ".join(warnings)
|
|
return parsed
|
|
|
|
warning_message = parsed.get("llm_warning")
|
|
if not warning_message:
|
|
warning_message = "Model response did not include a valid SELECT statement."
|
|
# Capture repeated addition of unrequested filters for auditing
|
|
sql_text = parsed.get("sql", "")
|
|
if sql_text:
|
|
lowered_sql = sql_text.lower()
|
|
if any(phrase in lowered_sql for phrase in ("last 24", "last24", "24 hour", "24-hour")):
|
|
warning_message = (
|
|
"Model added an unrequested time filter. Ensure future responses only add filters present in the question."
|
|
)
|
|
warnings.append(f"Attempt {attempt + 1}: {warning_message}")
|
|
|
|
if attempt == 0:
|
|
logger.info("Retrying LLM request with explicit formatting reminder.")
|
|
continue
|
|
|
|
# Second attempt also failed
|
|
warnings.append("Attempt 2: Second attempt also failed to produce a valid SELECT statement.")
|
|
parsed["llm_warning"] = "; ".join(warnings)
|
|
return parsed
|
|
|
|
# Should not reach here, but provide a safe fallback
|
|
if parsed:
|
|
parsed.setdefault("sql", "")
|
|
parsed.setdefault("summary", "")
|
|
parsed["llm_warning"] = "; ".join(warnings) if warnings else "LLM did not produce a usable response."
|
|
return parsed
|
|
|
|
return {
|
|
"sql": "",
|
|
"summary": "",
|
|
"llm_warning": "LLM did not produce a usable response.",
|
|
}
|
|
|
|
def _invoke_generate(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
headers = {"Content-Type": "application/json"}
|
|
response = requests.post(
|
|
f"{self.config.base_url}/generate",
|
|
headers=headers,
|
|
data=json.dumps(payload),
|
|
timeout=120,
|
|
)
|
|
response.raise_for_status()
|
|
try:
|
|
data = response.json()
|
|
except ValueError:
|
|
body = response.text.strip()
|
|
return self._fallback_response(
|
|
raw=body,
|
|
message="LLM HTTP response was not JSON; returning raw text.",
|
|
)
|
|
|
|
content = data.get("generated_text") if isinstance(data, dict) else None
|
|
if not content:
|
|
raw = response.text.strip()
|
|
return self._fallback_response(
|
|
raw=raw or str(data),
|
|
message="LLM response missing generated_text field; returning raw text.",
|
|
)
|
|
|
|
return self._parse_response(content)
|
|
|
|
@staticmethod
|
|
def _strip_code_fence(text: str) -> str:
|
|
stripped = text.strip()
|
|
if stripped.startswith("```") and stripped.endswith("```"):
|
|
without_ticks = stripped.strip("`")
|
|
parts = without_ticks.split("\n", 1)
|
|
if len(parts) == 2:
|
|
body = parts[1]
|
|
else:
|
|
body = parts[0]
|
|
return body.rsplit("```", 1)[0] if "```" in body else body
|
|
return stripped
|
|
|
|
def _parse_response(self, content: str) -> Dict[str, Any]:
|
|
"""Attempt to decode JSON, handling common formatting wrappers."""
|
|
|
|
raw = content.strip()
|
|
|
|
candidates = [raw]
|
|
candidates.append(self._strip_code_fence(raw))
|
|
|
|
# Extract substring between first and last braces if needed
|
|
start = raw.find("{")
|
|
end = raw.rfind("}")
|
|
if start != -1 and end != -1 and end > start:
|
|
candidates.append(raw[start : end + 1])
|
|
|
|
for candidate in candidates:
|
|
candidate = candidate.strip()
|
|
if not candidate:
|
|
continue
|
|
try:
|
|
return json.loads(candidate)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# Heuristic: parse patterns like "SQL query: ... Summary: ..."
|
|
lower = raw.lower()
|
|
if "sql query:" in lower:
|
|
sql_start = lower.find("sql query:") + len("sql query:")
|
|
summary_start = lower.find("summary:")
|
|
if summary_start != -1:
|
|
sql_text = raw[sql_start:summary_start].strip()
|
|
summary_text = raw[summary_start + len("summary:") :].strip()
|
|
return {"sql": sql_text, "summary": summary_text}
|
|
if lower.startswith("sql:"):
|
|
sql_text = raw.split(":", 1)[1].strip()
|
|
return {"sql": sql_text, "summary": ""}
|
|
|
|
# Handle outputs like "... Summary: ..." (with or without explicit SQL label)
|
|
summary_idx = lower.rfind("summary:")
|
|
if summary_idx != -1:
|
|
sql_section = raw[:summary_idx].strip()
|
|
summary_text = raw[summary_idx + len("summary:") :].strip()
|
|
# Remove wrapping quotes if present
|
|
if summary_text.startswith("\"") and summary_text.endswith("\""):
|
|
summary_text = summary_text[1:-1]
|
|
if sql_section.lower().startswith("sql generated:"):
|
|
sql_section = sql_section[len("sql generated:") :].strip()
|
|
return {"sql": sql_section, "summary": summary_text}
|
|
|
|
# Final fallback: treat entire content as SQL and leave summary blank.
|
|
if raw.upper().startswith("SELECT") or "SELECT" in raw.upper():
|
|
return {"sql": raw.strip(), "summary": ""}
|
|
|
|
return self._fallback_response(
|
|
raw=raw,
|
|
message="LLM output could not be parsed; returning raw text.",
|
|
)
|
|
|
|
@staticmethod
|
|
def _fallback_response(raw: str, message: str) -> Dict[str, Any]:
|
|
preview = raw[:200] if raw else "(empty response)"
|
|
logger.warning("%s preview=%s", message, preview)
|
|
return {
|
|
"sql": "",
|
|
"summary": raw,
|
|
"llm_warning": f"{message} Preview: {preview}",
|
|
}
|
|
|
|
@staticmethod
|
|
def _question_mentions_timeframe(question: str) -> bool:
|
|
lowered = question.lower()
|
|
if re.search(
|
|
r"\b(today|yesterday|tonight|this|current|previous|prior|recent|last|past|since|ago|between|before|after|during)\b",
|
|
lowered,
|
|
):
|
|
return True
|
|
if re.search(
|
|
r"\b(day|days|week|weeks|month|months|quarter|quarters|year|years|hour|hours|minute|minutes)\b",
|
|
lowered,
|
|
):
|
|
return True
|
|
if re.search(r"\b(january|february|march|april|may|june|july|august|september|october|november|december)\b", lowered):
|
|
return True
|
|
if re.search(r"\b(mon|tue|wed|thu|fri|sat|sun)(day)?\b", lowered):
|
|
return True
|
|
if re.search(r"\b20\d{2}\b", lowered):
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def _strip_unrequested_temporal_filters(sql: str) -> tuple[str, bool]:
|
|
candidate = sql
|
|
temporal_regex = r"(DATEADD|DATEDIFF|DATEPART|SYSDATETIME|GETDATE|CURRENT_TIMESTAMP)"
|
|
changed = False
|
|
|
|
clause_pattern = re.compile(
|
|
rf"\s+(AND|OR)\s+[^;]*?{temporal_regex}[^;]*?(?=(\bAND\b|\bOR\b|\bGROUP BY\b|\bORDER BY\b|\bHAVING\b|$))",
|
|
re.IGNORECASE | re.DOTALL,
|
|
)
|
|
candidate, replaced = clause_pattern.subn("", candidate)
|
|
if replaced:
|
|
changed = True
|
|
|
|
where_pattern = re.compile(
|
|
rf"\bWHERE\b\s+[^;]*?{temporal_regex}[^;]*?(?=(\bGROUP BY\b|\bORDER BY\b|\bHAVING\b|$))",
|
|
re.IGNORECASE | re.DOTALL,
|
|
)
|
|
candidate, replaced = where_pattern.subn("", candidate)
|
|
if replaced:
|
|
changed = True
|
|
|
|
dangling_where_pattern = re.compile(r"\bWHERE\b\s*(?=(GROUP BY|ORDER BY|HAVING|$))", re.IGNORECASE)
|
|
candidate, replaced = dangling_where_pattern.subn("", candidate)
|
|
if replaced:
|
|
changed = True
|
|
|
|
if changed:
|
|
candidate = re.sub(r"\s{2,}", " ", candidate)
|
|
candidate = re.sub(r"\s+(GROUP BY|ORDER BY|HAVING)\b", r" \1", candidate, flags=re.IGNORECASE)
|
|
candidate = re.sub(r"\s+;", ";", candidate)
|
|
|
|
return candidate.strip(), changed
|
|
|
|
@staticmethod
|
|
def _normalize_sql(sql_text: str) -> str:
|
|
text = sql_text.strip()
|
|
# Remove markdown fences if still present
|
|
if text.startswith("```") and text.endswith("```"):
|
|
inner = text.strip("`")
|
|
parts = inner.split("\n", 1)
|
|
text = parts[1] if len(parts) == 2 else parts[0]
|
|
text = text.rsplit("```", 1)[0]
|
|
|
|
prefix_pattern = re.compile(r"^[\s\.\?\-:\*•]*sql(?:\s+(?:query|statement))?\s*:\s*", re.IGNORECASE)
|
|
text = prefix_pattern.sub("", text, count=1)
|
|
alt_prefix_pattern = re.compile(r"^[\s\.\?\-:\*•]*(?:generated sql|query|answer)\s*:\s*", re.IGNORECASE)
|
|
text = alt_prefix_pattern.sub("", text, count=1)
|
|
|
|
# Drop leading punctuation or bullet artifacts before the SELECT keyword
|
|
lowered = text.lower()
|
|
select_idx = lowered.find("select")
|
|
if select_idx > 0:
|
|
leading = text[:select_idx]
|
|
if leading.strip(" \n\t\r:-*•?.") == "":
|
|
text = text[select_idx:]
|
|
|
|
# Remove any trailing summary section that may still be attached
|
|
summary_split = re.split(r"(?i)\bsummary\s*:", text, maxsplit=1)
|
|
if summary_split:
|
|
text = summary_split[0]
|
|
|
|
return text.strip()
|
|
|
|
@staticmethod
|
|
def _replace_non_sqlserver_syntax(sql_text: str) -> tuple[str, bool]:
|
|
changed = False
|
|
sql = sql_text
|
|
|
|
# Replace ILIKE with LIKE + UPPER for case-insensitive matching
|
|
def replace_ilike(match: re.Match[str]) -> str:
|
|
nonlocal changed
|
|
column = match.group("column")
|
|
value = match.group("value")
|
|
|
|
if column.startswith('"') and column.endswith('"'):
|
|
column = f"[{column[1:-1]}]"
|
|
|
|
if value.startswith('"') and value.endswith('"'):
|
|
value = f"'{value[1:-1]}'"
|
|
|
|
changed = True
|
|
return f"UPPER({column}) LIKE UPPER({value})"
|
|
|
|
ilike_pattern = re.compile(
|
|
r"(?P<column>(?:[\w\.\[\]]+|\"[^\"]+\"))\s+ILIKE\s+(?P<value>'[^']*'|\"[^\"]*\"|:[\w_]+)",
|
|
re.IGNORECASE,
|
|
)
|
|
sql = ilike_pattern.sub(replace_ilike, sql)
|
|
|
|
# Remove NULLS FIRST/LAST clauses
|
|
nulls_pattern = re.compile(r"\s+NULLS\s+(?:FIRST|LAST)", re.IGNORECASE)
|
|
if nulls_pattern.search(sql):
|
|
sql = nulls_pattern.sub("", sql)
|
|
changed = True
|
|
|
|
return sql.strip(), changed
|
|
|
|
|
|
def default_client() -> SqlAgentClient:
|
|
context = load_context()
|
|
config = LlmConfig()
|
|
return SqlAgentClient(context=context, llm_config=config)
|