mirror of
https://github.com/arc53/DocsGPT.git
synced 2026-03-07 06:15:10 +00:00
feat: stream thinking tokens
This commit is contained in:
@@ -378,6 +378,22 @@ class GoogleLLM(BaseLLM):
|
||||
last_preview = f"{last_preview[:preview_chars]}..."
|
||||
return f"count={message_count}, last='{last_preview}'"
|
||||
|
||||
@staticmethod
|
||||
def _get_text_value(part):
|
||||
"""Get text from both SDK objects and dict-shaped test doubles."""
|
||||
if isinstance(part, dict):
|
||||
value = part.get("text")
|
||||
return value if isinstance(value, str) else ""
|
||||
value = getattr(part, "text", None)
|
||||
return value if isinstance(value, str) else ""
|
||||
|
||||
@staticmethod
|
||||
def _is_thought_part(part):
|
||||
"""Detect Gemini thinking parts when available."""
|
||||
if isinstance(part, dict):
|
||||
return bool(part.get("thought"))
|
||||
return bool(getattr(part, "thought", False))
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -438,7 +454,6 @@ class GoogleLLM(BaseLLM):
|
||||
if tools:
|
||||
cleaned_tools = self._clean_tools_format(tools)
|
||||
config.tools = cleaned_tools
|
||||
# Add response schema for structured output if provided
|
||||
|
||||
if response_schema:
|
||||
config.response_schema = response_schema
|
||||
@@ -475,10 +490,23 @@ class GoogleLLM(BaseLLM):
|
||||
for part in candidate.content.parts:
|
||||
if part.function_call:
|
||||
yield part
|
||||
elif part.text:
|
||||
yield part.text
|
||||
continue
|
||||
|
||||
part_text = self._get_text_value(part)
|
||||
if not part_text:
|
||||
continue
|
||||
|
||||
if self._is_thought_part(part):
|
||||
yield {"type": "thought", "thought": part_text}
|
||||
else:
|
||||
yield part_text
|
||||
elif hasattr(chunk, "text"):
|
||||
yield chunk.text
|
||||
chunk_text = self._get_text_value(chunk)
|
||||
if chunk_text:
|
||||
if self._is_thought_part(chunk):
|
||||
yield {"type": "thought", "thought": chunk_text}
|
||||
else:
|
||||
yield chunk_text
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
@@ -878,6 +878,9 @@ class LLMHandler(ABC):
|
||||
tool_calls = {}
|
||||
|
||||
for chunk in self._iterate_stream(response):
|
||||
if isinstance(chunk, dict) and chunk.get("type") == "thought":
|
||||
yield chunk
|
||||
continue
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
@@ -151,6 +151,51 @@ class OpenAILLM(BaseLLM):
|
||||
raise ValueError(f"Unexpected content type: {type(content)}")
|
||||
return cleaned_messages
|
||||
|
||||
@staticmethod
|
||||
def _normalize_reasoning_value(value):
|
||||
"""Normalize reasoning payloads from OpenAI-compatible stream chunks."""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
return "".join(
|
||||
OpenAILLM._normalize_reasoning_value(item) for item in value
|
||||
)
|
||||
if isinstance(value, dict):
|
||||
for key in ("text", "content", "value", "reasoning_content", "reasoning"):
|
||||
normalized = OpenAILLM._normalize_reasoning_value(value.get(key))
|
||||
if normalized:
|
||||
return normalized
|
||||
return ""
|
||||
|
||||
for attr in ("text", "content", "value"):
|
||||
if hasattr(value, attr):
|
||||
normalized = OpenAILLM._normalize_reasoning_value(getattr(value, attr))
|
||||
if normalized:
|
||||
return normalized
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _extract_reasoning_text(cls, delta):
|
||||
"""Extract reasoning/thinking tokens from OpenAI-compatible delta chunks."""
|
||||
if delta is None:
|
||||
return ""
|
||||
|
||||
for key in (
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thinking",
|
||||
"thinking_content",
|
||||
):
|
||||
value = getattr(delta, key, None)
|
||||
if value is None and isinstance(delta, dict):
|
||||
value = delta.get(key)
|
||||
normalized = cls._normalize_reasoning_value(value)
|
||||
if normalized:
|
||||
return normalized
|
||||
return ""
|
||||
|
||||
def _raw_gen(
|
||||
self,
|
||||
baseself,
|
||||
@@ -221,14 +266,26 @@ class OpenAILLM(BaseLLM):
|
||||
try:
|
||||
for line in response:
|
||||
logging.debug(f"OpenAI stream line: {line}")
|
||||
if (
|
||||
len(line.choices) > 0
|
||||
and line.choices[0].delta.content is not None
|
||||
and len(line.choices[0].delta.content) > 0
|
||||
):
|
||||
yield line.choices[0].delta.content
|
||||
elif len(line.choices) > 0:
|
||||
yield line.choices[0]
|
||||
if not getattr(line, "choices", None):
|
||||
continue
|
||||
|
||||
choice = line.choices[0]
|
||||
delta = getattr(choice, "delta", None)
|
||||
reasoning_text = self._extract_reasoning_text(delta)
|
||||
if reasoning_text:
|
||||
yield {"type": "thought", "thought": reasoning_text}
|
||||
|
||||
content = getattr(delta, "content", None)
|
||||
if isinstance(content, str) and content:
|
||||
yield content
|
||||
continue
|
||||
|
||||
has_tool_calls = bool(getattr(delta, "tool_calls", None))
|
||||
finish_reason = getattr(choice, "finish_reason", None)
|
||||
|
||||
# Yield non-content chunks only when needed for tool-call handling.
|
||||
if has_tool_calls or finish_reason == "tool_calls":
|
||||
yield choice
|
||||
finally:
|
||||
if hasattr(response, "close"):
|
||||
response.close()
|
||||
|
||||
@@ -805,7 +805,7 @@ function Thought({
|
||||
}) {
|
||||
const { t } = useTranslation();
|
||||
const [isDarkTheme] = useDarkTheme();
|
||||
const [isThoughtOpen, setIsThoughtOpen] = useState(true);
|
||||
const [isThoughtOpen, setIsThoughtOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<div className="mb-4 flex w-full flex-col flex-wrap items-start self-start lg:flex-nowrap">
|
||||
|
||||
@@ -254,6 +254,22 @@ class TestGoogleLLMHandler:
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_iterate_stream_preserves_thought_events(self):
|
||||
"""Test stream iteration preserves provider-emitted thought events."""
|
||||
handler = GoogleLLMHandler()
|
||||
|
||||
mock_chunks = [
|
||||
{"type": "thought", "thought": "first thought"},
|
||||
"answer token",
|
||||
]
|
||||
|
||||
result = list(handler._iterate_stream(mock_chunks))
|
||||
|
||||
assert result == [
|
||||
{"type": "thought", "thought": "first thought"},
|
||||
"answer token",
|
||||
]
|
||||
|
||||
def test_parse_response_parts_without_function_call_attribute(self):
|
||||
"""Test parsing response with parts missing function_call attribute."""
|
||||
handler = GoogleLLMHandler()
|
||||
|
||||
@@ -188,6 +188,22 @@ class TestOpenAILLMHandler:
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_iterate_stream_preserves_thought_events(self):
|
||||
"""Test stream iteration preserves provider-emitted thought events."""
|
||||
handler = OpenAILLMHandler()
|
||||
|
||||
mock_chunks = [
|
||||
{"type": "thought", "thought": "first thought"},
|
||||
"answer token",
|
||||
]
|
||||
|
||||
result = list(handler._iterate_stream(mock_chunks))
|
||||
|
||||
assert result == [
|
||||
{"type": "thought", "thought": "first thought"},
|
||||
"answer token",
|
||||
]
|
||||
|
||||
def test_parse_response_tool_call_missing_attributes(self):
|
||||
"""Test parsing tool calls with missing attributes."""
|
||||
handler = OpenAILLMHandler()
|
||||
|
||||
@@ -4,10 +4,11 @@ import pytest
|
||||
from application.llm.google_ai import GoogleLLM
|
||||
|
||||
class _FakePart:
|
||||
def __init__(self, text=None, function_call=None, file_data=None):
|
||||
def __init__(self, text=None, function_call=None, file_data=None, thought=False):
|
||||
self.text = text
|
||||
self.function_call = function_call
|
||||
self.file_data = file_data
|
||||
self.thought = thought
|
||||
|
||||
@staticmethod
|
||||
def from_text(text):
|
||||
@@ -38,10 +39,22 @@ class FakeTypesModule:
|
||||
Part = _FakePart
|
||||
Content = _FakeContent
|
||||
|
||||
class ThinkingConfig:
|
||||
def __init__(
|
||||
self,
|
||||
include_thoughts=None,
|
||||
thinking_budget=None,
|
||||
thinking_level=None,
|
||||
):
|
||||
self.include_thoughts = include_thoughts
|
||||
self.thinking_budget = thinking_budget
|
||||
self.thinking_level = thinking_level
|
||||
|
||||
class GenerateContentConfig:
|
||||
def __init__(self):
|
||||
self.system_instruction = None
|
||||
self.tools = None
|
||||
self.thinking_config = None
|
||||
self.response_schema = None
|
||||
self.response_mime_type = None
|
||||
|
||||
@@ -112,6 +125,111 @@ def test_raw_gen_stream_yields_chunks():
|
||||
assert list(gen) == ["a", "b"]
|
||||
|
||||
|
||||
def test_raw_gen_stream_does_not_set_thinking_config_by_default(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_stream(self, *args, **kwargs):
|
||||
captured["config"] = kwargs.get("config")
|
||||
return [types.SimpleNamespace(text="a", candidates=None)]
|
||||
|
||||
monkeypatch.setattr(FakeModels, "generate_content_stream", fake_stream)
|
||||
|
||||
llm = GoogleLLM(api_key="key")
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True))
|
||||
|
||||
assert captured["config"].thinking_config is None
|
||||
|
||||
|
||||
def test_raw_gen_stream_sets_thinking_config_when_explicitly_requested(monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def fake_stream(self, *args, **kwargs):
|
||||
captured["config"] = kwargs.get("config")
|
||||
return [types.SimpleNamespace(text="a", candidates=None)]
|
||||
|
||||
monkeypatch.setattr(FakeModels, "generate_content_stream", fake_stream)
|
||||
|
||||
llm = GoogleLLM(api_key="key")
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
list(
|
||||
llm._raw_gen_stream(
|
||||
llm,
|
||||
model="gemini",
|
||||
messages=msgs,
|
||||
stream=True,
|
||||
include_thoughts=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert captured["config"].thinking_config is not None
|
||||
assert captured["config"].thinking_config.include_thoughts is True
|
||||
|
||||
|
||||
def test_raw_gen_stream_emits_thought_events(monkeypatch):
|
||||
llm = GoogleLLM(api_key="key")
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
|
||||
thought_part = types.SimpleNamespace(
|
||||
text="thinking token",
|
||||
function_call=None,
|
||||
thought=True,
|
||||
)
|
||||
answer_part = types.SimpleNamespace(
|
||||
text="answer token",
|
||||
function_call=None,
|
||||
thought=False,
|
||||
)
|
||||
chunk = types.SimpleNamespace(
|
||||
candidates=[
|
||||
types.SimpleNamespace(
|
||||
content=types.SimpleNamespace(parts=[thought_part, answer_part])
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
FakeModels,
|
||||
"generate_content_stream",
|
||||
lambda self, *args, **kwargs: [chunk],
|
||||
)
|
||||
|
||||
out = list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True))
|
||||
|
||||
assert {"type": "thought", "thought": "thinking token"} in out
|
||||
assert "answer token" in out
|
||||
|
||||
|
||||
def test_raw_gen_stream_keeps_prefix_like_text_as_answer(monkeypatch):
|
||||
llm = GoogleLLM(api_key="key")
|
||||
msgs = [{"role": "user", "content": "hello"}]
|
||||
prefixed_answer = "[[DOCSGPT_GOOGLE_REASONING]]this is answer text"
|
||||
|
||||
answer_part = types.SimpleNamespace(
|
||||
text=prefixed_answer,
|
||||
function_call=None,
|
||||
thought=False,
|
||||
)
|
||||
chunk = types.SimpleNamespace(
|
||||
candidates=[
|
||||
types.SimpleNamespace(
|
||||
content=types.SimpleNamespace(parts=[answer_part])
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
FakeModels,
|
||||
"generate_content_stream",
|
||||
lambda self, *args, **kwargs: [chunk],
|
||||
)
|
||||
|
||||
out = list(llm._raw_gen_stream(llm, model="gemini", messages=msgs, stream=True))
|
||||
|
||||
assert prefixed_answer in out
|
||||
assert not any(isinstance(item, dict) and item.get("type") == "thought" for item in out)
|
||||
|
||||
|
||||
def test_prepare_structured_output_format_type_mapping():
|
||||
llm = GoogleLLM(api_key="key")
|
||||
schema = {
|
||||
@@ -148,4 +266,3 @@ def test_prepare_messages_with_attachments_appends_files(monkeypatch):
|
||||
files_entry = next((p for p in user_msg["content"] if isinstance(p, dict) and "files" in p), None)
|
||||
assert files_entry is not None
|
||||
assert isinstance(files_entry["files"], list) and len(files_entry["files"]) == 2
|
||||
|
||||
|
||||
@@ -14,18 +14,44 @@ class FakeChatCompletions:
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
class _Delta:
|
||||
def __init__(self, content=None):
|
||||
def __init__(self, content=None, reasoning_content=None, tool_calls=None):
|
||||
self.content = content
|
||||
self.reasoning_content = reasoning_content
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, content=None, delta=None, finish_reason="stop"):
|
||||
def __init__(
|
||||
self,
|
||||
content=None,
|
||||
delta=None,
|
||||
reasoning_content=None,
|
||||
tool_calls=None,
|
||||
finish_reason="stop",
|
||||
):
|
||||
self.message = FakeChatCompletions._Msg(content=content)
|
||||
self.delta = FakeChatCompletions._Delta(content=delta)
|
||||
self.delta = FakeChatCompletions._Delta(
|
||||
content=delta,
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
self.finish_reason = finish_reason
|
||||
|
||||
class _StreamLine:
|
||||
def __init__(self, deltas):
|
||||
self.choices = [FakeChatCompletions._Choice(delta=d) for d in deltas]
|
||||
choices = []
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
choices.append(
|
||||
FakeChatCompletions._Choice(
|
||||
delta=delta.get("content"),
|
||||
reasoning_content=delta.get("reasoning_content"),
|
||||
tool_calls=delta.get("tool_calls"),
|
||||
finish_reason=delta.get("finish_reason", "stop"),
|
||||
)
|
||||
)
|
||||
else:
|
||||
choices.append(FakeChatCompletions._Choice(delta=delta))
|
||||
self.choices = choices
|
||||
|
||||
class _Response:
|
||||
def __init__(self, choices=None, lines=None):
|
||||
@@ -144,6 +170,44 @@ def test_raw_gen_stream_yields_chunks(openai_llm):
|
||||
assert "part2" in "".join(chunks)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_raw_gen_stream_emits_thought_events(openai_llm):
|
||||
msgs = [{"role": "user", "content": "think first"}]
|
||||
|
||||
openai_llm.client.chat.completions.create = lambda **kwargs: FakeChatCompletions._Response(
|
||||
lines=[
|
||||
FakeChatCompletions._StreamLine(
|
||||
[{"reasoning_content": "internal thought"}]
|
||||
),
|
||||
FakeChatCompletions._StreamLine([{"content": "final answer"}]),
|
||||
FakeChatCompletions._StreamLine([{"finish_reason": "stop"}]),
|
||||
]
|
||||
)
|
||||
|
||||
chunks = list(openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs))
|
||||
|
||||
assert {"type": "thought", "thought": "internal thought"} in chunks
|
||||
assert "final answer" in chunks
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_raw_gen_stream_keeps_prefix_like_text_as_answer(openai_llm):
|
||||
msgs = [{"role": "user", "content": "return literal marker"}]
|
||||
prefixed_answer = "[[DOCSGPT_OPENAI_REASONING]]this is answer text"
|
||||
|
||||
openai_llm.client.chat.completions.create = lambda **kwargs: FakeChatCompletions._Response(
|
||||
lines=[
|
||||
FakeChatCompletions._StreamLine([{"content": prefixed_answer}]),
|
||||
FakeChatCompletions._StreamLine([{"finish_reason": "stop"}]),
|
||||
]
|
||||
)
|
||||
|
||||
chunks = list(openai_llm._raw_gen_stream(openai_llm, model="gpt", messages=msgs))
|
||||
|
||||
assert prefixed_answer in chunks
|
||||
assert not any(isinstance(chunk, dict) and chunk.get("type") == "thought" for chunk in chunks)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_prepare_structured_output_format_enforces_required_and_strict(openai_llm):
|
||||
schema = {
|
||||
|
||||
Reference in New Issue
Block a user