diff --git a/application/api/answer/routes.py b/application/api/answer/routes.py index 17eb5cc3..40c4d8cb 100644 --- a/application/api/answer/routes.py +++ b/application/api/answer/routes.py @@ -11,7 +11,7 @@ from bson.objectid import ObjectId from flask import Blueprint, current_app, make_response, request, Response from flask_restx import fields, Namespace, Resource -from pymongo import MongoClient +from core.mongo_db import MongoDB from application.core.settings import settings from application.error import bad_request @@ -22,7 +22,7 @@ from application.utils import check_required_fields logger = logging.getLogger(__name__) -mongo = MongoClient(settings.MONGO_URI) +mongo = MongoDB.get_client() db = mongo["docsgpt"] conversations_collection = db["conversations"] sources_collection = db["sources"] diff --git a/application/api/internal/routes.py b/application/api/internal/routes.py index 6ecb4346..f004cf97 100755 --- a/application/api/internal/routes.py +++ b/application/api/internal/routes.py @@ -1,13 +1,13 @@ import os import datetime from flask import Blueprint, request, send_from_directory -from pymongo import MongoClient +from core.mongo_db import MongoDB from werkzeug.utils import secure_filename from bson.objectid import ObjectId from application.core.settings import settings -mongo = MongoClient(settings.MONGO_URI) +mongo = MongoDB.get_client() db = mongo["docsgpt"] conversations_collection = db["conversations"] sources_collection = db["sources"] diff --git a/application/api/user/routes.py b/application/api/user/routes.py index feee91cc..3469c800 100644 --- a/application/api/user/routes.py +++ b/application/api/user/routes.py @@ -8,7 +8,7 @@ from bson.dbref import DBRef from bson.objectid import ObjectId from flask import Blueprint, jsonify, make_response, request from flask_restx import inputs, fields, Namespace, Resource -from pymongo import MongoClient +from core.mongo_db import MongoDB from werkzeug.utils import secure_filename from application.api.user.tasks import ingest, ingest_remote @@ -18,7 +18,7 @@ from application.extensions import api from application.utils import check_required_fields from application.vectorstore.vector_creator import VectorCreator -mongo = MongoClient(settings.MONGO_URI) +mongo = MongoDB.get_client() db = mongo["docsgpt"] conversations_collection = db["conversations"] sources_collection = db["sources"] diff --git a/application/core/mongo_db.py b/application/core/mongo_db.py new file mode 100644 index 00000000..ffb55d7f --- /dev/null +++ b/application/core/mongo_db.py @@ -0,0 +1,25 @@ +from application.core import settings +from pymongo import MongoClient +from flask import current_app, g + + +class MongoDB: + _client = None + + @classmethod + def get_client(cls): + """ + Get the MongoDB client instance, creating it if necessary. + """ + if cls._client is None: + cls._client = MongoClient(settings.MONGO_URI) + return cls._client + + @classmethod + def close_client(cls): + """ + Close the MongoDB client connection. + """ + if cls._client is not None: + cls._client.close() + cls._client = None diff --git a/application/usage.py b/application/usage.py index aba0ec77..21797817 100644 --- a/application/usage.py +++ b/application/usage.py @@ -1,10 +1,10 @@ import sys -from pymongo import MongoClient +from core.mongo_db import MongoDB from datetime import datetime from application.core.settings import settings from application.utils import num_tokens_from_string -mongo = MongoClient(settings.MONGO_URI) +mongo = MongoDB.get_client() db = mongo["docsgpt"] usage_collection = db["token_usage"] diff --git a/application/worker.py b/application/worker.py index f8f38afa..fc780d61 100755 --- a/application/worker.py +++ b/application/worker.py @@ -8,7 +8,7 @@ from urllib.parse import urljoin import requests from bson.objectid import ObjectId -from pymongo import MongoClient +from core.mongo_db import MongoDB from application.core.settings import settings from application.parser.file.bulk import SimpleDirectoryReader @@ -18,7 +18,7 @@ from application.parser.schema.base import Document from application.parser.token_func import group_split from application.utils import count_tokens_docs -mongo = MongoClient(settings.MONGO_URI) +mongo = MongoDB.get_client() db = mongo["docsgpt"] sources_collection = db["sources"]