diff --git a/.github/dependabot.yml b/.github/dependabot.yml index a00cd334..dd0799c6 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -13,3 +13,7 @@ updates: directory: "/frontend" # Location of package manifests schedule: interval: "weekly" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/holopin.yml b/.github/holopin.yml index d7487e9a..cc313f22 100644 --- a/.github/holopin.yml +++ b/.github/holopin.yml @@ -1,5 +1,11 @@ -organization: arc53 -defaultSticker: clqmdf0ed34290glbvqh0kzxd +organization: docsgpt +defaultSticker: cm1ulwkkl180570cl82rtzympu stickers: - - id: clqmdf0ed34290glbvqh0kzxd - alias: festive + - id: cm1ulwkkl180570cl82rtzympu + alias: contributor2024 + - id: cm1ureg8o130450cl8c1po6mil + alias: api + - id: cm1urhmag148240cl8yvqxkthx + alias: lpc + - id: cm1urlcpq622090cl2tvu4w71y + alias: lexeu diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ea8961f..be0263ff 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,22 +12,22 @@ jobs: contents: read packages: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v1 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to DockerHub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Login to ghcr.io - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} diff --git a/.github/workflows/cife.yml b/.github/workflows/cife.yml index 73a97755..4b1cbf3b 100644 --- a/.github/workflows/cife.yml +++ b/.github/workflows/cife.yml @@ -12,22 +12,22 @@ jobs: contents: read packages: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v1 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to DockerHub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Login to ghcr.io - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} diff --git a/.github/workflows/docker-develop-build.yml b/.github/workflows/docker-develop-build.yml index 5edc69d7..0bfc7e70 100644 --- a/.github/workflows/docker-develop-build.yml +++ b/.github/workflows/docker-develop-build.yml @@ -14,22 +14,22 @@ jobs: contents: read packages: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v1 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to DockerHub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Login to ghcr.io - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} diff --git a/.github/workflows/docker-develop-fe-build.yml b/.github/workflows/docker-develop-fe-build.yml index 29ad4524..14dbccc5 100644 --- a/.github/workflows/docker-develop-fe-build.yml +++ b/.github/workflows/docker-develop-fe-build.yml @@ -14,22 +14,22 @@ jobs: contents: read packages: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up QEMU uses: docker/setup-qemu-action@v1 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v3 - name: Login to DockerHub - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_PASSWORD }} - name: Login to ghcr.io - uses: docker/login-action@v2 + uses: docker/login-action@v3 with: registry: ghcr.io username: ${{ github.repository_owner }} diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 7ee31ebe..a36f529b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,7 +11,7 @@ jobs: ruff: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Lint with Ruff uses: chartboost/ruff-action@v1 diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index c6615e56..b858a0f7 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -8,9 +8,9 @@ jobs: matrix: python-version: ["3.11"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -24,7 +24,7 @@ jobs: python -m pytest --cov=application --cov-report=xml - name: Upload coverage reports to Codecov if: github.event_name == 'pull_request' && matrix.python-version == '3.11' - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/sync_fork.yaml b/.github/workflows/sync_fork.yaml index 81f222bb..a108daf6 100644 --- a/.github/workflows/sync_fork.yaml +++ b/.github/workflows/sync_fork.yaml @@ -17,7 +17,7 @@ jobs: steps: # Step 1: run a standard checkout action - name: Checkout target repo - uses: actions/checkout@v3 + uses: actions/checkout@v4 # Step 2: run the sync action - name: Sync upstream changes diff --git a/README.md b/README.md index f1942dc1..ee9a1af6 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,8 @@ We're eager to provide personalized assistance when deploying your DocsGPT to a [Send Email :email:](mailto:contact@arc53.com?subject=DocsGPT%20support%2Fsolutions) -![video-example-of-docs-gpt](https://d3dg1063dc54p9.cloudfront.net/videos/demov3.gif) + +video-example-of-docs-gpt ## Roadmap diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 9a22db84..17eb5cc3 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -292,6 +292,7 @@ class Stream(Resource): def post(self): data = request.get_json() required_fields = ["question"] + missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields @@ -422,7 +423,7 @@ class Answer(Resource): @api.doc(description="Provide an answer based on the question and retriever") def post(self): data = request.get_json() - required_fields = ["question"] + required_fields = ["question"] missing_fields = check_required_fields(data, required_fields) if missing_fields: return missing_fields diff --git a/application/api/user/routes.py b/application/api/user/routes.py index c409e69a..2ead8ef1 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -7,7 +7,7 @@ from bson.binary import Binary, UuidRepresentation from bson.dbref import DBRef from bson.objectid import ObjectId from flask import Blueprint, jsonify, make_response, request -from flask_restx import fields, Namespace, Resource +from flask_restx import inputs, fields, Namespace, Resource from pymongo import MongoClient from werkzeug.utils import secure_filename @@ -802,7 +802,7 @@ class ShareConversation(Resource): if missing_fields: return missing_fields - is_promptable = request.args.get("isPromptable") + is_promptable = request.args.get("isPromptable", type=inputs.boolean) if is_promptable is None: return make_response( jsonify({"success": False, "message": "isPromptable is required"}), 400 @@ -831,7 +831,7 @@ class ShareConversation(Resource): uuid.uuid4(), UuidRepresentation.STANDARD ) - if is_promptable.lower() == "true": + if is_promptable: prompt_id = data.get("prompt_id", "default") chunks = data.get("chunks", "2") @@ -859,7 +859,7 @@ class ShareConversation(Resource): "conversation_id": DBRef( "conversations", ObjectId(conversation_id) ), - "isPromptable": is_promptable.lower() == "true", + "isPromptable": is_promptable, "first_n_queries": current_n_queries, "user": user, "api_key": api_uuid, @@ -883,7 +883,7 @@ class ShareConversation(Resource): "$ref": "conversations", "$id": ObjectId(conversation_id), }, - "isPromptable": is_promptable.lower() == "true", + "isPromptable": is_promptable, "first_n_queries": current_n_queries, "user": user, "api_key": api_uuid, @@ -918,7 +918,7 @@ class ShareConversation(Resource): "$ref": "conversations", "$id": ObjectId(conversation_id), }, - "isPromptable": is_promptable.lower() == "true", + "isPromptable": is_promptable, "first_n_queries": current_n_queries, "user": user, "api_key": api_uuid, @@ -939,7 +939,7 @@ class ShareConversation(Resource): "conversation_id": DBRef( "conversations", ObjectId(conversation_id) ), - "isPromptable": is_promptable.lower() == "false", + "isPromptable": is_promptable, "first_n_queries": current_n_queries, "user": user, } @@ -962,7 +962,7 @@ class ShareConversation(Resource): "$ref": "conversations", "$id": ObjectId(conversation_id), }, - "isPromptable": is_promptable.lower() == "false", + "isPromptable": is_promptable, "first_n_queries": current_n_queries, "user": user, } diff --git a/application/cache.py b/application/cache.py new file mode 100644 index 00000000..33022e45 --- /dev/null +++ b/application/cache.py @@ -0,0 +1,93 @@ +import redis +import time +import json +import logging +from threading import Lock +from application.core.settings import settings +from application.utils import get_hash + +logger = logging.getLogger(__name__) + +_redis_instance = None +_instance_lock = Lock() + +def get_redis_instance(): + global _redis_instance + if _redis_instance is None: + with _instance_lock: + if _redis_instance is None: + try: + _redis_instance = redis.Redis.from_url(settings.CACHE_REDIS_URL, socket_connect_timeout=2) + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + _redis_instance = None + return _redis_instance + +def gen_cache_key(*messages, model="docgpt"): + if not all(isinstance(msg, dict) for msg in messages): + raise ValueError("All messages must be dictionaries.") + messages_str = json.dumps(list(messages), sort_keys=True) + combined = f"{model}_{messages_str}" + cache_key = get_hash(combined) + return cache_key + +def gen_cache(func): + def wrapper(self, model, messages, *args, **kwargs): + try: + cache_key = gen_cache_key(*messages) + redis_client = get_redis_instance() + if redis_client: + try: + cached_response = redis_client.get(cache_key) + if cached_response: + return cached_response.decode('utf-8') + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + + result = func(self, model, messages, *args, **kwargs) + if redis_client: + try: + redis_client.set(cache_key, result, ex=1800) + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + + return result + except ValueError as e: + logger.error(e) + return "Error: No user message found in the conversation to generate a cache key." + return wrapper + +def stream_cache(func): + def wrapper(self, model, messages, stream, *args, **kwargs): + cache_key = gen_cache_key(*messages) + logger.info(f"Stream cache key: {cache_key}") + + redis_client = get_redis_instance() + if redis_client: + try: + cached_response = redis_client.get(cache_key) + if cached_response: + logger.info(f"Cache hit for stream key: {cache_key}") + cached_response = json.loads(cached_response.decode('utf-8')) + for chunk in cached_response: + yield chunk + time.sleep(0.03) + return + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + + result = func(self, model, messages, stream, *args, **kwargs) + stream_cache_data = [] + + for chunk in result: + stream_cache_data.append(chunk) + yield chunk + + if redis_client: + try: + redis_client.set(cache_key, json.dumps(stream_cache_data), ex=1800) + logger.info(f"Stream cache saved for key: {cache_key}") + except redis.ConnectionError as e: + logger.error(f"Redis connection error: {e}") + + return wrapper \ No newline at end of file diff --git a/application/core/settings.py b/application/core/settings.py index e6173be4..d4b02481 100644 --- a/application/core/settings.py +++ b/application/core/settings.py @@ -18,9 +18,12 @@ class Settings(BaseSettings): DEFAULT_MAX_HISTORY: int = 150 MODEL_TOKEN_LIMITS: dict = {"gpt-3.5-turbo": 4096, "claude-2": 1e5} UPLOAD_FOLDER: str = "inputs" - VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" + VECTOR_STORE: str = "faiss" # "faiss" or "elasticsearch" or "qdrant" or "milvus" or "lancedb" RETRIEVERS_ENABLED: list = ["classic_rag", "duckduck_search"] # also brave_search + # LLM Cache + CACHE_REDIS_URL: str = "redis://localhost:6379/2" + API_URL: str = "http://localhost:7091" # backend url for celery worker API_KEY: Optional[str] = None # LLM api key @@ -67,6 +70,9 @@ class Settings(BaseSettings): MILVUS_URI: Optional[str] = "./milvus_local.db" # milvus lite version as default MILVUS_TOKEN: Optional[str] = "" + # LanceDB vectorstore config + LANCEDB_PATH: str = "/tmp/lancedb" # Path where LanceDB stores its local data + LANCEDB_TABLE_NAME: Optional[str] = "docsgpts" # Name of the table to use for storing vectors BRAVE_SEARCH_API_KEY: Optional[str] = None FLASK_DEBUG_MODE: bool = False diff --git a/application/llm/base.py b/application/llm/base.py index 475b7937..1caab5d3 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,28 +1,29 @@ from abc import ABC, abstractmethod from application.usage import gen_token_usage, stream_token_usage +from application.cache import stream_cache, gen_cache class BaseLLM(ABC): def __init__(self): self.token_usage = {"prompt_tokens": 0, "generated_tokens": 0} - def _apply_decorator(self, method, decorator, *args, **kwargs): - return decorator(method, *args, **kwargs) + def _apply_decorator(self, method, decorators, *args, **kwargs): + for decorator in decorators: + method = decorator(method) + return method(self, *args, **kwargs) @abstractmethod def _raw_gen(self, model, messages, stream, *args, **kwargs): pass def gen(self, model, messages, stream=False, *args, **kwargs): - return self._apply_decorator(self._raw_gen, gen_token_usage)( - self, model=model, messages=messages, stream=stream, *args, **kwargs - ) + decorators = [gen_token_usage, gen_cache] + return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass def gen_stream(self, model, messages, stream=True, *args, **kwargs): - return self._apply_decorator(self._raw_gen_stream, stream_token_usage)( - self, model=model, messages=messages, stream=stream, *args, **kwargs - ) + decorators = [stream_cache, stream_token_usage] + return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) \ No newline at end of file diff --git a/application/requirements.txt b/application/requirements.txt index 6a57dd12..6ea1d1ba 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -4,7 +4,7 @@ beautifulsoup4==4.12.3 celery==5.3.6 dataclasses-json==0.6.7 docx2txt==0.8 -duckduckgo-search==6.2.6 +duckduckgo-search==6.3.0 ebooklib==0.18 elastic-transport==8.15.0 elasticsearch==8.15.1 @@ -54,7 +54,7 @@ pathable==0.4.3 pillow==10.4.0 portalocker==2.10.1 prance==23.6.21.0 -primp==0.6.2 +primp==0.6.3 prompt-toolkit==3.0.47 protobuf==5.28.2 py==1.11.0 diff --git a/application/utils.py b/application/utils.py index f0802c39..1fc9e329 100644 --- a/application/utils.py +++ b/application/utils.py @@ -1,6 +1,8 @@ import tiktoken +import hashlib from flask import jsonify, make_response + _encoding = None @@ -39,3 +41,8 @@ def check_required_fields(data, required_fields): 400, ) return None + + +def get_hash(data): + return hashlib.md5(data.encode()).hexdigest() + diff --git a/application/vectorstore/lancedb.py b/application/vectorstore/lancedb.py new file mode 100644 index 00000000..25d62318 --- /dev/null +++ b/application/vectorstore/lancedb.py @@ -0,0 +1,119 @@ +from typing import List, Optional +import importlib +from application.vectorstore.base import BaseVectorStore +from application.core.settings import settings + +class LanceDBVectorStore(BaseVectorStore): + """Class for LanceDB Vector Store integration.""" + + def __init__(self, path: str = settings.LANCEDB_PATH, + table_name_prefix: str = settings.LANCEDB_TABLE_NAME, + source_id: str = None, + embeddings_key: str = "embeddings"): + """Initialize the LanceDB vector store.""" + super().__init__() + self.path = path + self.table_name = f"{table_name_prefix}_{source_id}" if source_id else table_name_prefix + self.embeddings_key = embeddings_key + self._lance_db = None + self.docsearch = None + self._pa = None # PyArrow (pa) will be lazy loaded + + @property + def pa(self): + """Lazy load pyarrow module.""" + if self._pa is None: + self._pa = importlib.import_module("pyarrow") + return self._pa + + @property + def lancedb(self): + """Lazy load lancedb module.""" + if not hasattr(self, "_lancedb_module"): + self._lancedb_module = importlib.import_module("lancedb") + return self._lancedb_module + + @property + def lance_db(self): + """Lazy load the LanceDB connection.""" + if self._lance_db is None: + self._lance_db = self.lancedb.connect(self.path) + return self._lance_db + + @property + def table(self): + """Lazy load the LanceDB table.""" + if self.docsearch is None: + if self.table_name in self.lance_db.table_names(): + self.docsearch = self.lance_db.open_table(self.table_name) + else: + self.docsearch = None + return self.docsearch + + def ensure_table_exists(self): + """Ensure the table exists before performing operations.""" + if self.table is None: + embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key) + schema = self.pa.schema([ + self.pa.field("vector", self.pa.list_(self.pa.float32(), list_size=embeddings.dimension)), + self.pa.field("text", self.pa.string()), + self.pa.field("metadata", self.pa.struct([ + self.pa.field("key", self.pa.string()), + self.pa.field("value", self.pa.string()) + ])) + ]) + self.docsearch = self.lance_db.create_table(self.table_name, schema=schema) + + def add_texts(self, texts: List[str], metadatas: Optional[List[dict]] = None, source_id: str = None): + """Add texts with metadata and their embeddings to the LanceDB table.""" + embeddings = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key).embed_documents(texts) + vectors = [] + for embedding, text, metadata in zip(embeddings, texts, metadatas or [{}] * len(texts)): + if source_id: + metadata["source_id"] = source_id + metadata_struct = [{"key": k, "value": str(v)} for k, v in metadata.items()] + vectors.append({ + "vector": embedding, + "text": text, + "metadata": metadata_struct + }) + self.ensure_table_exists() + self.docsearch.add(vectors) + + def search(self, query: str, k: int = 2, *args, **kwargs): + """Search LanceDB for the top k most similar vectors.""" + self.ensure_table_exists() + query_embedding = self._get_embeddings(settings.EMBEDDINGS_NAME, self.embeddings_key).embed_query(query) + results = self.docsearch.search(query_embedding).limit(k).to_list() + return [(result["_distance"], result["text"], result["metadata"]) for result in results] + + def delete_index(self): + """Delete the entire LanceDB index (table).""" + if self.table: + self.lance_db.drop_table(self.table_name) + + def assert_embedding_dimensions(self, embeddings): + """Ensure that embedding dimensions match the table index dimensions.""" + word_embedding_dimension = embeddings.dimension + if self.table: + table_index_dimension = len(self.docsearch.schema["vector"].type.value_type) + if word_embedding_dimension != table_index_dimension: + raise ValueError( + f"Embedding dimension mismatch: embeddings.dimension ({word_embedding_dimension}) " + f"!= table index dimension ({table_index_dimension})" + ) + + def filter_documents(self, filter_condition: dict) -> List[dict]: + """Filter documents based on certain conditions.""" + self.ensure_table_exists() + + # Ensure source_id exists in the filter condition + if 'source_id' not in filter_condition: + raise ValueError("filter_condition must contain 'source_id'") + + source_id = filter_condition["source_id"] + + # Use LanceDB's native filtering if supported, otherwise filter manually + filtered_data = self.docsearch.filter(lambda x: x.metadata and x.metadata.get("source_id") == source_id).to_list() + + return filtered_data \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml index f3b8a363..d3f3421a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -20,6 +20,7 @@ services: - CELERY_BROKER_URL=redis://redis:6379/0 - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt + - CACHE_REDIS_URL=redis://redis:6379/2 ports: - "7091:7091" volumes: @@ -41,6 +42,7 @@ services: - CELERY_RESULT_BACKEND=redis://redis:6379/1 - MONGO_URI=mongodb://mongo:27017/docsgpt - API_URL=http://backend:7091 + - CACHE_REDIS_URL=redis://redis:6379/2 depends_on: - redis - mongo diff --git a/docs/package-lock.json b/docs/package-lock.json index 99836cc6..78206570 100644 --- a/docs/package-lock.json +++ b/docs/package-lock.json @@ -7,7 +7,7 @@ "license": "MIT", "dependencies": { "@vercel/analytics": "^1.1.1", - "docsgpt": "^0.4.1", + "docsgpt": "^0.4.3", "next": "^14.2.12", "nextra": "^2.13.2", "nextra-theme-docs": "^2.13.2", @@ -422,11 +422,6 @@ "node": ">=6.9.0" } }, - "node_modules/@bpmn-io/snarkdown": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/@bpmn-io/snarkdown/-/snarkdown-2.2.0.tgz", - "integrity": "sha512-bVD7FIoaBDZeCJkMRgnBPDeptPlto87wt2qaCjf5t8iLaevDmTPaREd6FpBEGsHlUdHFFZWRk4qAoEC5So2M0Q==" - }, "node_modules/@braintree/sanitize-url": { "version": "6.0.4", "resolved": "https://registry.npmjs.org/@braintree/sanitize-url/-/sanitize-url-6.0.4.tgz", @@ -3162,30 +3157,6 @@ "cytoscape": "^3.2.0" } }, - "node_modules/cytoscape-fcose": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/cytoscape-fcose/-/cytoscape-fcose-2.2.0.tgz", - "integrity": "sha512-ki1/VuRIHFCzxWNrsshHYPs6L7TvLu3DL+TyIGEsRcvVERmxokbf5Gdk7mFxZnTdiGtnA4cfSmjZJMviqSuZrQ==", - "dependencies": { - "cose-base": "^2.2.0" - }, - "peerDependencies": { - "cytoscape": "^3.2.0" - } - }, - "node_modules/cytoscape-fcose/node_modules/cose-base": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/cose-base/-/cose-base-2.2.0.tgz", - "integrity": "sha512-AzlgcsCbUMymkADOJtQm3wO9S3ltPfYOFD5033keQn9NJzIbtnZj+UdBJe7DYml/8TdbtHJW3j58SOnKhWY/5g==", - "dependencies": { - "layout-base": "^2.0.0" - } - }, - "node_modules/cytoscape-fcose/node_modules/layout-base": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/layout-base/-/layout-base-2.0.1.tgz", - "integrity": "sha512-dp3s92+uNI1hWIpPGH3jK2kxE2lMjdXdr+DH8ynZHpd6PUlH6x6cbuXnoMmiNumznqaNO31xu9e79F0uuZ0JFg==" - }, "node_modules/d3": { "version": "7.8.5", "resolved": "https://registry.npmjs.org/d3/-/d3-7.8.5.tgz", @@ -3697,13 +3668,12 @@ } }, "node_modules/docsgpt": { - "version": "0.4.1", - "resolved": "https://registry.npmjs.org/docsgpt/-/docsgpt-0.4.1.tgz", - "integrity": "sha512-9oH638vIg8I+zsjLV5Rp21yYniAtiTcyuBSByqWl2KoBdF/8vDSmr491l8n+ikbaTLiCW4uRU0p0r3BvRizy2Q==", + "version": "0.4.3", + "resolved": "https://registry.npmjs.org/docsgpt/-/docsgpt-0.4.3.tgz", + "integrity": "sha512-svLM6xEg4iUtb7HuR1qwC95K4ctvTky8gXRXgqtDIUC5Fg4zeHwivbmaFkBbP3N+bcqWVWCJK9DfJfW+OjTeuA==", "license": "Apache-2.0", "dependencies": { "@babel/plugin-transform-flow-strip-types": "^7.23.3", - "@bpmn-io/snarkdown": "^2.2.0", "@parcel/resolver-glob": "^2.12.0", "@parcel/transformer-svg-react": "^2.12.0", "@parcel/transformer-typescript-tsc": "^2.12.0", @@ -3715,6 +3685,7 @@ "flow-bin": "^0.229.2", "i": "^0.3.7", "install": "^0.13.0", + "markdown-it": "^14.1.0", "npm": "^10.5.0", "react": "^18.2.0", "react-dom": "^18.2.0", @@ -3807,9 +3778,9 @@ "integrity": "sha512-/if4Ueg0GUQlhCrW2ZlXwDAm40ipuKo+OgeHInlL8sbjt+hzISxZK949fZeJaVsheamrzANXvw1zQTvbxTvSHw==" }, "node_modules/elkjs": { - "version": "0.8.2", - "resolved": "https://registry.npmjs.org/elkjs/-/elkjs-0.8.2.tgz", - "integrity": "sha512-L6uRgvZTH+4OF5NE/MBbzQx/WYpru1xCBE9respNj6qznEewGUIfhzmm7horWWxbNO2M0WckQypGctR8lH79xQ==" + "version": "0.9.3", + "resolved": "https://registry.npmjs.org/elkjs/-/elkjs-0.9.3.tgz", + "integrity": "sha512-f/ZeWvW/BCXbhGEf1Ujp29EASo/lk1FDnETgNKwJrsVvGZhUWCZyg3xLJjAsxfOmt8KjswHmI5EwCQcPMpOYhQ==" }, "node_modules/entities": { "version": "4.5.0", @@ -4859,6 +4830,15 @@ "resolved": "https://registry.npmjs.org/lines-and-columns/-/lines-and-columns-1.2.4.tgz", "integrity": "sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==" }, + "node_modules/linkify-it": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/linkify-it/-/linkify-it-5.0.0.tgz", + "integrity": "sha512-5aHCbzQRADcdP+ATqnDuhhJ/MRIqDkZX5pyjFHRRysS8vZ5AbqGEoFIb6pYHPZ+L/OC2Lc+xT8uHVVR5CAK/wQ==", + "license": "MIT", + "dependencies": { + "uc.micro": "^2.0.0" + } + }, "node_modules/lmdb": { "version": "2.8.5", "resolved": "https://registry.npmjs.org/lmdb/-/lmdb-2.8.5.tgz", @@ -4944,6 +4924,29 @@ "node": ">=0.10.0" } }, + "node_modules/markdown-it": { + "version": "14.1.0", + "resolved": "https://registry.npmjs.org/markdown-it/-/markdown-it-14.1.0.tgz", + "integrity": "sha512-a54IwgWPaeBCAAsv13YgmALOF1elABB08FxO9i+r4VFk5Vl4pKokRPeX8u5TCgSsPi6ec1otfLjdOpVcgbpshg==", + "license": "MIT", + "dependencies": { + "argparse": "^2.0.1", + "entities": "^4.4.0", + "linkify-it": "^5.0.0", + "mdurl": "^2.0.0", + "punycode.js": "^2.3.1", + "uc.micro": "^2.1.0" + }, + "bin": { + "markdown-it": "bin/markdown-it.mjs" + } + }, + "node_modules/markdown-it/node_modules/argparse": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", + "integrity": "sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==", + "license": "Python-2.0" + }, "node_modules/markdown-table": { "version": "3.0.3", "resolved": "https://registry.npmjs.org/markdown-table/-/markdown-table-3.0.3.tgz", @@ -5487,23 +5490,29 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/mdurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mdurl/-/mdurl-2.0.0.tgz", + "integrity": "sha512-Lf+9+2r+Tdp5wXDXC4PcIBjTDtq4UKjCPMQhKIuzpJNW0b96kVqSwW0bT7FhRSfmAiFYgP+SCRvdrDozfh0U5w==", + "license": "MIT" + }, "node_modules/mermaid": { - "version": "10.6.1", - "resolved": "https://registry.npmjs.org/mermaid/-/mermaid-10.6.1.tgz", - "integrity": "sha512-Hky0/RpOw/1il9X8AvzOEChfJtVvmXm+y7JML5C//ePYMy0/9jCEmW1E1g86x9oDfW9+iVEdTV/i+M6KWRNs4A==", + "version": "10.9.3", + "resolved": "https://registry.npmjs.org/mermaid/-/mermaid-10.9.3.tgz", + "integrity": "sha512-V80X1isSEvAewIL3xhmz/rVmc27CVljcsbWxkxlWJWY/1kQa4XOABqpDl2qQLGKzpKm6WbTfUEKImBlUfFYArw==", "dependencies": { "@braintree/sanitize-url": "^6.0.1", "@types/d3-scale": "^4.0.3", "@types/d3-scale-chromatic": "^3.0.0", - "cytoscape": "^3.23.0", + "cytoscape": "^3.28.1", "cytoscape-cose-bilkent": "^4.1.0", - "cytoscape-fcose": "^2.1.0", "d3": "^7.4.0", "d3-sankey": "^0.12.3", "dagre-d3-es": "7.0.10", "dayjs": "^1.11.7", - "dompurify": "^3.0.5", - "elkjs": "^0.8.2", + "dompurify": "^3.0.5 <3.1.7", + "elkjs": "^0.9.0", + "katex": "^0.16.9", "khroma": "^2.0.0", "lodash-es": "^4.17.21", "mdast-util-from-markdown": "^1.3.0", @@ -9338,6 +9347,15 @@ "resolved": "https://registry.npmjs.org/pseudomap/-/pseudomap-1.0.2.tgz", "integrity": "sha512-b/YwNhb8lk1Zz2+bXXpS/LK9OisiZZ1SNsSLxN1x2OXVEhW2Ckr/7mWE5vrC1ZTiJlD9g19jWszTmJsB+oEpFQ==" }, + "node_modules/punycode.js": { + "version": "2.3.1", + "resolved": "https://registry.npmjs.org/punycode.js/-/punycode.js-2.3.1.tgz", + "integrity": "sha512-uxFIHU0YlHYhDQtV4R9J6a52SLx28BCjT+4ieh7IGbgwVJWO+km431c4yRlREUAsAmt/uMjQUyQHNEPf0M39CA==", + "license": "MIT", + "engines": { + "node": ">=6" + } + }, "node_modules/react": { "version": "18.2.0", "resolved": "https://registry.npmjs.org/react/-/react-18.2.0.tgz", @@ -10052,9 +10070,9 @@ } }, "node_modules/typescript": { - "version": "5.6.2", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.6.2.tgz", - "integrity": "sha512-NW8ByodCSNCwZeghjN3o+JX5OFH0Ojg6sadjEKY4huZ52TqbJTJnDo5+Tw98lSy63NZvi4n+ez5m2u5d4PkZyw==", + "version": "5.6.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.6.3.tgz", + "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "peer": true, "bin": { "tsc": "bin/tsc", @@ -10064,6 +10082,12 @@ "node": ">=14.17" } }, + "node_modules/uc.micro": { + "version": "2.1.0", + "resolved": "https://registry.npmjs.org/uc.micro/-/uc.micro-2.1.0.tgz", + "integrity": "sha512-ARDJmphmdvUk6Glw7y9DQ2bFkKBHwQHLi2lsaH6PPmz/Ka9sFOBsBluozhDltWmnv9u/cF6Rt87znRTPV+yp/A==", + "license": "MIT" + }, "node_modules/unified": { "version": "10.1.2", "resolved": "https://registry.npmjs.org/unified/-/unified-10.1.2.tgz", diff --git a/docs/package.json b/docs/package.json index 9acf9b2a..ddd06c03 100644 --- a/docs/package.json +++ b/docs/package.json @@ -7,7 +7,7 @@ "license": "MIT", "dependencies": { "@vercel/analytics": "^1.1.1", - "docsgpt": "^0.4.1", + "docsgpt": "^0.4.3", "next": "^14.2.12", "nextra": "^2.13.2", "nextra-theme-docs": "^2.13.2", diff --git a/extensions/react-widget/package.json b/extensions/react-widget/package.json index d449d0a3..baf62aca 100644 --- a/extensions/react-widget/package.json +++ b/extensions/react-widget/package.json @@ -1,6 +1,6 @@ { "name": "docsgpt", - "version": "0.4.2", + "version": "0.4.3", "private": false, "description": "DocsGPT 🦖 is an innovative open-source tool designed to simplify the retrieval of information from project documentation using advanced GPT models 🤖.", "source": "./src/index.html", diff --git a/extensions/react-widget/src/components/DocsGPTWidget.tsx b/extensions/react-widget/src/components/DocsGPTWidget.tsx index 83defbcf..01861274 100644 --- a/extensions/react-widget/src/components/DocsGPTWidget.tsx +++ b/extensions/react-widget/src/components/DocsGPTWidget.tsx @@ -453,8 +453,11 @@ export const DocsGPTWidget = ({ setQueries(updatedQueries); setStatus('idle') } + else if (data.type === 'source') { + // handle the case where data type === 'source' + } else { - const result = data.answer; + const result = data.answer ? data.answer : ''; //Fallback to an empty string if data.answer is undefined const streamingResponse = queries[queries.length - 1].response ? queries[queries.length - 1].response : ''; const updatedQueries = [...queries]; updatedQueries[updatedQueries.length - 1].response = streamingResponse + result; diff --git a/extensions/slack-bot/.gitignore b/extensions/slack-bot/.gitignore new file mode 100644 index 00000000..1d8e58b2 --- /dev/null +++ b/extensions/slack-bot/.gitignore @@ -0,0 +1,3 @@ +.env +.venv/ +get-pip.py \ No newline at end of file diff --git a/extensions/slack-bot/Readme.md b/extensions/slack-bot/Readme.md new file mode 100644 index 00000000..704184a2 --- /dev/null +++ b/extensions/slack-bot/Readme.md @@ -0,0 +1,84 @@ + +# Slack Bot Configuration Guide + +> **Note:** The following guidelines must be followed on the [Slack API website](https://api.slack.com/) for setting up your Slack app and generating the necessary tokens. + +## Step-by-Step Instructions + +### 1. Navigate to Your Apps +- Go to the Slack API page for apps and select **Create an App** from the “From Scratch” option. + +### 2. App Creation +- Name your app and choose the workspace where you wish to add the assistant. + +### 3. Enabling Socket Mode +- Navigate to **Settings > Socket Mode** and enable **Socket Mode**. +- This action will generate an App-level token. Select the `connections:write` scope and copy the App-level token for future use. + +### 4. Socket Naming +- Assign a name to your socket as per your preference. + +### 5. Basic Information Setup +- Go to **Basic Information** (under **Settings**) and configure the following: + - Assistant name + - App icon + - Background color + +### 6. Bot Token and Permissions +- In the **OAuth & Permissions** option found under the **Features** section, retrieve the Bot Token. Save it for future usage. +- You will also need to add specific bot token scopes: + - `app_mentions:read` + - `assistant:write` + - `chat:write` + - `chat:write.public` + - `im:history` + +### 7. Enable Events +- From **Event Subscriptions**, enable events and add the following Bot User events: + - `app_mention` + - `assistant_thread_context_changed` + - `assistant_thread_started` + - `message.im` + +### 8. Agent/Assistant Toggle +- In the **Features > Agent & Assistants** section, toggle on the Agent or Assistant option. +- In the **Suggested Prompts** setting, leave it as `dynamic` (this is the default setting). + +--- + +## Code-Side Configuration Guide + +This section focuses on generating and setting up the necessary tokens in the `.env` file, using the `.env-example` as a template. + +### Step 1: Generating Required Keys + +1. **SLACK_APP_TOKEN** + - Navigate to **Settings > Socket Mode** in the Slack API and enable **Socket Mode**. + - Copy the App-level token generated (usually starts with `xapp-`). + +2. **SLACK_BOT_TOKEN** + - Go to **OAuth & Permissions** (under the **Features** section in Slack API). + - Retrieve the **Bot Token** (starts with `xoxb-`). + +3. **DOCSGPT_API_KEY** + - Go to the **DocsGPT website**. + - Navigate to **Settings > Chatbots > Create New** to generate a DocsGPT API Key. + - Copy the generated key for use. + +### Step 2: Creating and Updating the `.env` File + +1. Create a new `.env` file in the root of your project (if it doesn’t already exist). +2. Use the `.env-example` as a reference and update the file with the following keys and values: + +```bash +# .env file +SLACK_APP_TOKEN=xapp-your-generated-app-token +SLACK_BOT_TOKEN=xoxb-your-generated-bot-token +DOCSGPT_API_KEY=your-docsgpt-generated-api-key +``` + +Replace the placeholder values with the actual tokens generated from the Slack API and DocsGPT as per the steps outlined above. + +--- + +This concludes the guide for both setting up the Slack API and configuring the `.env` file on the code side. diff --git a/extensions/slack-bot/app.py b/extensions/slack-bot/app.py new file mode 100644 index 00000000..d4f522fd --- /dev/null +++ b/extensions/slack-bot/app.py @@ -0,0 +1,112 @@ +import os +import hashlib +import httpx +import re +from slack_bolt.async_app import AsyncApp +from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler +from dotenv import load_dotenv + +load_dotenv() +API_BASE = os.getenv("API_BASE", "https://gptcloud.arc53.com") +API_URL = API_BASE + "/api/answer" + +# Slack bot token and signing secret +SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN") +SLACK_APP_TOKEN = os.getenv("SLACK_APP_TOKEN") + +# OpenAI API key for DocsGPT (replace this with your actual API key) +DOCSGPT_API_KEY = os.getenv("DOCSGPT_API_KEY") + +# Initialize Slack app +app = AsyncApp(token=SLACK_BOT_TOKEN) + +def encode_conversation_id(conversation_id: str) -> str: + """ + Encodes 11 length Slack conversation_id to 12 length string + Args: + conversation_id (str): The 11 digit slack conversation_id. + Returns: + str: Hashed id. + """ + # Create a SHA-256 hash of the string + hashed_id = hashlib.sha256(conversation_id.encode()).hexdigest() + + # Take the first 24 characters of the hash + hashed_24_char_id = hashed_id[:24] + return hashed_24_char_id + +async def generate_answer(question: str, messages: list, conversation_id: str | None) -> dict: + """Generates an answer using the external API.""" + payload = { + "question": question, + "api_key": DOCSGPT_API_KEY, + "history": messages, + "conversation_id": conversation_id, + } + headers = { + "Content-Type": "application/json; charset=utf-8" + } + timeout = 60.0 + async with httpx.AsyncClient() as client: + response = await client.post(API_URL, json=payload, headers=headers, timeout=timeout) + + if response.status_code == 200: + data = response.json() + conversation_id = data.get("conversation_id") + answer = data.get("answer", "Sorry, I couldn't find an answer.") + return {"answer": answer, "conversation_id": conversation_id} + else: + print(response.json()) + return {"answer": "Sorry, I couldn't find an answer.", "conversation_id": None} + +@app.message(".*") +async def message_docs(message, say): + client = app.client + channel = message['channel'] + thread_ts = message['thread_ts'] + user_query = message['text'] + await client.assistant_threads_setStatus( + channel_id = channel, + thread_ts = thread_ts, + status = "is generating your answer...", + ) + + docs_gpt_channel_id = encode_conversation_id(thread_ts) + + # Get response from DocsGPT + response = await generate_answer(user_query,[], docs_gpt_channel_id) + answer = convert_to_slack_markdown(response['answer']) + + # Respond in Slack + await client.chat_postMessage(text = answer, mrkdwn= True, channel= message['channel'], + thread_ts = message['thread_ts'],) + +def convert_to_slack_markdown(markdown_text: str): + # Convert bold **text** to *text* for Slack + slack_text = re.sub(r'\*\*(.*?)\*\*', r'*\1*', markdown_text) # **text** to *text* + + # Convert italics _text_ to _text_ for Slack + slack_text = re.sub(r'_(.*?)_', r'_\1_', slack_text) # _text_ to _text_ + + # Convert inline code `code` to `code` (Slack supports backticks for inline code) + slack_text = re.sub(r'`(.*?)`', r'`\1`', slack_text) + + # Convert bullet points with single or no spaces to filled bullets (•) + slack_text = re.sub(r'^\s{0,1}[-*]\s+', ' • ', slack_text, flags=re.MULTILINE) + + # Convert bullet points with multiple spaces to hollow bullets (◦) + slack_text = re.sub(r'^\s{2,}[-*]\s+', '\t◦ ', slack_text, flags=re.MULTILINE) + + # Convert headers (##) to bold in Slack + slack_text = re.sub(r'^\s*#{1,6}\s*(.*?)$', r'*\1*', slack_text, flags=re.MULTILINE) + + return slack_text + +async def main(): + handler = AsyncSocketModeHandler(app, os.environ["SLACK_APP_TOKEN"]) + await handler.start_async() + +# Start the app +if __name__ == "__main__": + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/extensions/slack-bot/requirements.txt b/extensions/slack-bot/requirements.txt new file mode 100644 index 00000000..0c588b43 --- /dev/null +++ b/extensions/slack-bot/requirements.txt @@ -0,0 +1,10 @@ +aiohttp>=3,<4 +certifi==2024.7.4 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.0 +idna==3.7 +python-dotenv==1.0.1 +sniffio==1.3.1 +slack-bolt==1.21.0 +bson==0.5.10 diff --git a/frontend/signal-desktop-keyring.gpg b/frontend/signal-desktop-keyring.gpg new file mode 100644 index 00000000..b5e68a04 Binary files /dev/null and b/frontend/signal-desktop-keyring.gpg differ diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 1455f495..ba0a4bd7 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -19,7 +19,7 @@ function MainLayout() {
+ const [, , componentMounted] = useDarkTheme(); + if (!componentMounted) { + return
; } return (
diff --git a/frontend/src/Navigation.tsx b/frontend/src/Navigation.tsx index 052a9643..ca12df54 100644 --- a/frontend/src/Navigation.tsx +++ b/frontend/src/Navigation.tsx @@ -2,15 +2,15 @@ import { useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useDispatch, useSelector } from 'react-redux'; import { NavLink, useNavigate } from 'react-router-dom'; - import conversationService from './api/services/conversationService'; import userService from './api/services/userService'; import Add from './assets/add.svg'; +import openNewChat from './assets/openNewChat.svg'; +import Hamburger from './assets/hamburger.svg'; import DocsGPT3 from './assets/cute_docsgpt3.svg'; import Discord from './assets/discord.svg'; import Expand from './assets/expand.svg'; import Github from './assets/github.svg'; -import Hamburger from './assets/hamburger.svg'; import Info from './assets/info.svg'; import SettingGear from './assets/settingGear.svg'; import Twitter from './assets/TwitterX.svg'; @@ -40,7 +40,11 @@ import { setSelectedDocs, setSourceDocs, } from './preferences/preferenceSlice'; +import { selectQueries } from './conversation/conversationSlice'; import Upload from './upload/Upload'; +import ShareButton from './components/ShareButton'; +import Help from './components/Help'; + interface NavigationProps { navOpen: boolean; @@ -63,6 +67,7 @@ NavImage.propTypes = { }; */ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const dispatch = useDispatch(); + const queries = useSelector(selectQueries); const docs = useSelector(selectSourceDocs); const selectedDocs = useSelector(selectSelectedDocs); const conversations = useSelector(selectConversations); @@ -73,7 +78,6 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { const [isDarkTheme] = useDarkTheme(); const [isDocsListOpen, setIsDocsListOpen] = useState(false); const { t } = useTranslation(); - const isApiKeySet = useSelector(selectApiKeyStatus); const [apiKeyModalState, setApiKeyModalState] = useState('INACTIVE'); @@ -93,6 +97,9 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { if (!conversations) { fetchConversations(); } + if (queries.length === 0) { + resetConversation(); + } }, [conversations, dispatch]); async function fetchConversations() { @@ -164,7 +171,11 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { }), ); }; - + const newChat = () => { + if (queries && queries?.length > 0) { + resetConversation(); + } + }; async function updateConversationName(updatedConversation: { name: string; id: string; @@ -200,26 +211,45 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { return ( <> {!navOpen && ( - +
+
+ + {queries?.length > 0 && ( + + )} +
+ DocsGPT +
+
+
)}
-
+
{conversations && conversations.length > 0 ? (
@@ -304,7 +337,6 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) { <> )}
-
@@ -359,83 +391,69 @@ export default function Navigation({ navOpen, setNavOpen }: NavigationProps) {

-
- { - if (isMobile) { - setNavOpen(!navOpen); - } - resetConversation(); - }} - to="/about" - className={({ isActive }) => - `my-auto mx-4 flex h-9 cursor-pointer gap-4 rounded-3xl hover:bg-gray-100 dark:hover:bg-[#28292E] ${ - isActive ? 'bg-gray-3000 dark:bg-[#28292E]' : '' - }` - } - > - icon -

{t('about')}

-
-
- - discord - - - x - - - github - +
+
+ + +
+ + discord + + + x + + + github + +
- +
+ +
DocsGPT
+
- + + + diff --git a/frontend/src/assets/documentation.svg b/frontend/src/assets/documentation.svg index f9f7c596..955d392f 100644 --- a/frontend/src/assets/documentation.svg +++ b/frontend/src/assets/documentation.svg @@ -1,3 +1,4 @@ - - + + + diff --git a/frontend/src/assets/envelope-dark.svg b/frontend/src/assets/envelope-dark.svg new file mode 100644 index 00000000..a61bec4f --- /dev/null +++ b/frontend/src/assets/envelope-dark.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/src/assets/envelope.svg b/frontend/src/assets/envelope.svg new file mode 100644 index 00000000..a4c25032 --- /dev/null +++ b/frontend/src/assets/envelope.svg @@ -0,0 +1,3 @@ + + + diff --git a/frontend/src/assets/openNewChat.svg b/frontend/src/assets/openNewChat.svg new file mode 100644 index 00000000..0749ff17 --- /dev/null +++ b/frontend/src/assets/openNewChat.svg @@ -0,0 +1,4 @@ + + + + diff --git a/frontend/src/components/Help.tsx b/frontend/src/components/Help.tsx new file mode 100644 index 00000000..44bcb057 --- /dev/null +++ b/frontend/src/components/Help.tsx @@ -0,0 +1,80 @@ +import { useState, useRef, useEffect } from 'react'; +import Info from '../assets/info.svg'; +import PageIcon from '../assets/documentation.svg'; +import EmailIcon from '../assets/envelope.svg'; +import { useTranslation } from 'react-i18next'; +const Help = () => { + const [isOpen, setIsOpen] = useState(false); + const dropdownRef = useRef(null); + const buttonRef = useRef(null); + const { t } = useTranslation(); + + const toggleDropdown = () => { + setIsOpen((prev) => !prev); + }; + + const handleClickOutside = (event: MouseEvent) => { + if ( + dropdownRef.current && + !dropdownRef.current.contains(event.target as Node) && + buttonRef.current && + !buttonRef.current.contains(event.target as Node) + ) { + setIsOpen(false); + } + }; + + useEffect(() => { + document.addEventListener('mousedown', handleClickOutside); + return () => { + document.removeEventListener('mousedown', handleClickOutside); + }; + }, []); + + return ( +
+ + {isOpen && ( + + )} +
+ ); +}; + +export default Help; diff --git a/frontend/src/components/SettingsBar.tsx b/frontend/src/components/SettingsBar.tsx index 2b7c2a33..f617c6e8 100644 --- a/frontend/src/components/SettingsBar.tsx +++ b/frontend/src/components/SettingsBar.tsx @@ -71,9 +71,9 @@ const SettingsBar = ({ setActiveTab, activeTab }: SettingsBarProps) => { + {isShareModalOpen && ( + { + setShareModalState(false); + }} + conversationId={conversationId} + /> + )} + + ); +} diff --git a/frontend/src/components/SkeletonLoader.tsx b/frontend/src/components/SkeletonLoader.tsx new file mode 100644 index 00000000..e9a136e4 --- /dev/null +++ b/frontend/src/components/SkeletonLoader.tsx @@ -0,0 +1,138 @@ +import React, { useState, useEffect } from 'react'; + +interface SkeletonLoaderProps { + count?: number; + component?: 'default' | 'analysis' | 'chatbot' | 'logs'; +} + +const SkeletonLoader: React.FC = ({ + count = 1, + component = 'default', +}) => { + const [skeletonCount, setSkeletonCount] = useState(count); + + useEffect(() => { + const handleResize = () => { + const windowWidth = window.innerWidth; + + if (windowWidth > 1024) { + setSkeletonCount(1); + } else if (windowWidth > 768) { + setSkeletonCount(count); + } else { + setSkeletonCount(Math.min(count, 2)); + } + }; + + handleResize(); + window.addEventListener('resize', handleResize); + + return () => { + window.removeEventListener('resize', handleResize); + }; + }, [count]); + + return ( +
+ {component === 'default' ? ( + [...Array(skeletonCount)].map((_, idx) => ( +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ )) + ) : component === 'analysis' ? ( + [...Array(skeletonCount)].map((_, idx) => ( +
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ )) + ) : component === 'chatbot' ? ( +
+
+
+
+
+
+
+
+ + {[...Array(skeletonCount * 6)].map((_, idx) => ( +
+
+
+
+
+
+ ))} +
+ ) : ( + [...Array(skeletonCount)].map((_, idx) => ( +
+
+
+
+
+
+
+
+
+ )) + )} +
+ ); +}; + +export default SkeletonLoader; diff --git a/frontend/src/components/SourceDropdown.tsx b/frontend/src/components/SourceDropdown.tsx index 6a348161..f92173a0 100644 --- a/frontend/src/components/SourceDropdown.tsx +++ b/frontend/src/components/SourceDropdown.tsx @@ -121,9 +121,12 @@ function SourceDropdown({ className="flex cursor-pointer items-center justify-between hover:bg-gray-100 dark:text-bright-gray dark:hover:bg-purple-taupe" onClick={handleEmptyDocumentSelect} > - { - handlePostDocumentSelect(null); - }}> + { + handlePostDocumentSelect(null); + }} + > {t('none')}
diff --git a/frontend/src/conversation/Conversation.tsx b/frontend/src/conversation/Conversation.tsx index fb819922..a8326da7 100644 --- a/frontend/src/conversation/Conversation.tsx +++ b/frontend/src/conversation/Conversation.tsx @@ -1,22 +1,24 @@ import { Fragment, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useDispatch, useSelector } from 'react-redux'; - +import newChatIcon from '../assets/openNewChat.svg'; import ArrowDown from '../assets/arrow-down.svg'; import Send from '../assets/send.svg'; import SendDark from '../assets/send_dark.svg'; -import ShareIcon from '../assets/share.svg'; import SpinnerDark from '../assets/spinner-dark.svg'; import Spinner from '../assets/spinner.svg'; import RetryIcon from '../components/RetryIcon'; +import { useNavigate } from 'react-router-dom'; import Hero from '../Hero'; import { useDarkTheme, useMediaQuery } from '../hooks'; import { ShareConversationModal } from '../modals/ShareConversationModal'; +import { setConversation, updateConversationId } from './conversationSlice'; import { selectConversationId } from '../preferences/preferenceSlice'; import { AppDispatch } from '../store'; import ConversationBubble from './ConversationBubble'; import { handleSendFeedback } from './conversationHandlers'; import { FEEDBACK, Query } from './conversationModels'; +import ShareIcon from '../assets/share.svg'; import { addQuery, fetchAnswer, @@ -27,6 +29,7 @@ import { export default function Conversation() { const queries = useSelector(selectQueries); + const navigate = useNavigate(); const status = useSelector(selectStatus); const conversationId = useSelector(selectConversationId); const dispatch = useDispatch(); @@ -46,6 +49,9 @@ export default function Conversation() { }; useEffect(() => { !eventInterrupt && scrollIntoView(); + if (queries.length == 0) { + resetConversation(); + } }, [queries.length, queries[queries.length - 1]]); useEffect(() => { @@ -120,6 +126,17 @@ export default function Conversation() { handleInput(); } }; + const resetConversation = () => { + dispatch(setConversation([])); + dispatch( + updateConversationId({ + query: { conversationId: null }, + }), + ); + }; + const newChat = () => { + if (queries && queries.length > 0) resetConversation(); + }; const prepResponseView = (query: Query, index: number) => { let responseView; @@ -197,23 +214,41 @@ export default function Conversation() { }; }, []); return ( -
- {conversationId && ( - <> +
+ {conversationId && queries.length > 0 && ( +
{' '} - +
+ {isMobile && queries.length > 0 && ( + + )} + + +
{isShareModalOpen && ( { @@ -222,7 +257,7 @@ export default function Conversation() { conversationId={conversationId} /> )} - +
)}
{ + event.preventDefault(); + }, []); + + useEffect(() => { + const conversationsMainDiv = document.getElementById( + 'conversationsMainDiv', + ); + + if (conversationsMainDiv) { + if (isOpen) { + conversationsMainDiv.addEventListener('wheel', preventScroll, { + passive: false, + }); + conversationsMainDiv.addEventListener('touchmove', preventScroll, { + passive: false, + }); + } else { + conversationsMainDiv.removeEventListener('wheel', preventScroll); + conversationsMainDiv.removeEventListener('touchmove', preventScroll); + } + + return () => { + conversationsMainDiv.removeEventListener('wheel', preventScroll); + conversationsMainDiv.removeEventListener('touchmove', preventScroll); + }; + } + }, [isOpen]); + function onClear() { setConversationsName(conversation.name); setIsEdit(false); @@ -147,7 +183,7 @@ export default function ConversationTile({
diff --git a/frontend/src/settings/Analytics.tsx b/frontend/src/settings/Analytics.tsx index 5ddab2cb..8baad361 100644 --- a/frontend/src/settings/Analytics.tsx +++ b/frontend/src/settings/Analytics.tsx @@ -1,3 +1,4 @@ +import React, { useState, useEffect } from 'react'; import { BarElement, CategoryScale, @@ -7,7 +8,6 @@ import { Title, Tooltip, } from 'chart.js'; -import React from 'react'; import { Bar } from 'react-chartjs-2'; import userService from '../api/services/userService'; @@ -17,6 +17,8 @@ import { formatDate } from '../utils/dateTimeUtils'; import { APIKeyData } from './types'; import type { ChartData } from 'chart.js'; +import SkeletonLoader from '../components/SkeletonLoader'; + ChartJS.register( CategoryScale, LinearScale, @@ -35,37 +37,37 @@ const filterOptions = [ ]; export default function Analytics() { - const [messagesData, setMessagesData] = React.useState | null>(null); - const [tokenUsageData, setTokenUsageData] = React.useState | null>(null); - const [feedbackData, setFeedbackData] = React.useState | null>(null); - const [chatbots, setChatbots] = React.useState([]); - const [selectedChatbot, setSelectedChatbot] = - React.useState(); - const [messagesFilter, setMessagesFilter] = React.useState<{ + const [chatbots, setChatbots] = useState([]); + const [selectedChatbot, setSelectedChatbot] = useState(); + const [messagesFilter, setMessagesFilter] = useState<{ label: string; value: string; }>({ label: '30 Days', value: 'last_30_days' }); - const [tokenUsageFilter, setTokenUsageFilter] = React.useState<{ + const [tokenUsageFilter, setTokenUsageFilter] = useState<{ label: string; value: string; }>({ label: '30 Days', value: 'last_30_days' }); - const [feedbackFilter, setFeedbackFilter] = React.useState<{ + const [feedbackFilter, setFeedbackFilter] = useState<{ label: string; value: string; }>({ label: '30 Days', value: 'last_30_days' }); + const [loadingMessages, setLoadingMessages] = useState(true); + const [loadingTokens, setLoadingTokens] = useState(true); + const [loadingFeedback, setLoadingFeedback] = useState(true); + const fetchChatbots = async () => { try { const response = await userService.getAPIKeys(); @@ -80,6 +82,7 @@ export default function Analytics() { }; const fetchMessagesData = async (chatbot_id?: string, filter?: string) => { + setLoadingMessages(true); try { const response = await userService.getMessageAnalytics({ api_key_id: chatbot_id, @@ -92,10 +95,13 @@ export default function Analytics() { setMessagesData(data.messages); } catch (error) { console.error(error); + } finally { + setLoadingMessages(false); } }; const fetchTokenData = async (chatbot_id?: string, filter?: string) => { + setLoadingTokens(true); try { const response = await userService.getTokenAnalytics({ api_key_id: chatbot_id, @@ -108,10 +114,13 @@ export default function Analytics() { setTokenUsageData(data.token_usage); } catch (error) { console.error(error); + } finally { + setLoadingTokens(false); } }; const fetchFeedbackData = async (chatbot_id?: string, filter?: string) => { + setLoadingFeedback(true); try { const response = await userService.getFeedbackAnalytics({ api_key_id: chatbot_id, @@ -124,30 +133,33 @@ export default function Analytics() { setFeedbackData(data.feedback); } catch (error) { console.error(error); + } finally { + setLoadingFeedback(false); } }; - React.useEffect(() => { + useEffect(() => { fetchChatbots(); }, []); - React.useEffect(() => { + useEffect(() => { const id = selectedChatbot?.id; const filter = messagesFilter; fetchMessagesData(id, filter?.value); }, [selectedChatbot, messagesFilter]); - React.useEffect(() => { + useEffect(() => { const id = selectedChatbot?.id; const filter = tokenUsageFilter; fetchTokenData(id, filter?.value); }, [selectedChatbot, tokenUsageFilter]); - React.useEffect(() => { + useEffect(() => { const id = selectedChatbot?.id; const filter = feedbackFilter; fetchFeedbackData(id, filter?.value); }, [selectedChatbot, feedbackFilter]); + return (
@@ -181,8 +193,10 @@ export default function Analytics() { border="border" />
+ + {/* Messages Analytics */}
-
+

Messages @@ -208,26 +222,32 @@ export default function Analytics() { id="legend-container-1" className="flex flex-row items-center justify-end" >

- - formatDate(item), - ), - datasets: [ - { - label: 'Messages', - data: Object.values(messagesData || {}), - backgroundColor: '#7D54D1', - }, - ], - }} - legendID="legend-container-1" - maxTicksLimitInX={8} - isStacked={false} - /> + {loadingMessages ? ( + + ) : ( + + formatDate(item), + ), + datasets: [ + { + label: 'Messages', + data: Object.values(messagesData || {}), + backgroundColor: '#7D54D1', + }, + ], + }} + legendID="legend-container-1" + maxTicksLimitInX={8} + isStacked={false} + /> + )}
-
+ + {/* Token Usage Analytics */} +

Token Usage @@ -253,31 +273,37 @@ export default function Analytics() { id="legend-container-2" className="flex flex-row items-center justify-end" >

- - formatDate(item), - ), - datasets: [ - { - label: 'Tokens', - data: Object.values(tokenUsageData || {}), - backgroundColor: '#7D54D1', - }, - ], - }} - legendID="legend-container-2" - maxTicksLimitInX={8} - isStacked={false} - /> + {loadingTokens ? ( + + ) : ( + + formatDate(item), + ), + datasets: [ + { + label: 'Tokens', + data: Object.values(tokenUsageData || {}), + backgroundColor: '#7D54D1', + }, + ], + }} + legendID="legend-container-2" + maxTicksLimitInX={8} + isStacked={false} + /> + )}
-
-
+ + {/* Feedback Analytics */} +
+

- User Feedback + Feedback

- - formatDate(item), - ), - datasets: [ - { - label: 'Positive', - data: Object.values(feedbackData || {}).map( - (item) => item.positive, - ), - backgroundColor: '#8BD154', - }, - { - label: 'Negative', - data: Object.values(feedbackData || {}).map( - (item) => item.negative, - ), - backgroundColor: '#D15454', - }, - ], - }} - legendID="legend-container-3" - maxTicksLimitInX={10} - isStacked={true} - /> + {loadingFeedback ? ( + + ) : ( + + formatDate(item), + ), + datasets: [ + { + label: 'Positive Feedback', + data: Object.values(feedbackData || {}).map( + (item) => item.positive, + ), + backgroundColor: '#7D54D1', + }, + { + label: 'Negative Feedback', + data: Object.values(feedbackData || {}).map( + (item) => item.negative, + ), + backgroundColor: '#FF6384', + }, + ], + }} + legendID="legend-container-3" + maxTicksLimitInX={8} + isStacked={false} + /> + )}
diff --git a/frontend/src/settings/Documents.tsx b/frontend/src/settings/Documents.tsx index 67cdb6d0..004b6f48 100644 --- a/frontend/src/settings/Documents.tsx +++ b/frontend/src/settings/Documents.tsx @@ -9,6 +9,7 @@ import Trash from '../assets/trash.svg'; import caretSort from '../assets/caret-sort.svg'; import DropdownMenu from '../components/DropdownMenu'; import { Doc, DocumentsProps, ActiveState } from '../models/misc'; // Ensure ActiveState type is imported +import SkeletonLoader from '../components/SkeletonLoader'; import { getDocs } from '../preferences/preferenceApi'; import { setSourceDocs } from '../preferences/preferenceSlice'; import Input from '../components/Input'; @@ -43,6 +44,7 @@ const Documents: React.FC = ({ // State for modal: active/inactive const [modalState, setModalState] = useState('INACTIVE'); // Initialize with inactive state const [isOnboarding, setIsOnboarding] = useState(false); // State for onboarding flag + const [loading, setLoading] = useState(false); const syncOptions = [ { label: 'Never', value: 'never' }, @@ -52,6 +54,7 @@ const Documents: React.FC = ({ ]; const handleManageSync = (doc: Doc, sync_frequency: string) => { + setLoading(true); userService .manageSync({ source_id: doc.id, sync_frequency }) .then(() => { @@ -60,7 +63,10 @@ const Documents: React.FC = ({ .then((data) => { dispatch(setSourceDocs(data)); }) - .catch((error) => console.error(error)); + .catch((error) => console.error(error)) + .finally(() => { + setLoading(false); + }); }; // Filter documents based on the search term @@ -72,7 +78,7 @@ const Documents: React.FC = ({
-
+
= ({ Add New
- - - - - - - - - - - - {!filteredDocuments?.length && ( + {loading ? ( + + ) : ( +
{t('settings.documents.name')} -
- {t('settings.documents.date')}{' '} - {' '} -
-
-
- {t('settings.documents.tokenUsage')}{' '} - -
-
-
- {t('settings.documents.type')}{' '} - -
-
+ - + + + + + - )} - {filteredDocuments && - filteredDocuments.map((document, index) => ( - - - - - - + + {!filteredDocuments?.length && ( + + - ))} - -
- {t('settings.documents.noData')} - {t('settings.documents.name')} +
+ {t('settings.documents.date')} + +
+
+
+ {t('settings.documents.tokenUsage')} + +
+
+
+ {t('settings.documents.type')} + +
+
{document.name}{document.date} - {document.tokens ? formatTokens(+document.tokens) : ''} - - {document.type === 'remote' ? 'Pre-loaded' : 'Private'} - -
- {document.type !== 'remote' && ( - Delete { - event.stopPropagation(); - handleDeleteDocument(index, document); - }} - /> - )} - {document.syncFrequency && ( -
- { - handleManageSync(document, value); - }} - defaultValue={document.syncFrequency} - icon={SyncIcon} - /> -
- )} -
+
+ {t('settings.documents.noData')}
+ )} + {filteredDocuments && + filteredDocuments.map((document, index) => ( + + {document.name} + {document.date} + + {document.tokens ? formatTokens(+document.tokens) : ''} + + + {document.type === 'remote' ? 'Pre-loaded' : 'Private'} + + +
+ {document.type !== 'remote' && ( + Delete { + event.stopPropagation(); + handleDeleteDocument(index, document); + }} + /> + )} + {document.syncFrequency && ( +
+ { + handleManageSync(document, value); + }} + defaultValue={document.syncFrequency} + icon={SyncIcon} + /> +
+ )} +
+ + + ))} + + + )}
{/* Conditionally render the Upload modal based on modalState */} {modalState === 'ACTIVE' && ( diff --git a/frontend/src/settings/Logs.tsx b/frontend/src/settings/Logs.tsx index 58ab930d..1e248d46 100644 --- a/frontend/src/settings/Logs.tsx +++ b/frontend/src/settings/Logs.tsx @@ -1,20 +1,23 @@ -import React from 'react'; +import React, { useState, useEffect, useRef, useCallback } from 'react'; import userService from '../api/services/userService'; import ChevronRight from '../assets/chevron-right.svg'; import Dropdown from '../components/Dropdown'; +import SkeletonLoader from '../components/SkeletonLoader'; import { APIKeyData, LogData } from './types'; import CoppyButton from '../components/CopyButton'; export default function Logs() { - const [chatbots, setChatbots] = React.useState([]); - const [selectedChatbot, setSelectedChatbot] = - React.useState(); - const [logs, setLogs] = React.useState([]); - const [page, setPage] = React.useState(1); - const [hasMore, setHasMore] = React.useState(true); + const [chatbots, setChatbots] = useState([]); + const [selectedChatbot, setSelectedChatbot] = useState(); + const [logs, setLogs] = useState([]); + const [page, setPage] = useState(1); + const [hasMore, setHasMore] = useState(true); + const [loadingChatbots, setLoadingChatbots] = useState(true); + const [loadingLogs, setLoadingLogs] = useState(true); const fetchChatbots = async () => { + setLoadingChatbots(true); try { const response = await userService.getAPIKeys(); if (!response.ok) { @@ -24,10 +27,13 @@ export default function Logs() { setChatbots(chatbots); } catch (error) { console.error(error); + } finally { + setLoadingChatbots(false); } }; const fetchLogs = async () => { + setLoadingLogs(true); try { const response = await userService.getLogs({ page: page, @@ -38,20 +44,23 @@ export default function Logs() { throw new Error('Failed to fetch logs'); } const olderLogs = await response.json(); - setLogs([...logs, ...olderLogs.logs]); + setLogs((prevLogs) => [...prevLogs, ...olderLogs.logs]); setHasMore(olderLogs.has_more); } catch (error) { console.error(error); + } finally { + setLoadingLogs(false); } }; - React.useEffect(() => { + useEffect(() => { fetchChatbots(); }, []); - React.useEffect(() => { + useEffect(() => { if (hasMore) fetchLogs(); }, [page, selectedChatbot]); + return (
@@ -59,38 +68,47 @@ export default function Logs() {

Filter by chatbot

- ({ - label: chatbot.name, - value: chatbot.id, - })), - { label: 'None', value: '' }, - ]} - placeholder="Select chatbot" - onSelect={(chatbot: { label: string; value: string }) => { - setSelectedChatbot( - chatbots.find((item) => item.id === chatbot.value), - ); - setLogs([]); - setPage(1); - setHasMore(true); - }} - selectedValue={ - (selectedChatbot && { - label: selectedChatbot.name, - value: selectedChatbot.id, - }) || - null - } - rounded="3xl" - border="border" - /> + {loadingChatbots ? ( + + ) : ( + ({ + label: chatbot.name, + value: chatbot.id, + })), + { label: 'None', value: '' }, + ]} + placeholder="Select chatbot" + onSelect={(chatbot: { label: string; value: string }) => { + setSelectedChatbot( + chatbots.find((item) => item.id === chatbot.value), + ); + setLogs([]); + setPage(1); + setHasMore(true); + }} + selectedValue={ + (selectedChatbot && { + label: selectedChatbot.name, + value: selectedChatbot.id, + }) || + null + } + rounded="3xl" + border="border" + /> + )}
+
- + {loadingLogs ? ( + + ) : ( + + )}
); @@ -102,15 +120,16 @@ type LogsTableProps = { }; function LogsTable({ logs, setPage }: LogsTableProps) { - const observerRef = React.useRef(); - const firstObserver = React.useCallback((node: HTMLDivElement) => { + const observerRef = useRef(); + const firstObserver = useCallback((node: HTMLDivElement) => { if (observerRef.current) { - observerRef.current = new IntersectionObserver((enteries) => { - if (enteries[0].isIntersecting) setPage((prev) => prev + 1); + observerRef.current = new IntersectionObserver((entries) => { + if (entries[0].isIntersecting) setPage((prev) => prev + 1); }); } if (node && observerRef.current) observerRef.current.observe(node); }, []); + return (
diff --git a/setup.sh b/setup.sh index 988841da..7980461b 100755 --- a/setup.sh +++ b/setup.sh @@ -9,6 +9,39 @@ prompt_user() { read -p "Enter your choice (1, 2 or 3): " choice } +check_and_start_docker() { + # Check if Docker is running + if ! docker info > /dev/null 2>&1; then + echo "Docker is not running. Starting Docker..." + + # Check the operating system + case "$(uname -s)" in + Darwin) + open -a Docker + ;; + Linux) + sudo systemctl start docker + ;; + *) + echo "Unsupported platform. Please start Docker manually." + exit 1 + ;; + esac + + # Wait for Docker to be fully operational with animated dots + echo -n "Waiting for Docker to start" + while ! docker system info > /dev/null 2>&1; do + for i in {1..3}; do + echo -n "." + sleep 1 + done + echo -ne "\rWaiting for Docker to start " # Reset to overwrite previous dots + done + + echo -e "\nDocker has started!" + fi +} + # Function to handle the choice to download the model locally download_locally() { echo "LLM_NAME=llama.cpp" > .env @@ -30,6 +63,9 @@ download_locally() { echo "Model already exists." fi + # Call the function to check and start Docker if needed + check_and_start_docker + docker-compose -f docker-compose-local.yaml build && docker-compose -f docker-compose-local.yaml up -d #python -m venv venv #source venv/bin/activate @@ -59,10 +95,11 @@ use_openai() { echo "VITE_API_STREAMING=true" >> .env echo "The .env file has been created with API_KEY set to your provided key." + # Call the function to check and start Docker if needed + check_and_start_docker + docker-compose build && docker-compose up -d - - echo "The application will run on http://localhost:5173" echo "You can stop the application by running the following command:" echo "docker-compose down" @@ -73,6 +110,9 @@ use_docsgpt() { echo "VITE_API_STREAMING=true" >> .env echo "The .env file has been created with API_KEY set to your provided key." + # Call the function to check and start Docker if needed + check_and_start_docker + docker-compose build && docker-compose up -d echo "The application will run on http://localhost:5173" diff --git a/tests/llm/test_anthropic.py b/tests/llm/test_anthropic.py index ee4ba15f..689013c0 100644 --- a/tests/llm/test_anthropic.py +++ b/tests/llm/test_anthropic.py @@ -22,17 +22,23 @@ class TestAnthropicLLM(unittest.TestCase): mock_response = Mock() mock_response.completion = "test completion" - with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create: - response = self.llm.gen("test_model", messages) - self.assertEqual(response, "test completion") + with patch("application.cache.get_redis_instance") as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + mock_redis_instance.set = Mock() - prompt_expected = "### Context \n context \n ### Question \n question" - mock_create.assert_called_with( - model="test_model", - max_tokens_to_sample=300, - stream=False, - prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}" - ) + with patch.object(self.llm.anthropic.completions, "create", return_value=mock_response) as mock_create: + response = self.llm.gen("test_model", messages) + self.assertEqual(response, "test completion") + + prompt_expected = "### Context \n context \n ### Question \n question" + mock_create.assert_called_with( + model="test_model", + max_tokens_to_sample=300, + stream=False, + prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}" + ) + mock_redis_instance.set.assert_called_once() def test_gen_stream(self): messages = [ @@ -41,17 +47,23 @@ class TestAnthropicLLM(unittest.TestCase): ] mock_responses = [Mock(completion="response_1"), Mock(completion="response_2")] - with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: - responses = list(self.llm.gen_stream("test_model", messages)) - self.assertListEqual(responses, ["response_1", "response_2"]) + with patch("application.cache.get_redis_instance") as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + mock_redis_instance.set = Mock() - prompt_expected = "### Context \n context \n ### Question \n question" - mock_create.assert_called_with( - model="test_model", - prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}", - max_tokens_to_sample=300, - stream=True - ) + with patch.object(self.llm.anthropic.completions, "create", return_value=iter(mock_responses)) as mock_create: + responses = list(self.llm.gen_stream("test_model", messages)) + self.assertListEqual(responses, ["response_1", "response_2"]) + + prompt_expected = "### Context \n context \n ### Question \n question" + mock_create.assert_called_with( + model="test_model", + prompt=f"{self.llm.HUMAN_PROMPT} {prompt_expected}{self.llm.AI_PROMPT}", + max_tokens_to_sample=300, + stream=True + ) + mock_redis_instance.set.assert_called_once() if __name__ == "__main__": unittest.main() diff --git a/tests/llm/test_sagemaker.py b/tests/llm/test_sagemaker.py index 0602f597..d659d498 100644 --- a/tests/llm/test_sagemaker.py +++ b/tests/llm/test_sagemaker.py @@ -52,28 +52,38 @@ class TestSagemakerAPILLM(unittest.TestCase): self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result) def test_gen(self): - with patch.object(self.sagemaker.runtime, 'invoke_endpoint', - return_value=self.response) as mock_invoke_endpoint: - output = self.sagemaker.gen(None, self.messages) - mock_invoke_endpoint.assert_called_once_with( - EndpointName=self.sagemaker.endpoint, - ContentType='application/json', - Body=self.body_bytes - ) - self.assertEqual(output, - self.result[0]['generated_text'][len(self.prompt):]) + with patch('application.cache.get_redis_instance') as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + + with patch.object(self.sagemaker.runtime, 'invoke_endpoint', + return_value=self.response) as mock_invoke_endpoint: + output = self.sagemaker.gen(None, self.messages) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes + ) + self.assertEqual(output, + self.result[0]['generated_text'][len(self.prompt):]) + mock_make_redis.assert_called_once() + mock_redis_instance.set.assert_called_once() def test_gen_stream(self): - with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', - return_value=self.response) as mock_invoke_endpoint: - output = list(self.sagemaker.gen_stream(None, self.messages)) - mock_invoke_endpoint.assert_called_once_with( - EndpointName=self.sagemaker.endpoint, - ContentType='application/json', - Body=self.body_bytes_stream - ) - self.assertEqual(output, []) - + with patch('application.cache.get_redis_instance') as mock_make_redis: + mock_redis_instance = mock_make_redis.return_value + mock_redis_instance.get.return_value = None + + with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream', + return_value=self.response) as mock_invoke_endpoint: + output = list(self.sagemaker.gen_stream(None, self.messages)) + mock_invoke_endpoint.assert_called_once_with( + EndpointName=self.sagemaker.endpoint, + ContentType='application/json', + Body=self.body_bytes_stream + ) + self.assertEqual(output, []) + mock_redis_instance.set.assert_called_once() class TestLineIterator(unittest.TestCase): def setUp(self): diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 00000000..4270a181 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,131 @@ +import unittest +import json +from unittest.mock import patch, MagicMock +from application.cache import gen_cache_key, stream_cache, gen_cache +from application.utils import get_hash + + +# Test for gen_cache_key function +def test_make_gen_cache_key(): + messages = [ + {'role': 'user', 'content': 'test_user_message'}, + {'role': 'system', 'content': 'test_system_message'}, + ] + model = "test_docgpt" + + # Manually calculate the expected hash + expected_combined = f"{model}_{json.dumps(messages, sort_keys=True)}" + expected_hash = get_hash(expected_combined) + cache_key = gen_cache_key(*messages, model=model) + + assert cache_key == expected_hash + +def test_gen_cache_key_invalid_message_format(): + # Test when messages is not a list + with unittest.TestCase.assertRaises(unittest.TestCase, ValueError) as context: + gen_cache_key("This is not a list", model="docgpt") + assert str(context.exception) == "All messages must be dictionaries." + +# Test for gen_cache decorator +@patch('application.cache.get_redis_instance') # Mock the Redis client +def test_gen_cache_hit(mock_make_redis): + # Arrange + mock_redis_instance = MagicMock() + mock_make_redis.return_value = mock_redis_instance + mock_redis_instance.get.return_value = b"cached_result" # Simulate a cache hit + + @gen_cache + def mock_function(self, model, messages): + return "new_result" + + messages = [{'role': 'user', 'content': 'test_user_message'}] + model = "test_docgpt" + + # Act + result = mock_function(None, model, messages) + + # Assert + assert result == "cached_result" # Should return cached result + mock_redis_instance.get.assert_called_once() # Ensure Redis get was called + mock_redis_instance.set.assert_not_called() # Ensure the function result is not cached again + + +@patch('application.cache.get_redis_instance') # Mock the Redis client +def test_gen_cache_miss(mock_make_redis): + # Arrange + mock_redis_instance = MagicMock() + mock_make_redis.return_value = mock_redis_instance + mock_redis_instance.get.return_value = None # Simulate a cache miss + + @gen_cache + def mock_function(self, model, messages): + return "new_result" + + messages = [ + {'role': 'user', 'content': 'test_user_message'}, + {'role': 'system', 'content': 'test_system_message'}, + ] + model = "test_docgpt" + # Act + result = mock_function(None, model, messages) + + # Assert + assert result == "new_result" + mock_redis_instance.get.assert_called_once() + +@patch('application.cache.get_redis_instance') +def test_stream_cache_hit(mock_make_redis): + # Arrange + mock_redis_instance = MagicMock() + mock_make_redis.return_value = mock_redis_instance + + cached_chunk = json.dumps(["chunk1", "chunk2"]).encode('utf-8') + mock_redis_instance.get.return_value = cached_chunk + + @stream_cache + def mock_function(self, model, messages, stream): + yield "new_chunk" + + messages = [{'role': 'user', 'content': 'test_user_message'}] + model = "test_docgpt" + + # Act + result = list(mock_function(None, model, messages, stream=True)) + + # Assert + assert result == ["chunk1", "chunk2"] # Should return cached chunks + mock_redis_instance.get.assert_called_once() + mock_redis_instance.set.assert_not_called() + + +@patch('application.cache.get_redis_instance') +def test_stream_cache_miss(mock_make_redis): + # Arrange + mock_redis_instance = MagicMock() + mock_make_redis.return_value = mock_redis_instance + mock_redis_instance.get.return_value = None # Simulate a cache miss + + @stream_cache + def mock_function(self, model, messages, stream): + yield "new_chunk" + + messages = [ + {'role': 'user', 'content': 'This is the context'}, + {'role': 'system', 'content': 'Some other message'}, + {'role': 'user', 'content': 'What is the answer?'} + ] + model = "test_docgpt" + + # Act + result = list(mock_function(None, model, messages, stream=True)) + + # Assert + assert result == ["new_chunk"] + mock_redis_instance.get.assert_called_once() + mock_redis_instance.set.assert_called_once() + + + + + +