diff --git a/application/llm/anthropic.py b/application/llm/anthropic.py index 31e45633..a078fc82 100644 --- a/application/llm/anthropic.py +++ b/application/llm/anthropic.py @@ -16,13 +16,13 @@ class AnthropicLLM(BaseLLM): self.AI_PROMPT = AI_PROMPT def _raw_gen( - self, baseself, model, messages, max_tokens=300, stream=False, **kwargs + self, baseself, model, messages, stream=False, max_tokens=300, **kwargs ): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}" if stream: - return self.gen_stream(model, prompt, max_tokens, **kwargs) + return self.gen_stream(model, prompt, stream, max_tokens, **kwargs) completion = self.anthropic.completions.create( model=model, @@ -32,7 +32,9 @@ class AnthropicLLM(BaseLLM): ) return completion.completion - def _raw_gen_stream(self, baseself, model, messages, max_tokens=300, **kwargs): + def _raw_gen_stream( + self, baseself, model, messages, stream=True, max_tokens=300, **kwargs + ): context = messages[0]["content"] user_question = messages[-1]["content"] prompt = f"### Context \n {context} \n ### Question \n {user_question}"