fix: api_key capturing + pytest errors

This commit is contained in:
Siddhant Rai
2024-04-15 22:32:24 +05:30
parent 60a670ce29
commit 77991896b4
11 changed files with 276 additions and 145 deletions

View File

@@ -4,7 +4,7 @@ from application.core.settings import settings
class PremAILLM(BaseLLM):
def __init__(self, api_key, *args, **kwargs):
def __init__(self, api_key=None, *args, **kwargs):
from premai import Prem
super().__init__(*args, **kwargs)
@@ -12,7 +12,7 @@ class PremAILLM(BaseLLM):
self.api_key = api_key
self.project_id = settings.PREMAI_PROJECT_ID
def _raw_gen(self, model, messages, stream=False, **kwargs):
def _raw_gen(self, baseself, model, messages, stream=False, **kwargs):
response = self.client.chat.completions.create(
model=model,
project_id=self.project_id,
@@ -23,7 +23,7 @@ class PremAILLM(BaseLLM):
return response.choices[0].message["content"]
def _raw_gen_stream(self, model, messages, stream=True, **kwargs):
def _raw_gen_stream(self, baseself, model, messages, stream=True, **kwargs):
response = self.client.chat.completions.create(
model=model,
project_id=self.project_id,