feat: stream thinking tokens

This commit is contained in:
Alex
2026-02-09 11:53:25 +00:00
parent e602d941ca
commit cd8029840f
8 changed files with 320 additions and 19 deletions

View File

@@ -254,6 +254,22 @@ class TestGoogleLLMHandler:
assert result == []
def test_iterate_stream_preserves_thought_events(self):
"""Test stream iteration preserves provider-emitted thought events."""
handler = GoogleLLMHandler()
mock_chunks = [
{"type": "thought", "thought": "first thought"},
"answer token",
]
result = list(handler._iterate_stream(mock_chunks))
assert result == [
{"type": "thought", "thought": "first thought"},
"answer token",
]
def test_parse_response_parts_without_function_call_attribute(self):
"""Test parsing response with parts missing function_call attribute."""
handler = GoogleLLMHandler()

View File

@@ -188,6 +188,22 @@ class TestOpenAILLMHandler:
assert result == []
def test_iterate_stream_preserves_thought_events(self):
"""Test stream iteration preserves provider-emitted thought events."""
handler = OpenAILLMHandler()
mock_chunks = [
{"type": "thought", "thought": "first thought"},
"answer token",
]
result = list(handler._iterate_stream(mock_chunks))
assert result == [
{"type": "thought", "thought": "first thought"},
"answer token",
]
def test_parse_response_tool_call_missing_attributes(self):
"""Test parsing tool calls with missing attributes."""
handler = OpenAILLMHandler()

View File

