diff --git a/mcp_clickhouse/__init__.py b/mcp_clickhouse/__init__.py index 21931d0..f06f828 100644 --- a/mcp_clickhouse/__init__.py +++ b/mcp_clickhouse/__init__.py @@ -3,6 +3,11 @@ list_databases, list_tables, run_select_query, + fetch_table_names, + fetch_table_metadata, + get_paginated_tables, + create_page_token, + table_pagination_cache, ) __all__ = [ @@ -10,4 +15,9 @@ "list_tables", "run_select_query", "create_clickhouse_client", + "fetch_table_names", + "fetch_table_metadata", + "get_paginated_tables", + "create_page_token", + "table_pagination_cache", ] diff --git a/mcp_clickhouse/mcp_env.py b/mcp_clickhouse/mcp_env.py index 40f2291..29f1996 100644 --- a/mcp_clickhouse/mcp_env.py +++ b/mcp_clickhouse/mcp_env.py @@ -133,9 +133,7 @@ def _validate_required_vars(self) -> None: missing_vars.append(var) if missing_vars: - raise ValueError( - f"Missing required environment variables: {', '.join(missing_vars)}" - ) + raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}") # Global instance placeholder for the singleton pattern diff --git a/mcp_clickhouse/mcp_server.py b/mcp_clickhouse/mcp_server.py index d1ec7fb..7273fd6 100644 --- a/mcp_clickhouse/mcp_server.py +++ b/mcp_clickhouse/mcp_server.py @@ -1,12 +1,14 @@ import logging -from typing import Sequence +from typing import Sequence, Dict, Any, Optional, List, TypedDict import concurrent.futures import atexit +import uuid import clickhouse_connect from clickhouse_connect.driver.binding import quote_identifier, format_query_value from dotenv import load_dotenv from mcp.server.fastmcp import FastMCP +from cachetools import TTLCache from mcp_clickhouse.mcp_env import get_config @@ -29,11 +31,91 @@ "python-dotenv", "uvicorn", "pip-system-certs", + "cachetools", ] mcp = FastMCP(MCP_SERVER_NAME, dependencies=deps) +# Define types for table information and pagination cache +class TableInfo(TypedDict): + database: str + name: str + comment: Optional[str] + columns: List[Dict[str, Any]] + create_table_query: str + row_count: int + column_count: int + + +class PaginationCacheEntry(TypedDict): + database: str + like: Optional[str] + table_names: List[str] + start_idx: int + + +# Store pagination state for list_tables with 1-hour expiry +# Using TTLCache from cachetools to automatically expire entries after 1 hour +table_pagination_cache = TTLCache(maxsize=100, ttl=3600) # 3600 seconds = 1 hour + + +def get_table_info( + client, + database: str, + table: str, + table_comments: Dict[str, str], + column_comments: Dict[str, Dict[str, str]], +) -> TableInfo: + """Get detailed information about a specific table. + + Args: + client: ClickHouse client + database: Database name + table: Table name + table_comments: Dictionary of table comments + column_comments: Dictionary of column comments + + Returns: + TableInfo object with table details + """ + logger.info(f"Getting schema info for table {database}.{table}") + schema_query = f"DESCRIBE TABLE {quote_identifier(database)}.{quote_identifier(table)}" + schema_result = client.query(schema_query) + + columns = [] + column_names = schema_result.column_names + for row in schema_result.result_rows: + column_dict = {} + for i, col_name in enumerate(column_names): + column_dict[col_name] = row[i] + # Add comment from our pre-fetched comments + if table in column_comments and column_dict["name"] in column_comments[table]: + column_dict["comment"] = column_comments[table][column_dict["name"]] + else: + column_dict["comment"] = None + columns.append(column_dict) + + # Get row count and column count from the table + row_count_query = f"SELECT count() FROM {quote_identifier(database)}.{quote_identifier(table)}" + row_count_result = client.query(row_count_query) + row_count = row_count_result.result_rows[0][0] if row_count_result.result_rows else 0 + column_count = len(columns) + + create_table_query = f"SHOW CREATE TABLE {database}.`{table}`" + create_table_result = client.command(create_table_query) + + return { + "database": database, + "name": table, + "comment": table_comments.get(table), + "columns": columns, + "create_table_query": create_table_result, + "row_count": row_count, + "column_count": column_count, + } + + @mcp.tool() def list_databases(): """List available ClickHouse databases""" @@ -44,24 +126,57 @@ def list_databases(): return result -@mcp.tool() -def list_tables(database: str, like: str = None): - """List available ClickHouse tables in a database, including schema, comment, - row count, and column count.""" - logger.info(f"Listing tables in database '{database}'") - client = create_clickhouse_client() +def fetch_table_names(client, database: str, like: str = None) -> List[str]: + """Get list of table names from the database. + + Args: + client: ClickHouse client + database: Database name + like: Optional pattern to filter table names + + Returns: + List of table names + """ query = f"SHOW TABLES FROM {quote_identifier(database)}" if like: query += f" LIKE {format_query_value(like)}" result = client.command(query) - # Get all table comments in one query - table_comments_query = f"SELECT name, comment FROM system.tables WHERE database = {format_query_value(database)}" + # Convert result to a list of table names + table_names = [] + if isinstance(result, str): + # Single table result + table_names = [t.strip() for t in result.split() if t.strip()] + elif isinstance(result, Sequence): + # Multiple table results + table_names = list(result) + + return table_names + + +def fetch_table_metadata( + client, database: str, table_names: List[str] +) -> tuple[Dict[str, str], Dict[str, Dict[str, str]]]: + """Fetch table and column comments for a list of tables. + + Args: + client: ClickHouse client + database: Database name + table_names: List of table names + + Returns: + Tuple of (table_comments, column_comments) + """ + if not table_names: + return {}, {} + + # Get table comments + table_comments_query = f"SELECT name, comment FROM system.tables WHERE database = {format_query_value(database)} AND name IN ({', '.join(format_query_value(name) for name in table_names)})" table_comments_result = client.query(table_comments_query) table_comments = {row[0]: row[1] for row in table_comments_result.result_rows} - # Get all column comments in one query - column_comments_query = f"SELECT table, name, comment FROM system.columns WHERE database = {format_query_value(database)}" + # Get column comments + column_comments_query = f"SELECT table, name, comment FROM system.columns WHERE database = {format_query_value(database)} AND table IN ({', '.join(format_query_value(name) for name in table_names)})" column_comments_result = client.query(column_comments_query) column_comments = {} for row in column_comments_result.result_rows: @@ -70,56 +185,127 @@ def list_tables(database: str, like: str = None): column_comments[table] = {} column_comments[table][col_name] = comment - def get_table_info(table): - logger.info(f"Getting schema info for table {database}.{table}") - schema_query = f"DESCRIBE TABLE {quote_identifier(database)}.{quote_identifier(table)}" - schema_result = client.query(schema_query) + return table_comments, column_comments - columns = [] - column_names = schema_result.column_names - for row in schema_result.result_rows: - column_dict = {} - for i, col_name in enumerate(column_names): - column_dict[col_name] = row[i] - # Add comment from our pre-fetched comments - if table in column_comments and column_dict['name'] in column_comments[table]: - column_dict['comment'] = column_comments[table][column_dict['name']] - else: - column_dict['comment'] = None - columns.append(column_dict) - - # Get row count and column count from the table - row_count_query = f"SELECT count() FROM {quote_identifier(database)}.{quote_identifier(table)}" - row_count_result = client.query(row_count_query) - row_count = row_count_result.result_rows[0][0] if row_count_result.result_rows else 0 - column_count = len(columns) - - create_table_query = f"SHOW CREATE TABLE {database}.`{table}`" - create_table_result = client.command(create_table_query) - - return { - "database": database, - "name": table, - "comment": table_comments.get(table), - "columns": columns, - "create_table_query": create_table_result, - "row_count": row_count, - "column_count": column_count, - } +def get_paginated_tables( + client, database: str, table_names: List[str], start_idx: int, page_size: int +) -> tuple[List[TableInfo], int, bool]: + """Get detailed information for a page of tables. + + Args: + client: ClickHouse client + database: Database name + table_names: List of all table names + start_idx: Starting index for pagination + page_size: Number of tables per page + + Returns: + Tuple of (list of table info, end index, has more pages) + """ + end_idx = min(start_idx + page_size, len(table_names)) + current_page_table_names = table_names[start_idx:end_idx] + + # Get metadata for current page + table_comments, column_comments = fetch_table_metadata( + client, database, current_page_table_names + ) + + # Get detailed information for each table tables = [] - if isinstance(result, str): - # Single table result - for table in (t.strip() for t in result.split()): - if table: - tables.append(get_table_info(table)) - elif isinstance(result, Sequence): - # Multiple table results - for table in result: - tables.append(get_table_info(table)) + for table_name in current_page_table_names: + tables.append(get_table_info(client, database, table_name, table_comments, column_comments)) + + return tables, end_idx, end_idx < len(table_names) + + +def create_page_token(database: str, like: str, table_names: List[str], end_idx: int) -> str: + """Create a new page token and store it in the cache. + + Args: + database: Database name + like: Pattern used to filter tables + table_names: List of all table names + end_idx: Index to start from for the next page + + Returns: + New page token + """ + token = str(uuid.uuid4()) + table_pagination_cache[token] = { + "database": database, + "like": like, + "table_names": table_names, + "start_idx": end_idx, + } + return token + + +@mcp.tool() +def list_tables(database: str, like: str = None, page_token: str = None, page_size: int = 50): + """List available ClickHouse tables in a database, including schema, comment, + row count, and column count. - logger.info(f"Found {len(tables)} tables") - return tables + Args: + database: The database to list tables from + like: Optional pattern to filter table names + page_token: Token for pagination, obtained from a previous call + page_size: Number of tables to return per page (default: 50) + + Returns: + A dictionary containing: + - tables: List of table information + - next_page_token: Token for the next page, or None if no more pages + """ + logger.info( + f"Listing tables in database '{database}' with page_token={page_token}, page_size={page_size}" + ) + client = create_clickhouse_client() + + # If we have a page token, retrieve the cached state + if page_token and page_token in table_pagination_cache: + cached_state = table_pagination_cache[page_token] + if cached_state["database"] != database or cached_state["like"] != like: + logger.warning(f"Page token {page_token} is for a different database or filter") + page_token = None + else: + # Use the cached state + table_names = cached_state["table_names"] + start_idx = cached_state["start_idx"] + + # Get tables for current page + tables, end_idx, has_more = get_paginated_tables( + client, database, table_names, start_idx, page_size + ) + + # Generate next page token if there are more tables + next_page_token = None + if has_more: + next_page_token = create_page_token(database, like, table_names, end_idx) + + # Clean up the used page token + del table_pagination_cache[page_token] + + return {"tables": tables, "next_page_token": next_page_token} + + # If no valid page token, fetch all table names + table_names = fetch_table_names(client, database, like) + + # Apply pagination + start_idx = 0 + tables, end_idx, has_more = get_paginated_tables( + client, database, table_names, start_idx, page_size + ) + + # Generate next page token if there are more tables + next_page_token = None + if has_more: + next_page_token = create_page_token(database, like, table_names, end_idx) + + logger.info( + f"Found {len(table_names)} tables, returning {len(tables)} with next_page_token={next_page_token}" + ) + return {"tables": tables, "next_page_token": next_page_token} def execute_query(query: str): @@ -161,7 +347,10 @@ def run_select_query(query: str): logger.warning(f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds: {query}") future.cancel() # Return a properly structured response for timeout errors - return {"status": "error", "message": f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds"} + return { + "status": "error", + "message": f"Query timed out after {SELECT_QUERY_TIMEOUT_SECS} seconds", + } except Exception as e: logger.error(f"Unexpected error in run_select_query: {str(e)}") # Catch all other exceptions and return them in a structured format diff --git a/pyproject.toml b/pyproject.toml index f8dd5d0..72e7664 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "uvicorn>=0.34.0", "clickhouse-connect>=0.8.16", "pip-system-certs>=4.0", + "cachetools>=5.5.2", ] [project.scripts] diff --git a/tests/test_pagination.py b/tests/test_pagination.py new file mode 100644 index 0000000..784073a --- /dev/null +++ b/tests/test_pagination.py @@ -0,0 +1,205 @@ +import unittest + +from dotenv import load_dotenv + +from mcp_clickhouse import create_clickhouse_client, list_tables +from mcp_clickhouse.mcp_server import ( + table_pagination_cache, + create_page_token, + fetch_table_names, + fetch_table_metadata, + get_paginated_tables, +) + +load_dotenv() + + +class TestPagination(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Set up the environment before tests.""" + cls.client = create_clickhouse_client() + + # Prepare test database + cls.test_db = "test_pagination_db" + cls.client.command(f"CREATE DATABASE IF NOT EXISTS {cls.test_db}") + + # Create 10 test tables to test pagination + for i in range(1, 11): + table_name = f"test_table_{i}" + # Drop table if exists to ensure clean state + cls.client.command(f"DROP TABLE IF EXISTS {cls.test_db}.{table_name}") + + # Create table with comments + cls.client.command(f""" + CREATE TABLE {cls.test_db}.{table_name} ( + id UInt32 COMMENT 'ID field {i}', + name String COMMENT 'Name field {i}' + ) ENGINE = MergeTree() + ORDER BY id + COMMENT 'Test table {i} for pagination testing' + """) + cls.client.command(f""" + INSERT INTO {cls.test_db}.{table_name} (id, name) VALUES ({i}, 'Test {i}') + """) + + @classmethod + def tearDownClass(cls): + """Clean up the environment after tests.""" + cls.client.command(f"DROP DATABASE IF EXISTS {cls.test_db}") + + def test_list_tables_pagination(self): + """Test that list_tables returns paginated results.""" + # Test with page_size 3, should get 3 tables and a next_page_token + result = list_tables(self.test_db, page_size=3) + self.assertIsInstance(result, dict) + self.assertIn("tables", result) + self.assertIn("next_page_token", result) + self.assertEqual(len(result["tables"]), 3) + self.assertIsNotNone(result["next_page_token"]) + + # Get the next page using the token + page_token = result["next_page_token"] + result2 = list_tables(self.test_db, page_token=page_token, page_size=3) + self.assertEqual(len(result2["tables"]), 3) + self.assertIsNotNone(result2["next_page_token"]) + + # The tables in the second page should be different from the first page + page1_table_names = {table["name"] for table in result["tables"]} + page2_table_names = {table["name"] for table in result2["tables"]} + self.assertEqual(len(page1_table_names.intersection(page2_table_names)), 0) + + # Get the third page + page_token = result2["next_page_token"] + result3 = list_tables(self.test_db, page_token=page_token, page_size=3) + self.assertEqual(len(result3["tables"]), 3) + self.assertIsNotNone(result3["next_page_token"]) + + # Get the fourth (last) page + page_token = result3["next_page_token"] + result4 = list_tables(self.test_db, page_token=page_token, page_size=3) + self.assertEqual(len(result4["tables"]), 1) # Only 1 table left + self.assertIsNone(result4["next_page_token"]) # No more pages + + def test_invalid_page_token(self): + """Test that list_tables handles invalid page tokens gracefully.""" + # Test with an invalid page token + result = list_tables(self.test_db, page_token="invalid_token", page_size=3) + self.assertIsInstance(result, dict) + self.assertIn("tables", result) + self.assertIn("next_page_token", result) + self.assertEqual(len(result["tables"]), 3) # Should return first page as fallback + + def test_token_for_different_database(self): + """Test handling a token for a different database.""" + # Get first page and token for test_db + result = list_tables(self.test_db, page_size=3) + page_token = result["next_page_token"] + + # Try to use the token with a different database name + # It should recognize the mismatch and fall back to first page + + # First, create another test database to use + test_db2 = "test_pagination_db2" + try: + self.client.command(f"CREATE DATABASE IF NOT EXISTS {test_db2}") + self.client.command(f""" + CREATE TABLE {test_db2}.test_table ( + id UInt32, + name String + ) ENGINE = MergeTree() + ORDER BY id + """) + + # Use the token with a different database + result2 = list_tables(test_db2, page_token=page_token, page_size=3) + self.assertIsInstance(result2, dict) + self.assertIn("tables", result2) + finally: + self.client.command(f"DROP DATABASE IF EXISTS {test_db2}") + + def test_different_page_sizes(self): + """Test pagination with different page sizes.""" + # Get all tables in one page + result = list_tables(self.test_db, page_size=20) + self.assertEqual(len(result["tables"]), 10) # All 10 tables + self.assertIsNone(result["next_page_token"]) # No more pages + + # Get 5 tables per page + result = list_tables(self.test_db, page_size=5) + self.assertEqual(len(result["tables"]), 5) + self.assertIsNotNone(result["next_page_token"]) + + # Get second page with 5 tables + page_token = result["next_page_token"] + result2 = list_tables(self.test_db, page_token=page_token, page_size=5) + self.assertEqual(len(result2["tables"]), 5) + self.assertIsNone(result2["next_page_token"]) # No more pages + + def test_page_token_expiry(self): + """Test that page tokens expire after their TTL.""" + # This test only works if we can modify the TTL for testing purposes + # We'll set a very short TTL to test expiration + + # Get the first page with a token for the next page + result = list_tables(self.test_db, page_size=3) + page_token = result["next_page_token"] + + # Verify the token exists in the cache + self.assertIn(page_token, table_pagination_cache) + + # For this test, we'll manually remove the token from the cache to simulate expiration + # since we can't easily wait for the actual TTL (1 hour) to expire + if page_token in table_pagination_cache: + del table_pagination_cache[page_token] + + # Try to use the expired token + result2 = list_tables(self.test_db, page_token=page_token, page_size=3) + # Should fall back to first page + self.assertEqual(len(result2["tables"]), 3) + self.assertIsNotNone(result2["next_page_token"]) + + def test_helper_functions(self): + """Test the individual helper functions used for pagination.""" + client = create_clickhouse_client() + + # Test fetch_table_names + table_names = fetch_table_names(client, self.test_db) + self.assertEqual(len(table_names), 10) + for i in range(1, 11): + self.assertIn(f"test_table_{i}", table_names) + + # Test fetch_table_metadata + sample_tables = table_names[:3] # Get first 3 tables + table_comments, column_comments = fetch_table_metadata(client, self.test_db, sample_tables) + + # Check table comments + self.assertEqual(len(table_comments), 3) + for table in sample_tables: + self.assertIn(table, table_comments) + self.assertIn("Test table", table_comments[table]) + + # Check column comments + self.assertEqual(len(column_comments), 3) + for table in sample_tables: + self.assertIn(table, column_comments) + self.assertIn("id", column_comments[table]) + self.assertIn("name", column_comments[table]) + + # Test get_paginated_tables + tables, end_idx, has_more = get_paginated_tables(client, self.test_db, table_names, 0, 3) + self.assertEqual(len(tables), 3) + self.assertEqual(end_idx, 3) + self.assertTrue(has_more) + + # Test create_page_token + token = create_page_token(self.test_db, None, table_names, 3) + self.assertIn(token, table_pagination_cache) + cached_state = table_pagination_cache[token] + self.assertEqual(cached_state["database"], self.test_db) + self.assertEqual(cached_state["start_idx"], 3) + self.assertEqual(cached_state["table_names"], table_names) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tool.py b/tests/test_tool.py index e931c3f..fd651fd 100644 --- a/tests/test_tool.py +++ b/tests/test_tool.py @@ -47,16 +47,25 @@ def test_list_databases(self): def test_list_tables_without_like(self): """Test listing tables without a 'LIKE' filter.""" result = list_tables(self.test_db) - self.assertIsInstance(result, list) - self.assertEqual(len(result), 1) - self.assertEqual(result[0]["name"], self.test_table) + self.assertIsInstance(result, dict) + self.assertIn("tables", result) + + tables = result["tables"] + self.assertEqual(len(tables), 1) + self.assertEqual(tables[0]["name"], self.test_table) + + # Since we only have one table, there should be no next page token + self.assertIsNone(result["next_page_token"]) def test_list_tables_with_like(self): """Test listing tables with a 'LIKE' filter.""" result = list_tables(self.test_db, like=f"{self.test_table}%") - self.assertIsInstance(result, list) - self.assertEqual(len(result), 1) - self.assertEqual(result[0]["name"], self.test_table) + self.assertIsInstance(result, dict) + self.assertIn("tables", result) + + tables = result["tables"] + self.assertEqual(len(tables), 1) + self.assertEqual(tables[0]["name"], self.test_table) def test_run_select_query_success(self): """Test running a SELECT query successfully.""" @@ -78,10 +87,13 @@ def test_run_select_query_failure(self): def test_table_and_column_comments(self): """Test that table and column comments are correctly retrieved.""" result = list_tables(self.test_db) - self.assertIsInstance(result, list) - self.assertEqual(len(result), 1) + self.assertIsInstance(result, dict) + self.assertIn("tables", result) + + tables = result["tables"] + self.assertEqual(len(tables), 1) - table_info = result[0] + table_info = tables[0] # Verify table comment self.assertEqual(table_info["comment"], "Test table for unit testing") diff --git a/uv.lock b/uv.lock index a41395f..a899a3c 100644 --- a/uv.lock +++ b/uv.lock @@ -24,6 +24,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/7a/4daaf3b6c08ad7ceffea4634ec206faeff697526421c20f07628c7372156/anyio-4.7.0-py3-none-any.whl", hash = "sha256:ea60c3723ab42ba6fff7e8ccb0488c898ec538ff4df1f1d5e642c3601d07e352", size = 93052 }, ] +[[package]] +name = "cachetools" +version = "5.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080 }, +] + [[package]] name = "certifi" version = "2024.12.14" @@ -210,9 +219,10 @@ cli = [ [[package]] name = "mcp-clickhouse" -version = "0.1.3" +version = "0.1.6" source = { editable = "." } dependencies = [ + { name = "cachetools" }, { name = "clickhouse-connect" }, { name = "mcp", extra = ["cli"] }, { name = "pip-system-certs" }, @@ -228,6 +238,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "cachetools", specifier = ">=5.5.2" }, { name = "clickhouse-connect", specifier = ">=0.8.16" }, { name = "mcp", extras = ["cli"], specifier = ">=1.3.0" }, { name = "pip-system-certs", specifier = ">=4.0" },