|
|
|
|
|
|
|
|
| import streamlit as st
|
| import torch
|
| import torch.nn.functional as F
|
| from transformers import AutoTokenizer, AutoModelForTokenClassification
|
| import json
|
|
|
| st.set_page_config(page_title="Link Detection", page_icon="🔗")
|
|
|
| @st.cache_resource
|
| def load_model(model_path="model_link_token_cls"):
|
| """Load model and tokenizer."""
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
|
| model = AutoModelForTokenClassification.from_pretrained(model_path)
|
| model = model.to(device)
|
| model.eval()
|
| return tokenizer, model, device
|
|
|
| def group_tokens_into_words(tokens, offset_mapping, link_probs):
|
| """Group tokens into words based on tokenizer patterns."""
|
| words = []
|
| current_word_tokens = []
|
| current_word_offsets = []
|
| current_word_probs = []
|
|
|
| for i, (token, offsets, prob) in enumerate(zip(tokens, offset_mapping, link_probs)):
|
|
|
| if offsets == [0, 0]:
|
| if current_word_tokens:
|
| words.append({
|
| 'tokens': current_word_tokens,
|
| 'offsets': current_word_offsets,
|
| 'probs': current_word_probs
|
| })
|
| current_word_tokens = []
|
| current_word_offsets = []
|
| current_word_probs = []
|
| continue
|
|
|
|
|
| is_new_word = False
|
|
|
|
|
| if token.startswith("▁"):
|
| is_new_word = True
|
|
|
| elif i == 0 or not token.startswith("##"):
|
|
|
| if i == 0 or offset_mapping[i-1] == [0, 0]:
|
| is_new_word = True
|
|
|
| elif current_word_offsets and offsets[0] > current_word_offsets[-1][1]:
|
| is_new_word = True
|
|
|
| if is_new_word and current_word_tokens:
|
|
|
| words.append({
|
| 'tokens': current_word_tokens,
|
| 'offsets': current_word_offsets,
|
| 'probs': current_word_probs
|
| })
|
| current_word_tokens = []
|
| current_word_offsets = []
|
| current_word_probs = []
|
|
|
|
|
| current_word_tokens.append(token)
|
| current_word_offsets.append(offsets)
|
| current_word_probs.append(prob)
|
|
|
|
|
| if current_word_tokens:
|
| words.append({
|
| 'tokens': current_word_tokens,
|
| 'offsets': current_word_offsets,
|
| 'probs': current_word_probs
|
| })
|
|
|
| return words
|
|
|
| def predict_links(text, tokenizer, model, device, threshold=0.5,
|
| max_length=512, doc_stride=128):
|
| """Predict link tokens with word-level highlighting using sliding windows."""
|
| if not text.strip():
|
| return [], []
|
|
|
|
|
| full_enc = tokenizer(
|
| text,
|
| add_special_tokens=False,
|
| truncation=False,
|
| return_offsets_mapping=True,
|
| )
|
| all_ids = full_enc["input_ids"]
|
| all_offsets = full_enc["offset_mapping"]
|
| n_tokens = len(all_ids)
|
|
|
|
|
| prob_sums = [0.0] * n_tokens
|
| prob_counts = [0] * n_tokens
|
|
|
|
|
| specials = tokenizer.num_special_tokens_to_add(pair=False)
|
| cap = max_length - specials
|
| step = max(cap - doc_stride, 1)
|
|
|
|
|
| start = 0
|
| while start < n_tokens:
|
| end = min(start + cap, n_tokens)
|
| window_ids = all_ids[start:end]
|
|
|
|
|
| input_ids = torch.tensor(
|
| [tokenizer.build_inputs_with_special_tokens(window_ids)],
|
| device=device
|
| )
|
| attention_mask = torch.ones_like(input_ids)
|
|
|
| with torch.no_grad():
|
| logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
|
| probs = F.softmax(logits, dim=-1)[0].cpu()
|
|
|
| content_probs = probs[1:-1, 1].tolist()
|
|
|
|
|
| for i, p in enumerate(content_probs):
|
| orig_idx = start + i
|
| if orig_idx < n_tokens:
|
| prob_sums[orig_idx] += p
|
| prob_counts[orig_idx] += 1
|
|
|
| if end == n_tokens:
|
| break
|
| start += step
|
|
|
|
|
| link_probs = [
|
| prob_sums[i] / prob_counts[i] if prob_counts[i] > 0 else 0.0
|
| for i in range(n_tokens)
|
| ]
|
|
|
|
|
| tokens = tokenizer.convert_ids_to_tokens(all_ids)
|
| offset_mapping = [list(o) for o in all_offsets]
|
|
|
|
|
| words = group_tokens_into_words(tokens, offset_mapping, link_probs)
|
|
|
|
|
| link_spans = []
|
| link_details = []
|
|
|
| for word_group in words:
|
| word_offsets = word_group['offsets']
|
| word_probs = word_group['probs']
|
|
|
|
|
| if any(prob >= threshold for prob in word_probs):
|
|
|
| start = word_offsets[0][0]
|
| end = word_offsets[-1][1]
|
| link_spans.append((start, end))
|
|
|
|
|
| max_confidence = max(word_probs)
|
| avg_confidence = sum(word_probs) / len(word_probs)
|
|
|
| link_text = text[start:end]
|
| link_details.append({
|
| "text": link_text,
|
| "start": start,
|
| "end": end,
|
| "max_confidence": round(max_confidence, 4),
|
| "avg_confidence": round(avg_confidence, 4)
|
| })
|
|
|
| return link_spans, link_details
|
|
|
| def render_highlighted_text(text, link_spans):
|
| """Render text with highlighted link spans."""
|
| if not text:
|
| return ""
|
|
|
|
|
| link_spans = sorted(link_spans, key=lambda x: x[0])
|
|
|
|
|
| html_parts = []
|
| last_end = 0
|
|
|
| for start, end in link_spans:
|
|
|
| if start > last_end:
|
| html_parts.append(text[last_end:start])
|
|
|
| html_parts.append(
|
| f'<span style="background-color: #90EE90; padding: 2px 4px; '
|
| f'border-radius: 3px; font-weight: 500;">{text[start:end]}</span>'
|
| )
|
| last_end = end
|
|
|
|
|
| if last_end < len(text):
|
| html_parts.append(text[last_end:])
|
|
|
| html_content = "".join(html_parts)
|
|
|
|
|
| full_html = f"""
|
| <div style="
|
| padding: 20px;
|
| background-color: #f8f9fa;
|
| border-radius: 8px;
|
| line-height: 1.8;
|
| font-size: 16px;
|
| white-space: pre-wrap;
|
| word-wrap: break-word;
|
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| ">
|
| {html_content}
|
| </div>
|
| """
|
|
|
| return full_html
|
|
|
| def main():
|
| st.title("Link Detection")
|
|
|
|
|
| try:
|
| tokenizer, model, device = load_model()
|
| st.success(f"Model loaded on {device}")
|
| except Exception as e:
|
| st.error(f"Failed to load model: {e}")
|
| return
|
|
|
|
|
| threshold = st.slider(
|
| "Confidence Threshold (%)",
|
| min_value=0,
|
| max_value=100,
|
| value=5,
|
| step=1,
|
| help="Highlights entire word if ANY of its tokens meet this threshold"
|
| ) / 100.0
|
|
|
|
|
| text = st.text_area("Input text:", height=200)
|
|
|
| if st.button("Detect Links"):
|
| if text:
|
| link_spans, link_details = predict_links(text, tokenizer, model, device, threshold)
|
|
|
|
|
| st.subheader("Text with Highlighted Links")
|
| html = render_highlighted_text(text, link_spans)
|
| st.markdown(html, unsafe_allow_html=True)
|
|
|
|
|
| st.info(f"Found {len(link_details)} words with link confidence above {threshold:.0%}")
|
|
|
|
|
| if link_details:
|
| st.subheader("Link Details (JSON)")
|
| st.json(link_details)
|
| else:
|
| st.warning("Please enter text")
|
|
|
| if __name__ == "__main__":
|
| main() |