135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
"""
|
|
Memory Vector Search Module
|
|
Uses sqlite-vector extension for local embeddings.
|
|
"""
|
|
|
|
import sqlite3
|
|
import struct
|
|
import json
|
|
from typing import List, Tuple, Optional
|
|
|
|
db_path = r"C:\Users\admin\.openclaw\memory.db"
|
|
dll_path = r"C:\Users\admin\AppData\Local\Programs\Python\Python313\Lib\site-packages\sqlite_vector\binaries\vector.dll"
|
|
|
|
DIMENSION = 768
|
|
|
|
|
|
class MemoryVectorDB:
|
|
def __init__(self):
|
|
self.conn = sqlite3.connect(db_path)
|
|
self.conn.enable_load_extension(True)
|
|
self.conn.load_extension(dll_path)
|
|
self.cursor = self.conn.cursor()
|
|
# Always init vector on connection
|
|
self._ensure_vector_init()
|
|
|
|
def _ensure_vector_init(self):
|
|
"""Make sure vector column is initialized."""
|
|
try:
|
|
self.cursor.execute('SELECT vector_init(?, ?, ?)',
|
|
('memory_embeddings', 'embedding', f'type=FLOAT32,dimension={DIMENSION}'))
|
|
except sqlite3.OperationalError as e:
|
|
# Already initialized is OK
|
|
if 'already initialized' not in str(e).lower():
|
|
raise
|
|
|
|
def close(self):
|
|
self.conn.close()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
self.close()
|
|
|
|
def setup(self):
|
|
"""Create table. Safe to run multiple times."""
|
|
self.cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS memory_embeddings (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
source_type TEXT,
|
|
source_path TEXT,
|
|
content_text TEXT,
|
|
embedding BLOB,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
''')
|
|
self.conn.commit()
|
|
return True
|
|
|
|
def insert(self, source_type: str, source_path: str, content_text: str,
|
|
embedding: List[float]):
|
|
"""Store a memory with its embedding."""
|
|
vec_json = json.dumps(embedding)
|
|
self.cursor.execute('''
|
|
INSERT INTO memory_embeddings (source_type, source_path, content_text, embedding)
|
|
VALUES (?, ?, ?, vector_as_f32(?))
|
|
''', (source_type, source_path, content_text, vec_json))
|
|
self.conn.commit()
|
|
return self.cursor.lastrowid
|
|
|
|
def quantize(self):
|
|
"""Quantize for fast search. Call after batch inserts."""
|
|
self.cursor.execute("SELECT vector_quantize('memory_embeddings', 'embedding')")
|
|
self.conn.commit()
|
|
self.cursor.execute("SELECT vector_quantize_preload('memory_embeddings', 'embedding')")
|
|
self.conn.commit()
|
|
|
|
def search(self, query_embedding: List[float], k: int = 5,
|
|
source_type: Optional[str] = None):
|
|
"""Search similar memories. Must quantize() first!"""
|
|
query_blob = b''.join([struct.pack('f', f) for f in query_embedding])
|
|
|
|
if source_type:
|
|
self.cursor.execute('''
|
|
SELECT e.source_path, e.content_text, v.distance
|
|
FROM memory_embeddings AS e
|
|
JOIN vector_quantize_scan('memory_embeddings', 'embedding', ?, ?) AS v
|
|
ON e.id = v.rowid
|
|
WHERE e.source_type = ?
|
|
''', (query_blob, k, source_type))
|
|
else:
|
|
self.cursor.execute('''
|
|
SELECT e.source_path, e.content_text, v.distance
|
|
FROM memory_embeddings AS e
|
|
JOIN vector_quantize_scan('memory_embeddings', 'embedding', ?, ?) AS v
|
|
ON e.id = v.rowid
|
|
''', (query_blob, k))
|
|
|
|
return self.cursor.fetchall()
|
|
|
|
|
|
def setup_memory_vectors():
|
|
"""One-time setup."""
|
|
with MemoryVectorDB() as db:
|
|
return db.setup()
|
|
|
|
|
|
def store_memory(source_type: str, source_path: str, content: str,
|
|
embedding: List[float]):
|
|
"""Store a memory."""
|
|
with MemoryVectorDB() as db:
|
|
rowid = db.insert(source_type, source_path, content, embedding)
|
|
db.quantize()
|
|
return rowid
|
|
|
|
|
|
def search_memories(query_embedding: List[float], k: int = 5):
|
|
"""Search memories."""
|
|
with MemoryVectorDB() as db:
|
|
return db.search(query_embedding, k)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
setup_memory_vectors()
|
|
print("[OK] Memory vector DB ready")
|
|
|
|
zero_vec = [0.0] * DIMENSION
|
|
store_memory("test", "test.txt", "Hello world", zero_vec)
|
|
print("[OK] Test memory stored")
|
|
|
|
results = search_memories(zero_vec, k=1)
|
|
print(f"[OK] Found {len(results)} result(s)")
|
|
for r in results:
|
|
print(f" - {r}")
|