Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cli/internal/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ func ToolNames() []string {
"memory_delete_context",
"memory_reindex",
"memory_status",
"memory_get",
"memory_query",
}
}

Expand Down Expand Up @@ -155,6 +157,12 @@ func tools() []map[string]any {
}, []string{"context"}),
tool("memory_reindex", map[string]any{}, nil),
tool("memory_status", map[string]any{}, nil),
tool("memory_get", map[string]any{
"hash": stringSchema("Content hash from a memory_search result"),
}, []string{"hash"}),
tool("memory_query", map[string]any{
"query": stringSchema("Read-only SELECT statement to run against the memory database"),
}, []string{"query"}),
}
}

Expand Down Expand Up @@ -198,6 +206,10 @@ func (s Server) callTool(ctx context.Context, name string, args map[string]any)
return "ok", memory.Delete(ctx, s.DB, strArg(args, "hash"))
case "memory_delete_context":
return "ok", memory.DeleteContext(ctx, s.DB, strArg(args, "context"))
case "memory_get":
return memory.Get(ctx, s.DB, strArg(args, "hash"))
case "memory_query":
return memory.Query(ctx, s.DB, strArg(args, "query"))
case "memory_reindex":
return "ok", memory.Reindex(ctx, s.DB)
case "memory_status":
Expand Down
2 changes: 2 additions & 0 deletions cli/internal/mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ func TestToolNames(t *testing.T) {
"memory_delete_context": true,
"memory_reindex": true,
"memory_status": true,
"memory_get": true,
"memory_query": true,
}
for _, name := range names {
delete(want, name)
Expand Down
66 changes: 66 additions & 0 deletions cli/internal/memory/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,72 @@ func Status(ctx context.Context, db *sql.DB) (map[string]any, error) {
return out, nil
}

type ContentResult struct {
Hash string `json:"hash"`
Path string `json:"path"`
Context *string `json:"context"`
Value *string `json:"value"`
CreatedAt int64 `json:"created_at"`
LastAccessed int64 `json:"last_accessed"`
}

func Get(ctx context.Context, db *sql.DB, hash string) (string, error) {
var r ContentResult
err := db.QueryRowContext(ctx,
"SELECT hash, path, context, value, created_at, last_accessed FROM dbmem_content WHERE hash = ?",
hash,
).Scan(&r.Hash, &r.Path, &r.Context, &r.Value, &r.CreatedAt, &r.LastAccessed)
if errors.Is(err, sql.ErrNoRows) {
return "", fmt.Errorf("hash not found: %s", hash)
}
if err != nil {
return "", err
}
data, _ := json.MarshalIndent(r, "", " ")
return string(data), nil
}

func Query(ctx context.Context, db *sql.DB, query string) (string, error) {
conn, err := db.Conn(ctx)
if err != nil {
return "", err
}
defer conn.Close()
if _, err := conn.ExecContext(ctx, "PRAGMA query_only = ON"); err != nil {
return "", err
}
rows, err := conn.QueryContext(ctx, query)
if err != nil {
return "", err
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return "", err
}
var results []map[string]any
for rows.Next() {
vals := make([]any, len(cols))
ptrs := make([]any, len(cols))
for i := range vals {
ptrs[i] = &vals[i]
}
if err := rows.Scan(ptrs...); err != nil {
return "", err
}
row := make(map[string]any, len(cols))
for i, col := range cols {
row[col] = vals[i]
}
results = append(results, row)
}
if err := rows.Err(); err != nil {
return "", err
}
data, _ := json.MarshalIndent(results, "", " ")
return string(data), nil
}

func ResultsJSON(results []SearchResult) string {
data, _ := json.MarshalIndent(results, "", " ")
return string(data)
Expand Down
89 changes: 89 additions & 0 deletions cli/internal/memory/memory_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,100 @@
package memory

import (
"context"
"database/sql"
"testing"

_ "github.com/mattn/go-sqlite3"

"github.com/sqliteai/sqlite-memory/cli/internal/config"
)

func openTestDB(t *testing.T) *sql.DB {
t.Helper()
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { db.Close() })
_, err = db.Exec(`CREATE TABLE dbmem_content (
hash TEXT PRIMARY KEY,
path TEXT NOT NULL,
context TEXT,
value TEXT,
created_at INTEGER NOT NULL DEFAULT 0,
last_accessed INTEGER NOT NULL DEFAULT 0
)`)
if err != nil {
t.Fatal(err)
}
return db
}

func TestGetReturnsFullContent(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
_, err := db.ExecContext(ctx, `INSERT INTO dbmem_content (hash, path, context, value, created_at, last_accessed)
VALUES ('abc123', '/docs/test.md', 'test-ctx', 'hello world', 1000, 2000)`)
if err != nil {
t.Fatal(err)
}
out, err := Get(ctx, db, "abc123")
if err != nil {
t.Fatal(err)
}
for _, want := range []string{"abc123", "/docs/test.md", "test-ctx", "hello world"} {
if !contains(out, want) {
t.Errorf("output missing %q:\n%s", want, out)
}
}
}

func TestGetNotFound(t *testing.T) {
db := openTestDB(t)
_, err := Get(context.Background(), db, "nope")
if err == nil {
t.Fatal("expected error for missing hash")
}
}

func TestQuerySelectWorks(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
_, err := db.ExecContext(ctx, `INSERT INTO dbmem_content (hash, path, value, created_at, last_accessed)
VALUES ('h1', '/a.md', 'content', 0, 0)`)
if err != nil {
t.Fatal(err)
}
out, err := Query(ctx, db, "SELECT hash, path FROM dbmem_content")
if err != nil {
t.Fatal(err)
}
if !contains(out, "h1") || !contains(out, "/a.md") {
t.Errorf("unexpected output: %s", out)
}
}

func TestQueryRejectsWrites(t *testing.T) {
db := openTestDB(t)
_, err := Query(context.Background(), db, "INSERT INTO dbmem_content (hash, path, created_at, last_accessed) VALUES ('x', '/x', 0, 0)")
if err == nil {
t.Fatal("expected error for write statement under query_only")
}
}

func contains(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(sub) == 0 ||
func() bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}())
}

func TestResolveModelLocalWithoutAPIKey(t *testing.T) {
cfg := config.Default()
cfg.Embedding.Model = "/models/local.gguf"
Expand Down