-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
168 lines (146 loc) · 6.13 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import requests
import hashlib
import pandas as pd
import streamlit as st
from concurrent.futures import ThreadPoolExecutor
from langchain_community.llms import LlamaCpp
from langchain.prompts.prompt import PromptTemplate
from langchain.sql_database import SQLDatabase
from sqlalchemy import create_engine
import logging
logging.basicConfig(level=logging.ERROR) # Set logging level
def calculate_md5(file_path):
"""Calculates the MD5 hash of a file."""
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def download_file(url, filename, expected_md5):
"""Downloads a file and verifies its integrity."""
try:
response = requests.get(url, stream=True)
if response.status_code == 200:
with open(filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
if calculate_md5(filename) != expected_md5:
st.error("Downloaded file is corrupted. Please try again.")
os.remove(filename)
else:
st.success("Download complete and verified!")
else:
st.error(f"Failed to download file: {response.status_code}")
except Exception as e:
st.error(f"Error downloading file: {e}")
@st.cache_resource(ttl=3600) # Cache the model for an hour
def load_model(model_file):
"""Loads the LlamaCpp model, ensuring it's a valid .gguf file."""
try:
if not model_file.endswith(".gguf"):
st.error("Invalid model file format. Please provide a .gguf file.")
return None
client = LlamaCpp(model_path=model_file, temperature=0)
return client
except Exception as e:
st.error(f"Error loading model: {e}")
return None
def get_database():
try:
db_path = "sqlite:///example.db"
db = SQLDatabase.from_uri(database_uri=db_path)
db._sample_rows_in_table_info = 0
engine = create_engine(db_path)
return db, engine
except Exception as e:
st.error(f"Error connecting to database: {e}")
return None, None
def main():
st.title("SQL Query Interface")
# User guide
with st.expander("User Guide"):
st.write("""
This interface allows you to query an SQL database using natural language.
- Enter your query in the input box and press 'Query' to get the results.
- The tables and their first 5 rows are displayed upon loading the page.
""")
# Retrieve database and engine
db, engine = get_database()
if db and engine:
# Display tables and contents upon page load
table_names = db.get_table_names()
if table_names:
st.write("Tables:")
tabs = st.tabs(table_names)
for tab, table_name in zip(tabs, table_names):
with tab:
st.write(f"Table: {table_name}")
query = f"SELECT * FROM {table_name} LIMIT 5" # Limit to 5 rows for display
try:
with engine.connect() as connection:
df = pd.read_sql_query(query, connection)
st.write(df)
except Exception as e:
st.error(f"Error retrieving data from {table_name}: {e}")
else:
st.write("No tables found in the database.")
question = st.text_area("Enter your query:", value="Courses containing Introduction")
if st.button("Query"):
model_file = "phi-3-sql.Q4_K_M.gguf"
model_url = "https://huggingface.co/omeryentur/phi-3-sql/blob/main/phi-3-sql.Q4_K_M.gguf"
expected_md5 = "d41d8cd98f00b204e9800998ecf8427e" # Replace with the actual MD5 hash of the model file
# Download the model file if it doesn't exist
if not os.path.exists(model_file):
st.write(f"Downloading {model_file}...")
download_file(model_url, model_file, expected_md5)
# Load the model
client = load_model(model_file)
if client:
# Retrieve table info
table_info = db.get_table_info()
# Define the SQL prompt template
template = """
{table_info}
{question}
"""
# Create the prompt with the query
prompt = PromptTemplate.from_template(template)
prompt_text = prompt.format(table_info=table_info, question=question)
try:
# Get SQL query from LLM
res = client(prompt_text)
sql_query = res.strip()
print(prompt_text)
with engine.connect() as connection:
df = pd.read_sql_query(sql_query, connection)
st.write(f"SQL Query: {sql_query}")
st.write("Result:")
st.write(df)
except Exception as e:
st.error(f"Error executing query: {e}")
else:
st.write("Please enter your query and press 'Query' to get results.")
# Add New Data to Database section
st.subheader("Add New Data to Database")
new_data = st.text_area("Enter new data (SQL INSERT statement):", "")
if st.button("Add Data"):
if new_data.strip():
try:
with engine.connect() as connection:
connection.execute(new_data)
st.success("Data added successfully!")
except Exception as e:
st.error(f"Error adding data: {e}")
else:
st.warning("Please enter a valid SQL INSERT statement.")
else:
st.error("Database connection not established.")
# Button to clear cache
if st.button("Clear Cache"):
st.cache_data.clear()
st.cache_resource.clear()
st.success("Cache cleared!")
if __name__ == "__main__":
main()