mirror of
https://github.com/arc53/DocsGPT.git
synced 2025-11-29 00:23:17 +00:00
token_func proposed change to chunking. open_ai_func proposed change to embedding_pipeline. Late chunking first implementation requires further testing.
95 lines
4.0 KiB
Python
Executable File
95 lines
4.0 KiB
Python
Executable File
from typing import List, Tuple, Union, Optional
|
|
from transformers import AutoTokenizer, AutoModel
|
|
from sentence_transformers import SentenceTransformer
|
|
import torch
|
|
import torch.nn as nn
|
|
from application.parser.schema.base import Document
|
|
|
|
|
|
class LateChunker:
|
|
def __init__(self, model_name: str, late_tokens: int = 1000, **model_kwargs):
|
|
"""
|
|
Initialize the LateChunker with a model, tokenizer, and late_tokens limit.
|
|
Supports both transformers and sentence-transformers models.
|
|
"""
|
|
self.late_tokens = late_tokens
|
|
self.model_name = model_name
|
|
|
|
# Load model based on type
|
|
if "sentence-transformers" in model_name:
|
|
self.model = SentenceTransformer(model_name, **model_kwargs)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
self.wrapper_type = "sentence_transformers"
|
|
else:
|
|
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, **model_kwargs)
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
|
self.wrapper_type = "transformers"
|
|
|
|
def tokenize_with_offsets(self, text: str):
|
|
"""Tokenize text and return tokens with character offsets."""
|
|
tokens = self.tokenizer.encode_plus(
|
|
text, return_offsets_mapping=True, add_special_tokens=False
|
|
)
|
|
return tokens["input_ids"], tokens["offset_mapping"]
|
|
|
|
def late_chunk_with_embeddings(
|
|
self, documents: List[Document]
|
|
) -> List[Tuple[str, List[Tuple[int, int]], List[float]]]:
|
|
"""
|
|
Combines documents into 'super chunks' that fit within `late_tokens` limit.
|
|
Outputs each super chunk with span annotations and embeddings.
|
|
"""
|
|
super_chunks = []
|
|
current_super_chunk_text = []
|
|
current_token_count = 0
|
|
span_annotations = []
|
|
|
|
for doc in documents:
|
|
doc_text = doc.text
|
|
input_ids, offsets = self.tokenize_with_offsets(doc_text)
|
|
doc_token_count = len(input_ids)
|
|
|
|
# Check if adding this document exceeds the late_tokens limit
|
|
if current_token_count + doc_token_count > self.late_tokens:
|
|
# Finalize the current super chunk
|
|
combined_text = " ".join(current_super_chunk_text)
|
|
embeddings = self.generate_embeddings(combined_text)
|
|
|
|
super_chunks.append((combined_text, span_annotations, embeddings))
|
|
|
|
# Reset for a new super chunk
|
|
current_super_chunk_text = []
|
|
span_annotations = []
|
|
current_token_count = 0
|
|
|
|
# Add document to the current super chunk
|
|
start_token = current_token_count
|
|
end_token = current_token_count + doc_token_count
|
|
span_annotations.append((start_token, end_token))
|
|
current_super_chunk_text.append(doc_text)
|
|
current_token_count = end_token
|
|
|
|
# Add the final super chunk if there are remaining documents
|
|
if current_super_chunk_text:
|
|
combined_text = " ".join(current_super_chunk_text)
|
|
embeddings = self.generate_embeddings(combined_text)
|
|
super_chunks.append((combined_text, span_annotations, embeddings))
|
|
|
|
return super_chunks
|
|
|
|
def generate_embeddings(self, text: str) -> List[float]:
|
|
"""Generate embeddings for a given text using the loaded model."""
|
|
if self.wrapper_type == "sentence_transformers":
|
|
# Sentence-Transformers
|
|
embeddings = self.model.encode([text])
|
|
return embeddings[0].tolist()
|
|
|
|
elif self.wrapper_type == "transformers":
|
|
# Transformers models
|
|
inputs = self.tokenizer(text, return_tensors="pt")
|
|
model_output = self.model(**inputs)
|
|
return model_output.last_hidden_state.mean(dim=1).squeeze().tolist()
|
|
|
|
else:
|
|
raise ValueError("Unsupported model type for embedding generation.")
|