Newer
Older
TelosDB / src / backend / mcp-handlers.js
import {
    CallToolRequestSchema,
    ListToolsRequestSchema,
} from "@modelcontextprotocol/sdk/types.js";
import { getDb, getEmbeddingDim, getKnexDb } from "./db.js";
import { llamaCompletion, llamaEmbedding } from "./llama-client.js";
import { Logger } from "./logger.js";
import { getToolDefinitions, TOOL_NAMES } from "./mcp-tools.js";

/**
 * ベクトルの次元を検証
 * @param {number[]} vector - 検証するベクトル
 * @throws {Error} 次元が不正な場合
 */
function validateEmbeddingDimension(vector) {
  const expectedDim = getEmbeddingDim();
  if (!Array.isArray(vector) || vector.length !== expectedDim) {
    const error = `埋め込み次元が不正です。期待: ${expectedDim}次元、取得: ${vector?.length || 'undefined'}次元`;
    Logger.error(error);
    throw new Error(error);
  }
}

/**
 * アイテムを追加(テキストから埋め込みを自動生成)
 * @param {string} content - テキスト内容
 * @param {string} path - メタデータパス
 * @returns {Promise<Object>} 処理結果
 */
async function handleAddItemText(content, path) {
  try {
    Logger.debug('Adding item from text', { contentLength: content.length });
    const embedding = await llamaEmbedding(content);
    validateEmbeddingDimension(embedding);

    const db = getDb();
    const knexDb = getKnexDb();

    const insertIds = await knexDb("items").insert({ content, path });
    const id = Array.isArray(insertIds) ? insertIds[0] : insertIds;

    db.prepare("INSERT INTO vec_items(id, embedding) VALUES (?, ?)")
      .run(id, new Float32Array(embedding));

    Logger.info(`Item added successfully with id: ${id}`);
    return { content: [{ type: "text", text: `Added item with id ${id}` }] };
  } catch (error) {
    Logger.error('Failed to add item from text', error);
    throw error;
  }
}

/**
 * アイテムを追加(ベクトルを直接指定)
 * @param {string} content - テキスト内容
 * @param {number[]} vector - 埋め込みベクトル
 * @param {string} path - メタデータパス
 * @returns {Promise<Object>} 処理結果
 */
async function handleAddItem(content, vector, path) {
  try {
    Logger.debug('Adding item with vector', { contentLength: content.length });
    validateEmbeddingDimension(vector);

    const db = getDb();
    const knexDb = getKnexDb();

    const insertIds = await knexDb("items").insert({ content, path });
    const id = Array.isArray(insertIds) ? insertIds[0] : insertIds;

    db.prepare("INSERT INTO vec_items(id, embedding) VALUES (?, ?)")
      .run(id, new Float32Array(vector));

    Logger.info(`Item added successfully with id: ${id}`);
    return { content: [{ type: "text", text: `Added item with id ${id}` }] };
  } catch (error) {
    Logger.error('Failed to add item', error);
    throw error;
  }
}

/**
 * テキストから埋め込みを生成して検索
 * @param {string} content - 検索キーワード
 * @param {number} limit - 結果の上限
 * @returns {Promise<Object>} 検索結果
 */
async function handleSearchText(content, limit = 10) {
  try {
    Logger.debug('Searching by text', { contentLength: content.length, limit });
    const embedding = await llamaEmbedding(content);
    validateEmbeddingDimension(embedding);

    const db = getDb();
    const results = db.prepare(`
      SELECT
        i.id,
        i.content,
        i.path,
        i.created_at,
        i.updated_at,
        v.distance
      FROM vec_items v
      JOIN items i ON v.id = i.id
      WHERE embedding MATCH ?
      ORDER BY distance
      LIMIT ?
    `).all(new Float32Array(embedding), limit);

    Logger.info(`Text search completed, found ${results.length} results`);
    return {
      content: [{ type: "text", text: JSON.stringify(results, null, 2) }],
    };
  } catch (error) {
    Logger.error('Failed to search by text', error);
    throw error;
  }
}

/**
 * ベクトルで直接検索
 * @param {number[]} vector - 検索ベクトル
 * @param {number} limit - 結果の上限
 * @returns {Object} 検索結果
 */
function handleSearchVector(vector, limit = 10) {
  try {
    Logger.debug('Searching by vector', { limit });
    validateEmbeddingDimension(vector);

    const db = getDb();
    const results = db.prepare(`
      SELECT
        i.id,
        i.content,
        i.path,
        i.created_at,
        i.updated_at,
        v.distance
      FROM vec_items v
      JOIN items i ON v.id = i.id
      WHERE embedding MATCH ?
      ORDER BY distance
      LIMIT ?
    `).all(new Float32Array(vector), limit);

    Logger.info(`Vector search completed, found ${results.length} results`);
    return {
      content: [{ type: "text", text: JSON.stringify(results, null, 2) }],
    };
  } catch (error) {
    Logger.error('Failed to search by vector', error);
    throw error;
  }
}

/**
 * LLM で テキスト生成を実行
 * @param {string} prompt - プロンプト
 * @param {number} n_predict - 生成トークン数
 * @param {number} temperature - 温度パラメータ
 * @returns {Promise<Object>} 生成結果
 */
async function handleLlmGenerate(prompt, n_predict, temperature) {
  try {
    Logger.debug('Generating text via LLM', { promptLength: prompt.length, n_predict, temperature });
    const text = await llamaCompletion(prompt, { n_predict, temperature });
    Logger.info(`LLM generation completed, result length: ${text.length}`);
    return { content: [{ type: "text", text }] };
  } catch (error) {
    Logger.error('Failed to generate text via LLM', error);
    throw error;
  }
}

/**
 * MCP ハンドラーを登録
 * @param {Server} server - MCP サーバーインスタンス
 */
export function registerMcpHandlers(server) {
  /**
   * ツール一覧を返すハンドラー
   */
  server.setRequestHandler(ListToolsRequestSchema, async () => {
    Logger.debug('Listing available tools');
    return {
      tools: getToolDefinitions(),
    };
  });

  /**
   * ツール呼び出しハンドラー
   */
  server.setRequestHandler(CallToolRequestSchema, async (request) => {
    const toolName = request.params.name;
    const args = request.params.arguments;

    try {
      Logger.debug(`Tool called: ${toolName}`);

      switch (toolName) {
        case TOOL_NAMES.ADD_ITEM_TEXT:
          return await handleAddItemText(args.content, args.path);

        case TOOL_NAMES.ADD_ITEM:
          return await handleAddItem(args.content, args.vector, args.path);

        case TOOL_NAMES.SEARCH_TEXT:
          return await handleSearchText(args.content, args.limit);

        case TOOL_NAMES.SEARCH_VECTOR:
          return handleSearchVector(args.vector, args.limit);

        case TOOL_NAMES.LLM_GENERATE:
          return await handleLlmGenerate(args.prompt, args.n_predict, args.temperature);

        default:
          const error = `不明なツール: ${toolName}`;
          Logger.error(error);
          throw new Error(error);
      }
    } catch (error) {
      Logger.error(`Tool execution failed: ${toolName}`, error);
      throw error;
    }
  });

  Logger.info('MCP handlers registered successfully');
}