"""Client utilities for calling the local TGI-hosted SQL model.""" from __future__ import annotations import json import logging import re from dataclasses import dataclass from typing import Any, Dict, Optional import requests from .prompting import AgentContext, build_prompt, load_context logger = logging.getLogger(__name__) @dataclass class LlmConfig: base_url: str = "http://192.168.0.30:8080" model: str = "defog/sqlcoder-7b-2" temperature: float = 0.1 max_new_tokens: Optional[int] = 512 class SqlAgentClient: """Wrapper that builds prompts and fetches SQL from the LLM service.""" _FORMAT_REMINDER = ( "Reminder: Respond with valid JSON exactly as {\"sql\": \"\", " "\"summary\": \"\"}. The sql value must start with SELECT, contain only one " "statement, use SQL Server syntax (e.g., LIKE instead of ILIKE, ORDER BY without NULLS LAST/FIRST), and " "avoid fabricated literals or extra filters. Only add WHERE conditions or time windows that the user explicitly mentioned. If clarification is required, set sql to an empty string and explain why in summary." ) def __init__(self, context: AgentContext, llm_config: Optional[LlmConfig] = None) -> None: self.context = context self.config = llm_config or LlmConfig() def build_request_payload( self, question: str, table_hints: Optional[list[str]] = None, *, format_hint: Optional[str] = None, ) -> Dict[str, Any]: prompt = build_prompt( question=question, context=self.context, table_hints=table_hints, format_hint=format_hint, ) parameters: Dict[str, Any] = { "temperature": max(self.config.temperature, 1e-5), "return_full_text": False, } if self.config.max_new_tokens is not None: parameters["max_new_tokens"] = self.config.max_new_tokens return { "inputs": prompt, "parameters": parameters, } def query(self, question: str, table_hints: Optional[list[str]] = None) -> Dict[str, Any]: warnings: list[str] = [] question_allows_timeframe = self._question_mentions_timeframe(question) parsed: Optional[Dict[str, Any]] = None for attempt in range(2): if attempt == 1: format_hint = ( self._FORMAT_REMINDER + f' The user question is: "{question}". Do not add filters beyond what is explicitly stated.' ) else: format_hint = None payload = self.build_request_payload(question, table_hints, format_hint=format_hint) parsed = self._invoke_generate(payload) sql_text = parsed.get("sql", "") if isinstance(parsed, dict) else "" normalized_sql = self._normalize_sql(sql_text) if sql_text else "" parsed["sql"] = normalized_sql if normalized_sql and not question_allows_timeframe: normalized_sql, stripped_temporal = self._strip_unrequested_temporal_filters(normalized_sql) if stripped_temporal: parsed["sql"] = normalized_sql warnings.append("Postprocess: Removed unrequested date/time filters from SQL output.") if normalized_sql: normalized_sql, stripped_compat = self._replace_non_sqlserver_syntax(normalized_sql) if stripped_compat: parsed["sql"] = normalized_sql warnings.append("Postprocess: Normalized syntax to SQL Server equivalents (e.g., LIKE).") if normalized_sql and normalized_sql.lstrip().lower().startswith("select"): if warnings: parsed["llm_warning"] = "; ".join(warnings) return parsed warning_message = parsed.get("llm_warning") if not warning_message: warning_message = "Model response did not include a valid SELECT statement." # Capture repeated addition of unrequested filters for auditing sql_text = parsed.get("sql", "") if sql_text: lowered_sql = sql_text.lower() if any(phrase in lowered_sql for phrase in ("last 24", "last24", "24 hour", "24-hour")): warning_message = ( "Model added an unrequested time filter. Ensure future responses only add filters present in the question." ) warnings.append(f"Attempt {attempt + 1}: {warning_message}") if attempt == 0: logger.info("Retrying LLM request with explicit formatting reminder.") continue # Second attempt also failed warnings.append("Attempt 2: Second attempt also failed to produce a valid SELECT statement.") parsed["llm_warning"] = "; ".join(warnings) return parsed # Should not reach here, but provide a safe fallback if parsed: parsed.setdefault("sql", "") parsed.setdefault("summary", "") parsed["llm_warning"] = "; ".join(warnings) if warnings else "LLM did not produce a usable response." return parsed return { "sql": "", "summary": "", "llm_warning": "LLM did not produce a usable response.", } def _invoke_generate(self, payload: Dict[str, Any]) -> Dict[str, Any]: headers = {"Content-Type": "application/json"} response = requests.post( f"{self.config.base_url}/generate", headers=headers, data=json.dumps(payload), timeout=120, ) response.raise_for_status() try: data = response.json() except ValueError: body = response.text.strip() return self._fallback_response( raw=body, message="LLM HTTP response was not JSON; returning raw text.", ) content = data.get("generated_text") if isinstance(data, dict) else None if not content: raw = response.text.strip() return self._fallback_response( raw=raw or str(data), message="LLM response missing generated_text field; returning raw text.", ) return self._parse_response(content) @staticmethod def _strip_code_fence(text: str) -> str: stripped = text.strip() if stripped.startswith("```") and stripped.endswith("```"): without_ticks = stripped.strip("`") parts = without_ticks.split("\n", 1) if len(parts) == 2: body = parts[1] else: body = parts[0] return body.rsplit("```", 1)[0] if "```" in body else body return stripped def _parse_response(self, content: str) -> Dict[str, Any]: """Attempt to decode JSON, handling common formatting wrappers.""" raw = content.strip() candidates = [raw] candidates.append(self._strip_code_fence(raw)) # Extract substring between first and last braces if needed start = raw.find("{") end = raw.rfind("}") if start != -1 and end != -1 and end > start: candidates.append(raw[start : end + 1]) for candidate in candidates: candidate = candidate.strip() if not candidate: continue try: return json.loads(candidate) except json.JSONDecodeError: continue # Heuristic: parse patterns like "SQL query: ... Summary: ..." lower = raw.lower() if "sql query:" in lower: sql_start = lower.find("sql query:") + len("sql query:") summary_start = lower.find("summary:") if summary_start != -1: sql_text = raw[sql_start:summary_start].strip() summary_text = raw[summary_start + len("summary:") :].strip() return {"sql": sql_text, "summary": summary_text} if lower.startswith("sql:"): sql_text = raw.split(":", 1)[1].strip() return {"sql": sql_text, "summary": ""} # Handle outputs like "... Summary: ..." (with or without explicit SQL label) summary_idx = lower.rfind("summary:") if summary_idx != -1: sql_section = raw[:summary_idx].strip() summary_text = raw[summary_idx + len("summary:") :].strip() # Remove wrapping quotes if present if summary_text.startswith("\"") and summary_text.endswith("\""): summary_text = summary_text[1:-1] if sql_section.lower().startswith("sql generated:"): sql_section = sql_section[len("sql generated:") :].strip() return {"sql": sql_section, "summary": summary_text} # Final fallback: treat entire content as SQL and leave summary blank. if raw.upper().startswith("SELECT") or "SELECT" in raw.upper(): return {"sql": raw.strip(), "summary": ""} return self._fallback_response( raw=raw, message="LLM output could not be parsed; returning raw text.", ) @staticmethod def _fallback_response(raw: str, message: str) -> Dict[str, Any]: preview = raw[:200] if raw else "(empty response)" logger.warning("%s preview=%s", message, preview) return { "sql": "", "summary": raw, "llm_warning": f"{message} Preview: {preview}", } @staticmethod def _question_mentions_timeframe(question: str) -> bool: lowered = question.lower() if re.search( r"\b(today|yesterday|tonight|this|current|previous|prior|recent|last|past|since|ago|between|before|after|during)\b", lowered, ): return True if re.search( r"\b(day|days|week|weeks|month|months|quarter|quarters|year|years|hour|hours|minute|minutes)\b", lowered, ): return True if re.search(r"\b(january|february|march|april|may|june|july|august|september|october|november|december)\b", lowered): return True if re.search(r"\b(mon|tue|wed|thu|fri|sat|sun)(day)?\b", lowered): return True if re.search(r"\b20\d{2}\b", lowered): return True return False @staticmethod def _strip_unrequested_temporal_filters(sql: str) -> tuple[str, bool]: candidate = sql temporal_regex = r"(DATEADD|DATEDIFF|DATEPART|SYSDATETIME|GETDATE|CURRENT_TIMESTAMP)" changed = False clause_pattern = re.compile( rf"\s+(AND|OR)\s+[^;]*?{temporal_regex}[^;]*?(?=(\bAND\b|\bOR\b|\bGROUP BY\b|\bORDER BY\b|\bHAVING\b|$))", re.IGNORECASE | re.DOTALL, ) candidate, replaced = clause_pattern.subn("", candidate) if replaced: changed = True where_pattern = re.compile( rf"\bWHERE\b\s+[^;]*?{temporal_regex}[^;]*?(?=(\bGROUP BY\b|\bORDER BY\b|\bHAVING\b|$))", re.IGNORECASE | re.DOTALL, ) candidate, replaced = where_pattern.subn("", candidate) if replaced: changed = True dangling_where_pattern = re.compile(r"\bWHERE\b\s*(?=(GROUP BY|ORDER BY|HAVING|$))", re.IGNORECASE) candidate, replaced = dangling_where_pattern.subn("", candidate) if replaced: changed = True if changed: candidate = re.sub(r"\s{2,}", " ", candidate) candidate = re.sub(r"\s+(GROUP BY|ORDER BY|HAVING)\b", r" \1", candidate, flags=re.IGNORECASE) candidate = re.sub(r"\s+;", ";", candidate) return candidate.strip(), changed @staticmethod def _normalize_sql(sql_text: str) -> str: text = sql_text.strip() # Remove markdown fences if still present if text.startswith("```") and text.endswith("```"): inner = text.strip("`") parts = inner.split("\n", 1) text = parts[1] if len(parts) == 2 else parts[0] text = text.rsplit("```", 1)[0] prefix_pattern = re.compile(r"^[\s\.\?\-:\*•]*sql(?:\s+(?:query|statement))?\s*:\s*", re.IGNORECASE) text = prefix_pattern.sub("", text, count=1) alt_prefix_pattern = re.compile(r"^[\s\.\?\-:\*•]*(?:generated sql|query|answer)\s*:\s*", re.IGNORECASE) text = alt_prefix_pattern.sub("", text, count=1) # Drop leading punctuation or bullet artifacts before the SELECT keyword lowered = text.lower() select_idx = lowered.find("select") if select_idx > 0: leading = text[:select_idx] if leading.strip(" \n\t\r:-*•?.") == "": text = text[select_idx:] # Remove any trailing summary section that may still be attached summary_split = re.split(r"(?i)\bsummary\s*:", text, maxsplit=1) if summary_split: text = summary_split[0] return text.strip() @staticmethod def _replace_non_sqlserver_syntax(sql_text: str) -> tuple[str, bool]: changed = False sql = sql_text # Replace ILIKE with LIKE + UPPER for case-insensitive matching def replace_ilike(match: re.Match[str]) -> str: nonlocal changed column = match.group("column") value = match.group("value") if column.startswith('"') and column.endswith('"'): column = f"[{column[1:-1]}]" if value.startswith('"') and value.endswith('"'): value = f"'{value[1:-1]}'" changed = True return f"UPPER({column}) LIKE UPPER({value})" ilike_pattern = re.compile( r"(?P(?:[\w\.\[\]]+|\"[^\"]+\"))\s+ILIKE\s+(?P'[^']*'|\"[^\"]*\"|:[\w_]+)", re.IGNORECASE, ) sql = ilike_pattern.sub(replace_ilike, sql) # Remove NULLS FIRST/LAST clauses nulls_pattern = re.compile(r"\s+NULLS\s+(?:FIRST|LAST)", re.IGNORECASE) if nulls_pattern.search(sql): sql = nulls_pattern.sub("", sql) changed = True return sql.strip(), changed def default_client() -> SqlAgentClient: context = load_context() config = LlmConfig() return SqlAgentClient(context=context, llm_config=config)