|
16 | 16 | # This is a simple standalone implementation showing rag pipeline using Nvidia AI Foundational models.
|
17 | 17 | # It uses a simple Streamlit UI and one file implementation of a minimalistic RAG pipeline.
|
18 | 18 |
|
19 |
| -############################################ |
20 |
| -# Component #1 - Document Loader |
21 |
| -############################################ |
22 |
| - |
23 | 19 | import streamlit as st
|
24 | 20 | import os
|
| 21 | +from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings |
| 22 | +from langchain.text_splitter import CharacterTextSplitter |
| 23 | +from langchain_community.document_loaders import DirectoryLoader |
| 24 | +from langchain_community.vectorstores import FAISS |
| 25 | +import pickle |
| 26 | +from langchain_core.output_parsers import StrOutputParser |
| 27 | +from langchain_core.prompts import ChatPromptTemplate |
25 | 28 |
|
26 |
| -st.set_page_config(layout = "wide") |
| 29 | +st.set_page_config(layout="wide") |
27 | 30 |
|
| 31 | +# Component #1 - Document Upload |
28 | 32 | with st.sidebar:
|
29 | 33 | DOCS_DIR = os.path.abspath("./uploaded_docs")
|
30 | 34 | if not os.path.exists(DOCS_DIR):
|
31 | 35 | os.makedirs(DOCS_DIR)
|
32 | 36 | st.subheader("Add to the Knowledge Base")
|
33 | 37 | with st.form("my-form", clear_on_submit=True):
|
34 |
| - uploaded_files = st.file_uploader("Upload a file to the Knowledge Base:", accept_multiple_files = True) |
| 38 | + uploaded_files = st.file_uploader("Upload a file to the Knowledge Base:", accept_multiple_files=True) |
35 | 39 | submitted = st.form_submit_button("Upload!")
|
36 | 40 |
|
37 | 41 | if uploaded_files and submitted:
|
38 | 42 | for uploaded_file in uploaded_files:
|
39 | 43 | st.success(f"File {uploaded_file.name} uploaded successfully!")
|
40 |
| - with open(os.path.join(DOCS_DIR, uploaded_file.name),"wb") as f: |
| 44 | + with open(os.path.join(DOCS_DIR, uploaded_file.name), "wb") as f: |
41 | 45 | f.write(uploaded_file.read())
|
42 | 46 |
|
43 |
| -############################################ |
44 | 47 | # Component #2 - Embedding Model and LLM
|
45 |
| -############################################ |
| 48 | +llm = ChatNVIDIA(model="meta/llama3-70b-instruct") |
| 49 | +document_embedder = NVIDIAEmbeddings(model="nvidia/nv-embedqa-e5-v5", model_type="passage") |
46 | 50 |
|
47 |
| -from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings |
48 |
| - |
49 |
| -# make sure to export your NVIDIA AI Playground key as NVIDIA_API_KEY! |
50 |
| -llm = ChatNVIDIA(model="ai-llama3-70b") |
51 |
| -document_embedder = NVIDIAEmbeddings(model="ai-embed-qa-4", model_type="passage") |
52 |
| -query_embedder = NVIDIAEmbeddings(model="ai-embed-qa-4", model_type="query") |
53 |
| - |
54 |
| -############################################ |
55 | 51 | # Component #3 - Vector Database Store
|
56 |
| -############################################ |
57 |
| - |
58 |
| -from langchain.text_splitter import CharacterTextSplitter |
59 |
| -from langchain_community.document_loaders import DirectoryLoader |
60 |
| -from langchain_community.vectorstores import FAISS |
61 |
| -import pickle |
62 |
| - |
63 | 52 | with st.sidebar:
|
64 |
| - # Option for using an existing vector store |
65 | 53 | use_existing_vector_store = st.radio("Use existing vector store if available", ["Yes", "No"], horizontal=True)
|
66 | 54 |
|
67 |
| -# Path to the vector store file |
68 | 55 | vector_store_path = "vectorstore.pkl"
|
69 |
| - |
70 |
| -# Load raw documents from the directory |
71 | 56 | raw_documents = DirectoryLoader(DOCS_DIR).load()
|
72 | 57 |
|
73 |
| - |
74 |
| -# Check for existing vector store file |
75 | 58 | vector_store_exists = os.path.exists(vector_store_path)
|
76 | 59 | vectorstore = None
|
77 | 60 | if use_existing_vector_store == "Yes" and vector_store_exists:
|
|
81 | 64 | st.success("Existing vector store loaded successfully.")
|
82 | 65 | else:
|
83 | 66 | with st.sidebar:
|
84 |
| - if raw_documents: |
| 67 | + if raw_documents and use_existing_vector_store == "Yes": |
85 | 68 | with st.spinner("Splitting documents into chunks..."):
|
86 |
| - text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=200) |
| 69 | + text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=200) |
87 | 70 | documents = text_splitter.split_documents(raw_documents)
|
88 | 71 |
|
89 | 72 | with st.spinner("Adding document chunks to vector database..."):
|
|
96 | 79 | else:
|
97 | 80 | st.warning("No documents available to process!", icon="⚠️")
|
98 | 81 |
|
99 |
| -############################################ |
100 | 82 | # Component #4 - LLM Response Generation and Chat
|
101 |
| -############################################ |
102 |
| - |
103 | 83 | st.subheader("Chat with your AI Assistant, Envie!")
|
104 | 84 |
|
105 | 85 | if "messages" not in st.session_state:
|
|
109 | 89 | with st.chat_message(message["role"]):
|
110 | 90 | st.markdown(message["content"])
|
111 | 91 |
|
112 |
| -from langchain_core.output_parsers import StrOutputParser |
113 |
| -from langchain_core.prompts import ChatPromptTemplate |
114 |
| - |
115 |
| -prompt_template = ChatPromptTemplate.from_messages( |
116 |
| - [("system", "You are a helpful AI assistant named Envie. You will reply to questions only based on the context that you are provided. If something is out of context, you will refrain from replying and politely decline to respond to the user."), ("user", "{input}")] |
117 |
| -) |
118 |
| -user_input = st.chat_input("Can you tell me what NVIDIA is known for?") |
119 |
| -llm = ChatNVIDIA(model="ai-llama3-70b") |
| 92 | +prompt_template = ChatPromptTemplate.from_messages([ |
| 93 | + ("system", "You are a helpful AI assistant named Envie. If provided with context, use it to inform your responses. If no context is available, use your general knowledge to provide a helpful response."), |
| 94 | + ("human", "{input}") |
| 95 | +]) |
120 | 96 |
|
121 | 97 | chain = prompt_template | llm | StrOutputParser()
|
122 | 98 |
|
123 |
| -if user_input and vectorstore!=None: |
| 99 | +user_input = st.chat_input("Can you tell me what NVIDIA is known for?") |
| 100 | + |
| 101 | +if user_input: |
124 | 102 | st.session_state.messages.append({"role": "user", "content": user_input})
|
125 |
| - retriever = vectorstore.as_retriever() |
126 |
| - docs = retriever.invoke(user_input) |
127 | 103 | with st.chat_message("user"):
|
128 | 104 | st.markdown(user_input)
|
129 | 105 |
|
130 |
| - context = "" |
131 |
| - for doc in docs: |
132 |
| - context += doc.page_content + "\n\n" |
133 |
| - |
134 |
| - augmented_user_input = "Context: " + context + "\n\nQuestion: " + user_input + "\n" |
135 |
| - |
136 | 106 | with st.chat_message("assistant"):
|
137 | 107 | message_placeholder = st.empty()
|
138 | 108 | full_response = ""
|
139 | 109 |
|
| 110 | + if vectorstore is not None and use_existing_vector_store == "Yes": |
| 111 | + retriever = vectorstore.as_retriever() |
| 112 | + docs = retriever.invoke(user_input) |
| 113 | + context = "\n\n".join([doc.page_content for doc in docs]) |
| 114 | + augmented_user_input = f"Context: {context}\n\nQuestion: {user_input}\n" |
| 115 | + else: |
| 116 | + augmented_user_input = f"Question: {user_input}\n" |
| 117 | + |
140 | 118 | for response in chain.stream({"input": augmented_user_input}):
|
141 | 119 | full_response += response
|
142 | 120 | message_placeholder.markdown(full_response + "▌")
|
|
0 commit comments