From cd8029840feeae82f5af3fde6c537872d0ea18db Mon Sep 17 00:00:00 2001 From: Alex Date: Mon, 9 Feb 2026 11:53:25 +0000 Subject: [PATCH] feat: stream thinking tokens --- application/llm/google_ai.py | 36 +++++- application/llm/handlers/base.py | 3 + application/llm/openai.py | 73 +++++++++-- .../src/conversation/ConversationBubble.tsx | 2 +- tests/llm/handlers/test_google.py | 16 +++ tests/llm/handlers/test_openai.py | 16 +++ tests/llm/test_google_llm.py | 121 +++++++++++++++++- tests/llm/test_openai_llm.py | 72 ++++++++++- 8 files changed, 320 insertions(+), 19 deletions(-) diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index 00c609ec..bd262a87 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -378,6 +378,22 @@ class GoogleLLM(BaseLLM): last_preview = f"{last_preview[:preview_chars]}..." return f"count={message_count}, last='{last_preview}'" + @staticmethod + def _get_text_value(part): + """Get text from both SDK objects and dict-shaped test doubles.""" + if isinstance(part, dict): + value = part.get("text") + return value if isinstance(value, str) else "" + value = getattr(part, "text", None) + return value if isinstance(value, str) else "" + + @staticmethod + def _is_thought_part(part): + """Detect Gemini thinking parts when available.""" + if isinstance(part, dict): + return bool(part.get("thought")) + return bool(getattr(part, "thought", False)) + def _raw_gen( self, baseself, @@ -438,7 +454,6 @@ class GoogleLLM(BaseLLM): if tools: cleaned_tools = self._clean_tools_format(tools) config.tools = cleaned_tools - # Add response schema for structured output if provided if response_schema: config.response_schema = response_schema @@ -475,10 +490,23 @@ class GoogleLLM(BaseLLM): for part in candidate.content.parts: if part.function_call: yield part - elif part.text: - yield part.text + continue + + part_text = self._get_text_value(part) + if not part_text: + continue + + if self._is_thought_part(part): + yield {"type": "thought", "thought": part_text} + else: + yield part_text elif hasattr(chunk, "text"): - yield chunk.text + chunk_text = self._get_text_value(chunk) + if chunk_text: + if self._is_thought_part(chunk): + yield {"type": "thought", "thought": chunk_text} + else: + yield chunk_text finally: if hasattr(response, "close"): response.close() diff --git a/application/llm/handlers/base.py b/application/llm/handlers/base.py index b673a604..e33bd18e 100644 --- a/application/llm/handlers/base.py +++ b/application/llm/handlers/base.py @@ -878,6 +878,9 @@ class LLMHandler(ABC): tool_calls = {} for chunk in self._iterate_stream(response): + if isinstance(chunk, dict) and chunk.get("type") == "thought": + yield chunk + continue if isinstance(chunk, str): yield chunk continue diff --git a/application/llm/openai.py b/application/llm/openai.py index 263b4b5a..d3ece394 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -151,6 +151,51 @@ class OpenAILLM(BaseLLM): raise ValueError(f"Unexpected content type: {type(content)}") return cleaned_messages + @staticmethod + def _normalize_reasoning_value(value): + """Normalize reasoning payloads from OpenAI-compatible stream chunks.""" + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, list): + return "".join( + OpenAILLM._normalize_reasoning_value(item) for item in value + ) + if isinstance(value, dict): + for key in ("text", "content", "value", "reasoning_content", "reasoning"): + normalized = OpenAILLM._normalize_reasoning_value(value.get(key)) + if normalized: + return normalized + return "" + + for attr in ("text", "content", "value"): + if hasattr(value, attr): + normalized = OpenAILLM._normalize_reasoning_value(getattr(value, attr)) + if normalized: + return normalized + return "" + + @classmethod + def _extract_reasoning_text(cls, delta): + """Extract reasoning/thinking tokens from OpenAI-compatible delta chunks.""" + if delta is None: + return "" + + for key in ( + "reasoning_content", + "reasoning", + "thinking", + "thinking_content", + ): + value = getattr(delta, key, None) + if value is None and isinstance(delta, dict): + value = delta.get(key) + normalized = cls._normalize_reasoning_value(value) + if normalized: + return normalized + return "" + def _raw_gen( self, baseself, @@ -221,14 +266,26 @@ class OpenAILLM(BaseLLM): try: for line in response: logging.debug(f"OpenAI stream line: {line}") - if ( - len(line.choices) > 0 - and line.choices[0].delta.content is not None - and len(line.choices[0].delta.content) > 0 - ): - yield line.choices[0].delta.content - elif len(line.choices) > 0: - yield line.choices[0] + if not getattr(line, "choices", None): + continue + + choice = line.choices[0] + delta = getattr(choice, "delta", None) + reasoning_text = self._extract_reasoning_text(delta) + if reasoning_text: + yield {"type": "thought", "thought": reasoning_text} + + content = getattr(delta, "content", None) + if isinstance(content, str) and content: + yield content + continue + + has_tool_calls = bool(getattr(delta, "tool_calls", None)) + finish_reason = getattr(choice, "finish_reason", None) + + # Yield non-content chunks only when needed for tool-call handling. + if has_tool_calls or finish_reason == "tool_calls": + yield choice finally: if hasattr(response, "close"): response.close() diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index c647168d..d35de34b 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -805,7 +805,7 @@ function Thought({ }) { const { t } = useTranslation(); const [isDarkTheme] = useDarkTheme(); - const [isThoughtOpen, setIsThoughtOpen] = useState(true); + const [isThoughtOpen, setIsThoughtOpen] = useState(false); return (
diff --git a/tests/llm/handlers/test_google.py b/tests/llm/handlers/test_google.py index 8fc00e1e..900b2d5f 100644 --- a/tests/llm/handlers/test_google.py +++ b/tests/llm/handlers/test_google.py @@ -254,6 +254,22 @@ class TestGoogleLLMHandler: assert result == [] + def test_iterate_stream_preserves_thought_events(self): + """Test stream iteration preserves provider-emitted thought events.""" + handler = GoogleLLMHandler() + + mock_chunks = [ + {"type": "thought", "thought": "first thought"}, + "answer token", + ] + + result = list(handler._iterate_stream(mock_chunks)) + + assert result == [ + {"type": "thought", "thought": "first thought"}, + "answer token", + ] + def test_parse_response_parts_without_function_call_attribute(self): """Test parsing response with parts missing function_call attribute.""" handler = GoogleLLMHandler() diff --git a/tests/llm/handlers/test_openai.py b/tests/llm/handlers/test_openai.py index 86ad5096..64c89f6c 100644 --- a/tests/llm/handlers/test_openai.py +++ b/tests/llm/handlers/test_openai.py @@ -188,6 +188,22 @@ class TestOpenAILLMHandler: assert result == [] + def test_iterate_stream_preserves_thought_events(self): + """Test stream iteration preserves provider-emitted thought events.""" + handler = OpenAILLMHandler() + + mock_chunks = [ + {"type": "thought", "thought": "first thought"}, + "answer token", + ] + + result = list(handler._iterate_stream(mock_chunks)) + + assert result == [ + {"type": "thought", "thought": "first thought"}, + "answer token", + ] + def test_parse_response_tool_call_missing_attributes(self): """Test parsing tool calls with missing attributes.""" handler = OpenAILLMHandler() diff --git a/tests/llm/test_google_llm.py b/tests/llm/test_google_llm.py index 80434e98..8addfaa5 100644 --- a/tests/llm/test_google_llm.py +++ b/tests/llm/test_google_llm.py @@ -4,10 +4,11 @@ import pytest from application.llm.google_ai import GoogleLLM class _FakePart: - def __init__(self, text=None, function_call=None, file_data=None): + def __init__(self, text=None, function_call=None, file_data=None, thought=False): self.text = text self.function_call = function_call self.file_data = file_data + self.thought = thought @staticmethod def from_text(text): @@ -38,10 +39,22 @@ class FakeTypesModule: Part = _FakePart Content = _FakeContent + class ThinkingConfig: + def __init__( + self, + include_thoughts=None, + thinking_budget=None, + thinking_level=None, + ): + self.include_thoughts = include_thoughts + self.thinking_budget = thinking_budget + self.thinking_level = thinking_level + class GenerateContentConfig: def __init__(self): self.system_instruction = None self.tools = None + self.thinking_config = None self.response_schema = None self.response_mime_type = None @@ -112,6 +125,111 @@ def test_raw_gen_stream_yields_chunks(): assert list(gen) == ["a", "b"] +def test_raw_gen_stream_does_not_set_thinking_config_by_default(monkeypatch): + captured = {} + + def fake_stream(self, *args, **kwargs): + captured["config"] = kwargs.get("config") + return [types.SimpleNamespace(text="a", candidates=None)] + + monkeypatch.setattr(FakeModels, "generate_content_stream", fake_stream) + + llm = GoogleLLM(api_key="key") + msgs = [{"role": "user", "content": "hello"}] + list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True)) + + assert captured["config"].thinking_config is None + + +def test_raw_gen_stream_sets_thinking_config_when_explicitly_requested(monkeypatch): + captured = {} + + def fake_stream(self, *args, **kwargs): + captured["config"] = kwargs.get("config") + return [types.SimpleNamespace(text="a", candidates=None)] + + monkeypatch.setattr(FakeModels, "generate_content_stream", fake_stream) + + llm = GoogleLLM(api_key="key") + msgs = [{"role": "user", "content": "hello"}] + list( + llm._raw_gen_stream( + llm, + model="gemini", + messages=msgs, + stream=True, + include_thoughts=True, + ) + ) + + assert captured["config"].thinking_config is not None + assert captured["config"].thinking_config.include_thoughts is True + + +def test_raw_gen_stream_emits_thought_events(monkeypatch): + llm = GoogleLLM(api_key="key") + msgs = [{"role": "user", "content": "hello"}] + + thought_part = types.SimpleNamespace( + text="thinking token", + function_call=None, + thought=True, + ) + answer_part = types.SimpleNamespace( + text="answer token", + function_call=None, + thought=False, + ) + chunk = types.SimpleNamespace( + candidates=[ + types.SimpleNamespace( + content=types.SimpleNamespace(parts=[thought_part, answer_part]) + ) + ] + ) + + monkeypatch.setattr( + FakeModels, + "generate_content_stream", + lambda self, *args, **kwargs: [chunk], + ) + + out = list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True)) + + assert {"type": "thought", "thought": "thinking token"} in out + assert "answer token" in out + + +def test_raw_gen_stream_keeps_prefix_like_text_as_answer(monkeypatch): + llm = GoogleLLM(api_key="key") + msgs = [{"role": "user", "content": "hello"}] + prefixed_answer = "[[DOCSGPT_GOOGLE_REASONING]]this is answer text" + + answer_part = types.SimpleNamespace( + text=prefixed_answer, + function_call=None, + thought=False, + ) + chunk = types.SimpleNamespace( + candidates=[ + types.SimpleNamespace( + content=types.SimpleNamespace(parts=[answer_part]) + ) + ] + ) + + monkeypatch.setattr( + FakeModels, + "generate_content_stream", + lambda self, *args, **kwargs: [chunk], + ) + + out = list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True)) + + assert prefixed_answer in out + assert not any(isinstance(item, dict) and item.get("type") == "thought" for item in out) + + def test_prepare_structured_output_format_type_mapping(): llm = GoogleLLM(api_key="key") schema = { @@ -148,4 +266,3 @@ def test_prepare_messages_with_attachments_appends_files(monkeypatch): files_entry = next((p for p in user_msg["content"] if isinstance(p, dict) and "files" in p), None) assert files_entry is not None assert isinstance(files_entry["files"], list) and len(files_entry["files"]) == 2 - diff --git a/tests/llm/test_openai_llm.py b/tests/llm/test_openai_llm.py index 77ffe14c..eb19ce22 100644 --- a/tests/llm/test_openai_llm.py +++ b/tests/llm/test_openai_llm.py @@ -14,18 +14,44 @@ class FakeChatCompletions: self.tool_calls = tool_calls class _Delta: - def __init__(self, content=None): + def __init__(self, content=None, reasoning_content=None, tool_calls=None): self.content = content + self.reasoning_content = reasoning_content + self.tool_calls = tool_calls class _Choice: - def __init__(self, content=None, delta=None, finish_reason="stop"): + def __init__( + self, + content=None, + delta=None, + reasoning_content=None, + tool_calls=None, + finish_reason="stop", + ): self.message = FakeChatCompletions._Msg(content=content) - self.delta = FakeChatCompletions._Delta(content=delta) + self.delta = FakeChatCompletions._Delta( + content=delta, + reasoning_content=reasoning_content, + tool_calls=tool_calls, + ) self.finish_reason = finish_reason class _StreamLine: def __init__(self, deltas): - self.choices = [FakeChatCompletions._Choice(delta=d) for d in deltas] + choices = [] + for delta in deltas: + if isinstance(delta, dict): + choices.append( + FakeChatCompletions._Choice( + delta=delta.get("content"), + reasoning_content=delta.get("reasoning_content"), + tool_calls=delta.get("tool_calls"), + finish_reason=delta.get("finish_reason", "stop"), + ) + ) + else: + choices.append(FakeChatCompletions._Choice(delta=delta)) + self.choices = choices class _Response: def __init__(self, choices=None, lines=None): @@ -144,6 +170,44 @@ def test_raw_gen_stream_yields_chunks(openai_llm): assert "part2" in "".join(chunks) +@pytest.mark.unit +def test_raw_gen_stream_emits_thought_events(openai_llm): + msgs = [{"role": "user", "content": "think first"}] + + openai_llm.client.chat.completions.create = lambda **kwargs: FakeChatCompletions._Response( + lines=[ + FakeChatCompletions._StreamLine( + [{"reasoning_content": "internal thought"}] + ), + FakeChatCompletions._StreamLine([{"content": "final answer"}]), + FakeChatCompletions._StreamLine([{"finish_reason": "stop"}]), + ] + ) + + chunks = list(openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs)) + + assert {"type": "thought", "thought": "internal thought"} in chunks + assert "final answer" in chunks + + +@pytest.mark.unit +def test_raw_gen_stream_keeps_prefix_like_text_as_answer(openai_llm): + msgs = [{"role": "user", "content": "return literal marker"}] + prefixed_answer = "[[DOCSGPT_OPENAI_REASONING]]this is answer text" + + openai_llm.client.chat.completions.create = lambda **kwargs: FakeChatCompletions._Response( + lines=[ + FakeChatCompletions._StreamLine([{"content": prefixed_answer}]), + FakeChatCompletions._StreamLine([{"finish_reason": "stop"}]), + ] + ) + + chunks = list(openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs)) + + assert prefixed_answer in chunks + assert not any(isinstance(chunk, dict) and chunk.get("type") == "thought" for chunk in chunks) + + @pytest.mark.unit def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm): schema = {