@@ -4,10 +4,11 @@ import pytest
from application.llm.google_ai import GoogleLLM
class _FakePart:
def __init__(self, text=None, function_call=None, file_data=None):
def __init__(self, text=None, function_call=None, file_data=None, thought=False):
self.text = text
self.function_call = function_call
self.file_data = file_data
self.thought = thought
@staticmethod
def from_text(text):
@@ -38,10 +39,22 @@ class FakeTypesModule:
Part = _FakePart
Content = _FakeContent
class ThinkingConfig:
def __init__(
self,
include_thoughts=None,
thinking_budget=None,
thinking_level=None,
):
self.include_thoughts = include_thoughts
self.thinking_budget = thinking_budget
self.thinking_level = thinking_level
class GenerateContentConfig:
def __init__(self):
self.system_instruction = None
self.tools = None
self.thinking_config = None
self.response_schema = None
self.response_mime_type = None
@@ -112,6 +125,111 @@ def test_raw_gen_stream_yields_chunks():
assert list(gen) == ["a", "b"]
def test_raw_gen_stream_does_not_set_thinking_config_by_default(monkeypatch):
captured = {}
def fake_stream(self, *args, **kwargs):
captured["config"] = kwargs.get("config")
return [types.SimpleNamespace(text="a", candidates=None)]
monkeypatch.setattr(FakeModels, "generate_content_stream", fake_stream)
llm = GoogleLLM(api_key="key")
msgs = [{"role": "user", "content": "hello"}]
list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True))
assert captured["config"].thinking_config is None
def test_raw_gen_stream_sets_thinking_config_when_explicitly_requested(monkeypatch):
captured = {}
def fake_stream(self, *args, **kwargs):
captured["config"] = kwargs.get("config")
return [types.SimpleNamespace(text="a", candidates=None)]
monkeypatch.setattr(FakeModels, "generate_content_stream", fake_stream)
llm = GoogleLLM(api_key="key")
msgs = [{"role": "user", "content": "hello"}]
list(
llm._raw_gen_stream(
llm,
model="gemini",
messages=msgs,
stream=True,
include_thoughts=True,
)
)
assert captured["config"].thinking_config is not None
assert captured["config"].thinking_config.include_thoughts is True
def test_raw_gen_stream_emits_thought_events(monkeypatch):
llm = GoogleLLM(api_key="key")
msgs = [{"role": "user", "content": "hello"}]
thought_part = types.SimpleNamespace(
text="thinking token",
function_call=None,
thought=True,
)
answer_part = types.SimpleNamespace(
text="answer token",
function_call=None,
thought=False,
)
chunk = types.SimpleNamespace(
candidates=[
types.SimpleNamespace(
content=types.SimpleNamespace(parts=[thought_part, answer_part])
)
]
)
monkeypatch.setattr(
FakeModels,
"generate_content_stream",
lambda self, *args, **kwargs: [chunk],
)
out = list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True))
assert {"type": "thought", "thought": "thinking token"} in out
assert "answer token" in out
def test_raw_gen_stream_keeps_prefix_like_text_as_answer(monkeypatch):
llm = GoogleLLM(api_key="key")
msgs = [{"role": "user", "content": "hello"}]
prefixed_answer = "[[DOCSGPT_GOOGLE_REASONING]]this is answer text"
answer_part = types.SimpleNamespace(
text=prefixed_answer,
function_call=None,
thought=False,
)
chunk = types.SimpleNamespace(
candidates=[
types.SimpleNamespace(
content=types.SimpleNamespace(parts=[answer_part])
)
]
)
monkeypatch.setattr(
FakeModels,
"generate_content_stream",
lambda self, *args, **kwargs: [chunk],
)
out = list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True))
assert prefixed_answer in out
assert not any(isinstance(item, dict) and item.get("type") == "thought" for item in out)
def test_prepare_structured_output_format_type_mapping():
llm = GoogleLLM(api_key="key")
schema = {
@@ -148,4 +266,3 @@ def test_prepare_messages_with_attachments_appends_files(monkeypatch):
files_entry = next((p for p in user_msg["content"] if isinstance(p, dict) and "files" in p), None)
assert files_entry is not None
assert isinstance(files_entry["files"], list) and len(files_entry["files"]) == 2

View File

@@ -14,18 +14,44 @@ class FakeChatCompletions:
self.tool_calls = tool_calls
class _Delta:
def __init__(self, content=None):
def __init__(self, content=None, reasoning_content=None, tool_calls=None):
self.content = content
self.reasoning_content = reasoning_content
self.tool_calls = tool_calls
class _Choice:
def __init__(self, content=None, delta=None, finish_reason="stop"):
def __init__(
self,
content=None,
delta=None,
reasoning_content=None,
tool_calls=None,
finish_reason="stop",
):
self.message = FakeChatCompletions._Msg(content=content)
self.delta = FakeChatCompletions._Delta(content=delta)
self.delta = FakeChatCompletions._Delta(
content=delta,
reasoning_content=reasoning_content,
tool_calls=tool_calls,
)
self.finish_reason = finish_reason
class _StreamLine:
def __init__(self, deltas):
self.choices = [FakeChatCompletions._Choice(delta=d) for d in deltas]
choices = []
for delta in deltas:
if isinstance(delta, dict):
choices.append(
FakeChatCompletions._Choice(
delta=delta.get("content"),
reasoning_content=delta.get("reasoning_content"),
tool_calls=delta.get("tool_calls"),
finish_reason=delta.get("finish_reason", "stop"),
)
)
else:
choices.append(FakeChatCompletions._Choice(delta=delta))
self.choices = choices
class _Response:
def __init__(self, choices=None, lines=None):
@@ -144,6 +170,44 @@ def test_raw_gen_stream_yields_chunks(openai_llm):
assert "part2" in "".join(chunks)
@pytest.mark.unit
def test_raw_gen_stream_emits_thought_events(openai_llm):
msgs = [{"role": "user", "content": "think first"}]
openai_llm.client.chat.completions.create = lambda **kwargs: FakeChatCompletions._Response(
lines=[
FakeChatCompletions._StreamLine(
[{"reasoning_content": "internal thought"}]
),
FakeChatCompletions._StreamLine([{"content": "final answer"}]),
FakeChatCompletions._StreamLine([{"finish_reason": "stop"}]),
]
)
chunks = list(openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs))
assert {"type": "thought", "thought": "internal thought"} in chunks
assert "final answer" in chunks
@pytest.mark.unit
def test_raw_gen_stream_keeps_prefix_like_text_as_answer(openai_llm):
msgs = [{"role": "user", "content": "return literal marker"}]
prefixed_answer = "[[DOCSGPT_OPENAI_REASONING]]this is answer text"
openai_llm.client.chat.completions.create = lambda **kwargs: FakeChatCompletions._Response(
lines=[
FakeChatCompletions._StreamLine([{"content": prefixed_answer}]),
FakeChatCompletions._StreamLine([{"finish_reason": "stop"}]),
]
)
chunks = list(openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs))
assert prefixed_answer in chunks
assert not any(isinstance(chunk, dict) and chunk.get("type") == "thought" for chunk in chunks)
@pytest.mark.unit
def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm):
schema = {