parser functions change

token_func proposed change to chunking. open_ai_func proposed change to embedding_pipeline. Late chunking first  implementation requires further testing.
This commit is contained in:
Pavel
2024-11-20 21:40:57 +04:00
parent 9247f16add
commit 5a891647bf
3 changed files with 298 additions and 0 deletions

View File

@@ -0,0 +1,94 @@
from typing import List, Tuple, Union, Optional
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
from application.parser.schema.base import Document
class LateChunker:
def __init__(self, model_name: str, late_tokens: int = 1000, **model_kwargs):
"""
Initialize the LateChunker with a model, tokenizer, and late_tokens limit.
Supports both transformers and sentence-transformers models.
"""
self.late_tokens = late_tokens
self.model_name = model_name
# Load model based on type
if "sentence-transformers" in model_name:
self.model = SentenceTransformer(model_name, **model_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.wrapper_type = "sentence_transformers"
else:
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, **model_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.wrapper_type = "transformers"
def tokenize_with_offsets(self, text: str):
"""Tokenize text and return tokens with character offsets."""
tokens = self.tokenizer.encode_plus(
text, return_offsets_mapping=True, add_special_tokens=False
)
return tokens["input_ids"], tokens["offset_mapping"]
def late_chunk_with_embeddings(
self, documents: List[Document]
) -> List[Tuple[str, List[Tuple[int, int]], List[float]]]:
"""
Combines documents into 'super chunks' that fit within `late_tokens` limit.
Outputs each super chunk with span annotations and embeddings.
"""
super_chunks = []
current_super_chunk_text = []
current_token_count = 0
span_annotations = []
for doc in documents:
doc_text = doc.text
input_ids, offsets = self.tokenize_with_offsets(doc_text)
doc_token_count = len(input_ids)
# Check if adding this document exceeds the late_tokens limit
if current_token_count + doc_token_count > self.late_tokens:
# Finalize the current super chunk
combined_text = " ".join(current_super_chunk_text)
embeddings = self.generate_embeddings(combined_text)
super_chunks.append((combined_text, span_annotations, embeddings))
# Reset for a new super chunk
current_super_chunk_text = []
span_annotations = []
current_token_count = 0
# Add document to the current super chunk
start_token = current_token_count
end_token = current_token_count + doc_token_count
span_annotations.append((start_token, end_token))
current_super_chunk_text.append(doc_text)
current_token_count = end_token
# Add the final super chunk if there are remaining documents
if current_super_chunk_text:
combined_text = " ".join(current_super_chunk_text)
embeddings = self.generate_embeddings(combined_text)
super_chunks.append((combined_text, span_annotations, embeddings))
return super_chunks
def generate_embeddings(self, text: str) -> List[float]:
"""Generate embeddings for a given text using the loaded model."""
if self.wrapper_type == "sentence_transformers":
# Sentence-Transformers
embeddings = self.model.encode([text])
return embeddings[0].tolist()
elif self.wrapper_type == "transformers":
# Transformers models
inputs = self.tokenizer(text, return_tensors="pt")
model_output = self.model(**inputs)
return model_output.last_hidden_state.mean(dim=1).squeeze().tolist()
else:
raise ValueError("Unsupported model type for embedding generation.")