From f37ca95c1012d5918c2e36840877b3ad1415fae2 Mon Sep 17 00:00:00 2001 From: ManishMadan2882 Date: Tue, 6 May 2025 15:16:14 +0530 Subject: [PATCH] (fix/mermaid loading): load only the diagrams which stream --- frontend/src/components/MermaidRenderer.tsx | 106 +++++++++--------- frontend/src/components/types/index.ts | 1 + .../src/conversation/ConversationBubble.tsx | 25 +++-- .../src/conversation/ConversationMessages.tsx | 6 +- 4 files changed, 73 insertions(+), 65 deletions(-) diff --git a/frontend/src/components/MermaidRenderer.tsx b/frontend/src/components/MermaidRenderer.tsx index ac37cfb4..b9014581 100644 --- a/frontend/src/components/MermaidRenderer.tsx +++ b/frontend/src/components/MermaidRenderer.tsx @@ -10,6 +10,7 @@ import { useDarkTheme } from '../hooks'; const MermaidRenderer: React.FC = ({ code, + isLoading, }) => { const [isDarkTheme] = useDarkTheme(); const diagramId = useRef(`mermaid-${crypto.randomUUID()}`); @@ -21,70 +22,70 @@ const MermaidRenderer: React.FC = ({ const containerRef = useRef(null); const [hoverPosition, setHoverPosition] = useState<{ x: number, y: number } | null>(null); const [isHovering, setIsHovering] = useState(false); - + const handleMouseMove = (event: React.MouseEvent) => { if (!containerRef.current) return; - + const rect = containerRef.current.getBoundingClientRect(); const x = (event.clientX - rect.left) / rect.width; const y = (event.clientY - rect.top) / rect.height; - + setHoverPosition({ x, y }); }; - + const handleMouseEnter = () => setIsHovering(true); const handleMouseLeave = () => { setIsHovering(false); setHoverPosition(null); }; - + const getTransformOrigin = () => { if (!hoverPosition) return 'center center'; return `${hoverPosition.x * 100}% ${hoverPosition.y * 100}%`; }; useEffect(() => { - if (status === 'loading' || !code) return; - + if ((isLoading !== undefined ? isLoading : status === 'loading') || !code) return; + mermaid.initialize({ startOnLoad: true, theme: isDarkTheme ? 'dark' : 'default', securityLevel: 'loose', suppressErrorRendering: true, }); - + const renderDiagram = async (): Promise => { try { await mermaid.parse(code); //throws syntax errors - + const element = document.getElementById(diagramId.current); if (element) { element.removeAttribute('data-processed'); mermaid.contentLoaded(); - + const svgElement = element.querySelector('svg'); if (svgElement) { svgElement.setAttribute('width', '100%'); svgElement.setAttribute('height', 'auto'); svgElement.style.maxWidth = '100%'; svgElement.style.width = '100%'; - + svgElement.removeAttribute('viewBox'); - + } setError(null); } } catch (err) { - + setError( `Failed to render Mermaid diagram: ${err instanceof Error ? err.message : String(err)}` ); } }; - + renderDiagram(); - - - }, [code, isDarkTheme]); + + + }, [code, isDarkTheme, isLoading]); useEffect(() => { @@ -109,13 +110,13 @@ const MermaidRenderer: React.FC = ({ if (!element) return; const svgElement = element.querySelector('svg'); if (!svgElement) return; - + const svgClone = svgElement.cloneNode(true) as SVGElement; - + if (!svgClone.hasAttribute('xmlns')) { svgClone.setAttribute('xmlns', 'http://www.w3.org/2000/svg'); } - + if (!svgClone.hasAttribute('width') || !svgClone.hasAttribute('height')) { const viewBox = svgClone.getAttribute('viewBox')?.split(' ') || []; if (viewBox.length === 4) { @@ -123,15 +124,15 @@ const MermaidRenderer: React.FC = ({ svgClone.setAttribute('height', viewBox[3]); } } - + const serializer = new XMLSerializer(); const svgString = serializer.serializeToString(svgClone); - + const svgBlob = new Blob( - [`\n${svgString}`], + [`\n${svgString}`], { type: 'image/svg+xml' } ); - + const url = URL.createObjectURL(svgBlob); const link = document.createElement('a'); link.href = url; @@ -145,19 +146,19 @@ const MermaidRenderer: React.FC = ({ const downloadPng = (): void => { const element = document.getElementById(diagramId.current); if (!element) return; - + const svgElement = element.querySelector('svg'); if (!svgElement) return; - + const svgClone = svgElement.cloneNode(true) as SVGElement; - + if (!svgClone.hasAttribute('xmlns')) { svgClone.setAttribute('xmlns', 'http://www.w3.org/2000/svg'); } - + let width = parseInt(svgClone.getAttribute('width') || '0'); let height = parseInt(svgClone.getAttribute('height') || '0'); - + if (!width || !height) { const viewBox = svgClone.getAttribute('viewBox')?.split(' ') || []; if (viewBox.length === 4) { @@ -172,30 +173,30 @@ const MermaidRenderer: React.FC = ({ svgClone.setAttribute('height', height.toString()); } } - + const serializer = new XMLSerializer(); const svgString = serializer.serializeToString(svgClone); const svgBase64 = btoa(unescape(encodeURIComponent(svgString))); const dataUrl = `data:image/svg+xml;base64,${svgBase64}`; - + const img = new Image(); img.crossOrigin = 'anonymous'; - + img.onload = function(): void { const canvas = document.createElement('canvas'); canvas.width = width; canvas.height = height; - + const ctx = canvas.getContext('2d'); if (!ctx) { console.error('Could not get canvas context'); return; } - + ctx.fillRect(0, 0, canvas.width, canvas.height); - + ctx.drawImage(img, 0, 0, width, height); - + try { const pngUrl = canvas.toDataURL('image/png'); const link = document.createElement('a'); @@ -210,7 +211,7 @@ const MermaidRenderer: React.FC = ({ downloadSvg(); } }; - + img.src = dataUrl; }; @@ -225,16 +226,17 @@ const MermaidRenderer: React.FC = ({ document.body.removeChild(link); URL.revokeObjectURL(url); }; - - + + const downloadOptions = [ { label: 'Download as SVG', action: downloadSvg }, { label: 'Download as PNG', action: downloadPng }, { label: 'Download as MMD', action: downloadMmd }, ]; - const showDiagramOptions = status !== 'loading' && !error; - const errorRender = status !== 'loading' && error; + const isCurrentlyLoading = isLoading !== undefined ? isLoading : status === 'loading'; + const showDiagramOptions = !isCurrentlyLoading && !error; + const errorRender = !isCurrentlyLoading && error; @@ -246,7 +248,7 @@ const MermaidRenderer: React.FC = ({
- + {showDiagramOptions && (
)} - + {showDiagramOptions && (
- - {status === 'loading' ? ( + + {isCurrentlyLoading ? (
Loading diagram... @@ -308,24 +310,24 @@ const MermaidRenderer: React.FC = ({
) : ( <> -
-
 = ({
               {code}
             
- + {showCode && (
diff --git a/frontend/src/components/types/index.ts b/frontend/src/components/types/index.ts index 25dae251..1c0d138d 100644 --- a/frontend/src/components/types/index.ts +++ b/frontend/src/components/types/index.ts @@ -26,4 +26,5 @@ export type InputProps = { export type MermaidRendererProps = { code: string; + isLoading?: boolean; }; diff --git a/frontend/src/conversation/ConversationBubble.tsx b/frontend/src/conversation/ConversationBubble.tsx index 67cc1c34..5e3b83c1 100644 --- a/frontend/src/conversation/ConversationBubble.tsx +++ b/frontend/src/conversation/ConversationBubble.tsx @@ -53,6 +53,7 @@ const ConversationBubble = forwardRef< toolCalls?: ToolCallsType[]; retryBtn?: React.ReactElement; questionNumber?: number; + isStreaming?: boolean; handleUpdatedQuestionSubmission?: ( updatedquestion?: string, updated?: boolean, @@ -71,6 +72,7 @@ const ConversationBubble = forwardRef< toolCalls, retryBtn, questionNumber, + isStreaming, handleUpdatedQuestionSubmission, }, ref, @@ -195,29 +197,29 @@ const ConversationBubble = forwardRef< }; const processMarkdownContent = (content: string) => { const processedContent = preprocessLaTeX(content); - + const contentSegments: Array<{type: 'text' | 'mermaid', content: string}> = []; - + let lastIndex = 0; const regex = /```mermaid\n([\s\S]*?)```/g; let match; - + while ((match = regex.exec(processedContent)) !== null) { const textBefore = processedContent.substring(lastIndex, match.index); if (textBefore) { contentSegments.push({ type: 'text', content: textBefore }); } - + contentSegments.push({ type: 'mermaid', content: match[1].trim() }); - + lastIndex = match.index + match[0].length; } - + const textAfter = processedContent.substring(lastIndex); if (textAfter) { contentSegments.push({ type: 'text', content: textAfter }); } - + return contentSegments; }; bubble = ( @@ -404,7 +406,7 @@ const ConversationBubble = forwardRef< const { children, className, node, ref, ...rest } = props; const match = /language-(\w+)/.exec(className || ''); const language = match ? match[1] : ''; - + return match ? (
@@ -491,6 +493,7 @@ const ConversationBubble = forwardRef<
)} @@ -505,7 +508,7 @@ const ConversationBubble = forwardRef< {message && (
@@ -513,7 +516,7 @@ const ConversationBubble = forwardRef<
@@ -544,7 +547,7 @@ const ConversationBubble = forwardRef< }`} > (null); const atLast = useRef(true); const [eventInterrupt, setEventInterrupt] = useState(false); - + const handleUserInterruption = () => { if (!eventInterrupt && status === 'loading') { setEventInterrupt(true); @@ -54,7 +54,7 @@ export default function ConversationMessages({ setTimeout(() => { if (!conversationRef?.current) return; - + if (status === 'idle' || !queries[queries.length - 1]?.response) { conversationRef.current.scrollTo({ behavior: 'smooth', @@ -93,6 +93,7 @@ export default function ConversationMessages({ const prepResponseView = (query: Query, index: number) => { let responseView; if (query.thought || query.response) { + const isCurrentlyStreaming = status === 'loading' && index === queries.length - 1; responseView = ( handleFeedback(query, feedback, index)