mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 08:33:20 +00:00
This commit is contained in:
@@ -11,11 +11,11 @@ from retry import retry
|
||||
# from langchain.embeddings import CohereEmbeddings
|
||||
|
||||
|
||||
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
||||
def num_tokens_from_string(string: str, encoding_name: str) -> tuple[int, float]:
|
||||
# Function to convert string to tokens and estimate user cost.
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
num_tokens = len(encoding.encode(string))
|
||||
total_price = ((num_tokens / 1000) * 0.0004)
|
||||
total_price = (num_tokens / 1000) * 0.0004
|
||||
return num_tokens, total_price
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ def call_openai_api(docs, folder_name):
|
||||
os.makedirs(f"outputs/{folder_name}")
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
docs_test = [docs[0]]
|
||||
# remove the first element from docs
|
||||
docs.pop(0)
|
||||
@@ -44,15 +45,25 @@ def call_openai_api(docs, folder_name):
|
||||
# environment="us-east1-gcp" # next to api key in console
|
||||
# )
|
||||
# index_name = "pandas"
|
||||
store = FAISS.from_documents(docs_test, OpenAIEmbeddings())
|
||||
if ( # azure
|
||||
os.environ.get("OPENAI_API_BASE")
|
||||
and os.environ.get("OPENAI_API_VERSION")
|
||||
and os.environ.get("AZURE_DEPLOYMENT_NAME")
|
||||
):
|
||||
os.environ["OPENAI_API_TYPE"] = "azure"
|
||||
openai_embeddings = OpenAIEmbeddings(model=os.environ.get("AZURE_EMBEDDINGS_DEPLOYMENT_NAME"))
|
||||
else:
|
||||
openai_embeddings = OpenAIEmbeddings()
|
||||
store = FAISS.from_documents(docs_test, openai_embeddings)
|
||||
# store_pine = Pinecone.from_documents(docs_test, OpenAIEmbeddings(), index_name=index_name)
|
||||
|
||||
# Uncomment for MPNet embeddings
|
||||
# model_name = "sentence-transformers/all-mpnet-base-v2"
|
||||
# hf = HuggingFaceEmbeddings(model_name=model_name)
|
||||
# store = FAISS.from_documents(docs_test, hf)
|
||||
for i in tqdm(docs, desc="Embedding 🦖", unit="docs", total=len(docs),
|
||||
bar_format='{l_bar}{bar}| Time Left: {remaining}'):
|
||||
for i in tqdm(
|
||||
docs, desc="Embedding 🦖", unit="docs", total=len(docs), bar_format="{l_bar}{bar}| Time Left: {remaining}"
|
||||
):
|
||||
try:
|
||||
store_add_texts_with_retry(store, i)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user