-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpostgres_connector.py
296 lines (256 loc) · 12.8 KB
/
postgres_connector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
import os
import psycopg2
from dotenv import load_dotenv
class PostgresConnector:
def __init__(self):
load_dotenv()
self.db_name = os.getenv('DB_NAME')
self.db_user = os.getenv('DB_USER')
self.db_password = os.getenv('DB_PASSWORD')
self.connection = None
def connect(self):
"""Establishes a connection to the PostgreSQL database"""
try:
self.connection = psycopg2.connect(
dbname=self.db_name,
user=self.db_user,
password=self.db_password,
host='localhost' # assuming local PostgreSQL instance
)
return self.connection
except psycopg2.Error as e:
print(f"Error connecting to PostgreSQL database: {e}")
raise
def execute_query(self, query, params=None):
"""Executes a SQL query and returns the results"""
try:
if not self.connection or self.connection.closed:
self.connect()
with self.connection.cursor() as cursor:
cursor.execute(query, params)
if cursor.description: # If the query returns data
columns = [desc[0] for desc in cursor.description]
results = cursor.fetchall()
self.connection.commit() # Commit successful queries
return columns, results
else:
self.connection.commit()
return None, None
except psycopg2.Error as e:
self.connection.rollback() # Rollback on error
print(f"Error executing query: {e}")
raise
def get_table_ddl(self, schema, table):
"""Gets the DDL for a table"""
try:
# Get column information including constraints
ddl_query = """
WITH columns AS (
SELECT
column_name,
data_type,
CASE
WHEN character_maximum_length IS NOT NULL THEN '(' || character_maximum_length || ')'
WHEN datetime_precision IS NOT NULL THEN '(' || datetime_precision || ')'
WHEN numeric_precision IS NOT NULL AND numeric_scale IS NOT NULL
THEN '(' || numeric_precision || ',' || numeric_scale || ')'
ELSE ''
END as type_length,
is_nullable,
column_default,
ordinal_position
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
),
constraints AS (
SELECT
tc.constraint_name,
tc.constraint_type,
STRING_AGG(kcu.column_name, ', ') as columns,
tc.table_name,
ccu.table_name as foreign_table_name,
ccu.table_schema as foreign_table_schema
FROM information_schema.table_constraints tc
JOIN information_schema.key_column_usage kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
LEFT JOIN information_schema.constraint_column_usage ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.table_schema = %s
AND tc.table_name = %s
GROUP BY tc.constraint_name, tc.constraint_type, tc.table_name,
ccu.table_name, ccu.table_schema
)
SELECT
c.column_name,
c.data_type || c.type_length as full_data_type,
c.is_nullable,
c.column_default,
c.ordinal_position,
con.constraint_type,
con.constraint_name,
con.columns as constraint_columns,
con.foreign_table_name,
con.foreign_table_schema
FROM columns c
LEFT JOIN constraints con ON con.columns LIKE '%%' || c.column_name || '%%'
ORDER BY c.ordinal_position;
"""
_, ddl_info = self.execute_query(ddl_query, (schema, table, schema, table))
# Start building the CREATE TABLE statement
ddl_lines = [f"CREATE TABLE {schema}.{table} ("]
# Track processed columns to avoid duplicates
processed_columns = set()
# Track constraints to add after columns
constraints = []
# Process columns and their direct constraints
for row in ddl_info:
col_name, data_type, nullable, default, pos, con_type, con_name, con_cols, f_table, f_schema = row
# Skip if we've already processed this column
if col_name in processed_columns:
continue
# Build column definition
col_def = f" {col_name} {data_type}"
if nullable == 'NO':
col_def += " NOT NULL"
if default:
col_def += f" DEFAULT {default}"
ddl_lines.append(col_def + ",")
processed_columns.add(col_name)
# Collect constraints
if con_type and con_name and con_name not in [c[0] for c in constraints]:
if con_type == 'PRIMARY KEY':
constraints.append((con_name, f" CONSTRAINT {con_name} PRIMARY KEY ({con_cols})"))
elif con_type == 'FOREIGN KEY':
fk_def = f" CONSTRAINT {con_name} FOREIGN KEY ({con_cols}) "
fk_def += f"REFERENCES {f_schema}.{f_table} ({con_cols})"
constraints.append((con_name, fk_def))
elif con_type == 'UNIQUE':
constraints.append((con_name, f" CONSTRAINT {con_name} UNIQUE ({con_cols})"))
# Remove trailing comma from last column definition
if ddl_lines[-1].endswith(","):
ddl_lines[-1] = ddl_lines[-1][:-1]
# Add constraints
if constraints:
ddl_lines[-1] += "," # Add back comma to last column
for _, constraint in constraints:
ddl_lines.append(constraint + ",")
# Remove trailing comma from last constraint
ddl_lines[-1] = ddl_lines[-1][:-1]
ddl_lines.append(");")
# Get indexes (excluding those created by constraints)
index_query = """
SELECT indexdef
FROM pg_indexes
WHERE schemaname = %s
AND tablename = %s
AND indexdef NOT LIKE '%%constraint%%';
"""
_, indexes = self.execute_query(index_query, (schema, table))
if indexes:
ddl_lines.append("\n-- Indexes:")
for idx in indexes:
ddl_lines.append(idx[0] + ";")
return "\n".join(ddl_lines)
except psycopg2.Error as e:
print(f"Error getting table DDL: {e}")
return f"-- Error getting DDL for {schema}.{table}: {str(e)}"
def fetch_table_data(self):
"""
Fetches comprehensive information about all tables in the database and returns it in markdown format.
Returns:
str: Markdown formatted string containing table information
"""
markdown_output = []
try:
# Add introductory section
markdown_output.append("# Database Schema Documentation")
markdown_output.append(f"## Overview")
markdown_output.append(f"This document provides a comprehensive overview of the database `{self.db_name}`. "
"It includes detailed information about each table, including structure, sample data, and metadata.")
markdown_output.append("")
# Get list of all tables in the database with basic info
table_query = """
SELECT
schemaname,
tablename,
tableowner,
pg_size_pretty(pg_total_relation_size(quote_ident(schemaname) || '.' || quote_ident(tablename))) as size
FROM pg_tables
WHERE schemaname NOT IN ('pg_catalog', 'information_schema')
ORDER BY schemaname, tablename;
"""
_, tables = self.execute_query(table_query)
for schema, table, owner, size in tables:
markdown_output.append(f"## Table: {schema}.{table}")
markdown_output.append(f"**Database:** {self.db_name}")
markdown_output.append(f"**Owner:** {owner}")
markdown_output.append(f"**Size:** {size}")
markdown_output.append("")
# Get column information
column_query = """
SELECT
column_name,
data_type,
is_nullable,
column_default
FROM information_schema.columns
WHERE table_schema = %s AND table_name = %s
ORDER BY ordinal_position;
"""
markdown_output.append("### Columns")
markdown_output.append("| Column | Type | Nullable | Default |")
markdown_output.append("|--------|------|----------|----------|")
_, columns = self.execute_query(column_query, (schema, table))
for col_name, data_type, nullable, default in columns:
markdown_output.append(f"| {col_name} | {data_type} | {nullable} | {default or 'NULL'} |")
markdown_output.append("")
# Get table data
data_query = f'SELECT * FROM "{schema}"."{table}" LIMIT 10;'
cols, rows = self.execute_query(data_query)
if rows:
markdown_output.append("### Sample Data (First 10 rows)")
markdown_output.append("| " + " | ".join(cols) + " |")
markdown_output.append("|" + "|".join(["---"] * len(cols)) + "|")
for row in rows:
markdown_output.append("| " + " | ".join(str(val) for val in row) + " |")
markdown_output.append("")
# Get CREATE TABLE statement using our new method
ddl = self.get_table_ddl(schema, table)
markdown_output.append("### Create Table SQL")
markdown_output.append("```sql")
markdown_output.append(ddl)
markdown_output.append("```")
# Get basic table metadata
metadata_query = """
SELECT
c.reltuples::bigint as estimated_row_count
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = %s AND c.relname = %s;
"""
try:
_, metadata = self.execute_query(metadata_query, (schema, table))
if metadata and metadata[0][0] is not None: # Check if we got valid metadata
row_count = metadata[0][0]
markdown_output.append("\n### Additional Metadata")
markdown_output.append(f"- **Estimated Row Count:** {row_count}")
except psycopg2.Error:
pass # Skip metadata if the query fails
markdown_output.append("\n---\n")
return "\n".join(markdown_output)
except psycopg2.Error as e:
print(f"Error fetching table data: {e}")
raise
def close(self):
"""Closes the database connection"""
if self.connection:
self.connection.close()
def __enter__(self):
"""Context manager entry"""
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
self.close()