-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathembedding_workflow.py
47 lines (35 loc) · 1.37 KB
/
embedding_workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import psycopg2
import subprocess
# Database configuration
DB_NAME = "wikivi"
# Connect to the PostgreSQL database
conn = psycopg2.connect(database=DB_NAME)
cursor = conn.cursor()
# Fetch articles from the wikipedia_data table
cursor.execute("SELECT id, parsed_content FROM wikipedia_data;")
articles = cursor.fetchall()
for article_id, content in articles:
# Tokenize the content to get the first 512 tokens
tokens = content.split()[:512]
tokenized_content = " ".join(tokens)
# Save the tokenized content to a temporary file
with open("temp.txt", "w") as f:
f.write(tokenized_content)
# Use llama.cpp to generate embeddings for the tokenized content
subprocess.run(["./llama", "embed", "-i", "temp.txt", "-o", "output.vec", "-m", "path_to_openllama_3B_model"])
# Read the generated embeddings
with open("output.vec", "r") as f:
embedding = f.read()
# Store the embedding in the wikipedia_data table
cursor.execute("UPDATE wikipedia_data SET embedding_column = %s WHERE id = %s;", (embedding, article_id))
# Apply pgvector on the embedding column
cursor.execute("SELECT setvector(embedding_column) AS vector_output FROM wikipedia_data;")
# Commit the changes to the database
conn.commit()
# Close the database connection
cursor.close()
conn.close()
# Cleanup temporary files
os.remove("temp.txt")
os.remove("output.vec")